| | import torch |
| |
|
| | from diffusers.pipelines import FluxPipeline |
| | from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition |
| | from omini.rotation import RotationConfig, RotationTuner |
| | from PIL import Image |
| |
|
| |
|
| | def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False): |
| | """ |
| | Load rotation adapter weights. |
| | |
| | Args: |
| | path: Directory containing the saved adapter weights |
| | adapter_name: Name of the adapter to load |
| | strict: Whether to strictly match all keys |
| | """ |
| | from safetensors.torch import load_file |
| | import os |
| | import yaml |
| | |
| | device = transformer.device |
| | print(f"device for loading: {device}") |
| | |
| | |
| | safetensors_path = os.path.join(path, f"{adapter_name}.safetensors") |
| | pth_path = os.path.join(path, f"{adapter_name}.pth") |
| | |
| | if os.path.exists(safetensors_path): |
| | state_dict = load_file(safetensors_path) |
| | print(f"Loaded rotation adapter from {safetensors_path}") |
| | elif os.path.exists(pth_path): |
| | state_dict = torch.load(pth_path, map_location=device) |
| | print(f"Loaded rotation adapter from {pth_path}") |
| | else: |
| | raise FileNotFoundError( |
| | f"No adapter weights found for '{adapter_name}' in {path}\n" |
| | f"Looking for: {safetensors_path} or {pth_path}" |
| | ) |
| | |
| | |
| | transformer_device = next(transformer.parameters()).device |
| | transformer_dtype = next(transformer.parameters()).dtype |
| | |
| | |
| | |
| | state_dict_with_adapter = {} |
| | for k, v in state_dict.items(): |
| | |
| | new_key = k.replace(".rotation.", f".rotation.{adapter_name}.") |
| | if "_adapter_config" in new_key: |
| | print(f"adapter_config key: {new_key}") |
| | |
| | |
| | |
| | |
| | if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]: |
| | |
| | state_dict_with_adapter[new_key] = v.to(device=transformer_device) |
| | else: |
| | |
| | state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype) |
| | |
| | |
| | state_dict_with_adapter = { |
| | k.replace(".rotation.", f".rotation.{adapter_name}."): v |
| | for k, v in state_dict.items() |
| | } |
| | |
| | |
| | |
| | missing, unexpected = transformer.load_state_dict( |
| | state_dict_with_adapter, |
| | strict=strict |
| | ) |
| | |
| | if missing: |
| | print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") |
| | if unexpected: |
| | print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") |
| | |
| | |
| | config_path = os.path.join(path, f"{adapter_name}_config.yaml") |
| | if os.path.exists(config_path): |
| | with open(config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | print(f"Loaded config: {config}") |
| | |
| | total_params = sum(p.numel() for p in state_dict.values()) |
| | print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") |
| | |
| | return state_dict |
| |
|
| |
|
| | |
| | image = Image.open("assets/coffee.png").convert("RGB") |
| |
|
| | w, h, min_dim = image.size + (min(image.size),) |
| | image = image.crop( |
| | ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) |
| | ).resize((512, 512)) |
| |
|
| | prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table." |
| |
|
| | canny_image = convert_to_condition("canny", image) |
| | condition = Condition(canny_image, "canny") |
| |
|
| | seed_everything() |
| |
|
| |
|
| |
|
| | for i in range(40, 60): |
| | pipe = FluxPipeline.from_pretrained( |
| | "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 |
| | ) |
| |
|
| |
|
| | |
| | transformer = pipe.transformer |
| |
|
| | adapter_name = "default" |
| | transformer._hf_peft_config_loaded = True |
| |
|
| | rotation_adapter_config = { |
| | "r": 4, |
| | "num_rotations": 4, |
| | "target_modules": "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)", |
| | } |
| |
|
| | config = RotationConfig(**rotation_adapter_config) |
| | config.T = float(i + 1) / 20 |
| | rotation_tuner = RotationTuner( |
| | transformer, |
| | config, |
| | adapter_name=adapter_name, |
| | ) |
| | |
| | transformer = transformer.to(torch.bfloat16) |
| | transformer.set_adapter(adapter_name) |
| |
|
| | |
| | load_rotation( |
| | transformer, |
| | path="runs/20251110-191859/ckpt/4000", |
| | adapter_name=adapter_name, |
| | strict=False, |
| | ) |
| |
|
| | pipe = pipe.to("cuda") |
| |
|
| |
|
| |
|
| |
|
| |
|
| | result_img = generate( |
| | pipe, |
| | prompt=prompt, |
| | conditions=[condition], |
| | ).images[0] |
| |
|
| | concat_image = Image.new("RGB", (1536, 512)) |
| | concat_image.paste(image, (0, 0)) |
| | concat_image.paste(condition.condition, (512, 0)) |
| | concat_image.paste(result_img, (1024, 0)) |
| |
|
| | |
| | result_img.save(f"result_{i+1}.png") |
| | concat_image.save(f"result_concat_{i+1}.png") |
| | print(f"Saved result_{i+1}.png and result_concat_{i+1}.png") |