| | import openai |
| | import base64 |
| | from pathlib import Path |
| | import random |
| | import os |
| |
|
| |
|
| |
|
| | evaluation_prompts = { |
| | "identity": """ |
| | Compare the original subject image with the generated image. |
| | Rate on a scale of 1-5 how well the essential identifying features |
| | are preserved (logos, brand marks, distinctive patterns). |
| | Score: [1-5] |
| | Reasoning: [explanation] |
| | """, |
| | |
| | "material": """ |
| | Evaluate the material quality and surface characteristics. |
| | Rate on a scale of 1-5 how accurately materials are represented |
| | (textures, reflections, surface properties). |
| | Score: [1-5] |
| | Reasoning: [explanation] |
| | """, |
| | |
| | "color": """ |
| | Assess color fidelity in regions NOT specified for modification. |
| | Rate on a scale of 1-5 how consistent colors remain. |
| | Score: [1-5] |
| | Reasoning: [explanation] |
| | """, |
| | |
| | "appearance": """ |
| | Evaluate the overall realism and coherence of the generated image. |
| | Rate on a scale of 1-5 how realistic and natural it appears. |
| | Score: [1-5] |
| | Reasoning: [explanation] |
| | """, |
| | |
| | "modification": """ |
| | Given the text prompt: "{prompt}" |
| | Rate on a scale of 1-5 how well the specified changes are executed. |
| | Score: [1-5] |
| | Reasoning: [explanation] |
| | """ |
| | } |
| |
|
| |
|
| | def encode_image(image_path): |
| | with open(image_path, "rb") as image_file: |
| | return base64.b64encode(image_file.read()).decode('utf-8') |
| |
|
| | def evaluate_subject_driven_generation( |
| | original_image_path, |
| | generated_image_path, |
| | text_prompt, |
| | client |
| | ): |
| | """ |
| | Evaluate a subject-driven generation using GPT-4o vision |
| | """ |
| | |
| | |
| | original_img = encode_image(original_image_path) |
| | generated_img = encode_image(generated_image_path) |
| | |
| | results = {} |
| | |
| | |
| | response = client.chat.completions.create( |
| | model="gpt-4o", |
| | messages=[{ |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "text": "Original subject image:"}, |
| | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, |
| | {"type": "text", "text": "Generated image:"}, |
| | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| | {"type": "text", "text": evaluation_prompts["identity"]} |
| | ] |
| | }], |
| | max_tokens=300 |
| | ) |
| | results['identity'] = parse_score(response.choices[0].message.content) |
| | |
| | |
| | response = client.chat.completions.create( |
| | model="gpt-4o", |
| | messages=[{ |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "text": "Evaluate this generated image:"}, |
| | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| | {"type": "text", "text": evaluation_prompts["material"]} |
| | ] |
| | }], |
| | max_tokens=300 |
| | ) |
| | results['material'] = parse_score(response.choices[0].message.content) |
| | |
| | |
| | response = client.chat.completions.create( |
| | model="gpt-4o", |
| | messages=[{ |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "text": "Original:"}, |
| | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, |
| | {"type": "text", "text": "Generated:"}, |
| | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| | {"type": "text", "text": evaluation_prompts["color"]} |
| | ] |
| | }], |
| | max_tokens=300 |
| | ) |
| | results['color'] = parse_score(response.choices[0].message.content) |
| | |
| | |
| | response = client.chat.completions.create( |
| | model="gpt-4o", |
| | messages=[{ |
| | "role": "user", |
| | "content": [ |
| | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| | {"type": "text", "text": evaluation_prompts["appearance"]} |
| | ] |
| | }], |
| | max_tokens=300 |
| | ) |
| | results['appearance'] = parse_score(response.choices[0].message.content) |
| | |
| | |
| | response = client.chat.completions.create( |
| | model="gpt-4o", |
| | messages=[{ |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "text": f"Text prompt: {text_prompt}"}, |
| | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, |
| | {"type": "text", "text": evaluation_prompts["modification"].format(prompt=text_prompt)} |
| | ] |
| | }], |
| | max_tokens=300 |
| | ) |
| | results['modification'] = parse_score(response.choices[0].message.content) |
| | |
| | return results |
| |
|
| | def parse_score(response_text): |
| | """Extract score from GPT-4o response""" |
| | |
| | import re |
| | match = re.search(r'Score:\s*(\d+)', response_text) |
| | if match: |
| | return int(match.group(1)) |
| | return None |
| |
|
| | subject_names = [ |
| | "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can", |
| | "candle", "cat", "cat2", "clock", "colorful_sneaker", |
| | "dog", "dog2", "dog3", "dog5", "dog6", |
| | "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie", |
| | "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon", |
| | "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie" |
| | ] |
| |
|
| |
|
| | def get_prompt(subject_id, prompt_id): |
| | |
| | |
| | if subject_id in [0,1,2,3,4,5,8,9,17,18,19,20,21,22,23,24,25,26,27,28,29]: |
| | subject_name = subject_names[subject_id] |
| | prompts = [ |
| | f"a {subject_name} in the jungle", |
| | f"a {subject_name} in the snow", |
| | f"a {subject_name} on the beach", |
| | f"a {subject_name} on a cobblestone street", |
| | f"a {subject_name} on top of pink fabric", |
| | f"a {subject_name} on top of a wooden floor", |
| | f"a {subject_name} with a city in the background", |
| | f"a {subject_name} with a mountain in the background", |
| | f"a {subject_name} with a blue house in the background", |
| | f"a {subject_name} on top of a purple rug in a forest", |
| | f"a {subject_name} with a wheat field in the background", |
| | f"a {subject_name} with a tree and autumn leaves in the background", |
| | f"a {subject_name} with the Eiffel Tower in the background", |
| | f"a {subject_name} floating on top of water", |
| | f"a {subject_name} floating in an ocean of milk", |
| | f"a {subject_name} on top of green grass with sunflowers around it", |
| | f"a {subject_name} on top of a mirror", |
| | f"a {subject_name} on top of the sidewalk in a crowded street", |
| | f"a {subject_name} on top of a dirt road", |
| | f"a {subject_name} on top of a white rug", |
| | f"a red {subject_name}", |
| | f"a purple {subject_name}", |
| | f"a shiny {subject_name}", |
| | f"a wet {subject_name}", |
| | f"a cube shaped {subject_name}" |
| | ] |
| | |
| | else: |
| | prompts = [ |
| | f"a {subject_name} in the jungle", |
| | f"a {subject_name} in the snow", |
| | f"a {subject_name} on the beach", |
| | f"a {subject_name} on a cobblestone street", |
| | f"a {subject_name} on top of pink fabric", |
| | f"a {subject_name} on top of a wooden floor", |
| | f"a {subject_name} with a city in the background", |
| | f"a {subject_name} with a mountain in the background", |
| | f"a {subject_name} with a blue house in the background", |
| | f"a {subject_name} on top of a purple rug in a forest", |
| | f"a {subject_name} wearing a red hat", |
| | f"a {subject_name} wearing a santa hat", |
| | f"a {subject_name} wearing a rainbow scarf", |
| | f"a {subject_name} wearing a black top hat and a monocle", |
| | f"a {subject_name} in a chef outfit", |
| | f"a {subject_name} in a firefighter outfit", |
| | f"a {subject_name} in a police outfit", |
| | f"a {subject_name} wearing pink glasses", |
| | f"a {subject_name} wearing a yellow shirt", |
| | f"a {subject_name} in a purple wizard outfit", |
| | f"a red {subject_name}", |
| | f"a purple {subject_name}", |
| | f"a shiny {subject_name}", |
| | f"a wet {subject_name}", |
| | f"a cube shaped {subject_name}" |
| | ] |
| | |
| | return prompts[prompt_id] |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def batch_evaluate_dreambooth(client, generate_fn, dataset_path, output_csv): |
| | """ |
| | Evaluate 750 image pairs with 5 seeds each |
| | """ |
| | import pandas as pd |
| | |
| | results_list = [] |
| | |
| | |
| | for subject_id in range(30): |
| | subject_name = subject_names[subject_id] |
| | for prompt_id in range(25): |
| | original = f"{dataset_path}/{subject_name}" |
| | |
| | original_files = list(Path(original).glob("*.png")) |
| | if len(original_files) == 0: |
| | raise ValueError(f"No original images found in {original}") |
| | |
| | original = str(original_files[0]) |
| |
|
| | |
| | for seed in range(5): |
| | |
| | prompt = get_prompt(subject_id, prompt_id) |
| | |
| | |
| | generated_folder = f"{dataset_path}/{subject_name}/generated/" |
| | os.makedirs(generated_folder, exist_ok=True) |
| | generated = f"{generated_folder}/gen_seed{seed}_prompt{prompt_id}.png" |
| | |
| | generate_fn( |
| | prompt=prompt, |
| | subject_image_path=original, |
| | output_image_path=generated, |
| | seed=seed |
| | ) |
| | |
| | scores = evaluate_subject_driven_generation( |
| | original, generated, prompt, client |
| | ) |
| | |
| | results_list.append({ |
| | 'subject_id': subject_id, |
| | 'subject_name': subject_name, |
| | 'prompt_id': prompt_id, |
| | 'seed': seed, |
| | 'prompt': prompt, |
| | |
| | **scores |
| | }) |
| | |
| | |
| | df = pd.DataFrame(results_list) |
| | df.to_csv(output_csv, index=False) |
| | |
| | |
| | print(df.groupby('subject_id').mean()) |
| | print(f"\nOverall averages:") |
| | print(df[['identity', 'material', 'color', 'appearance', 'modification']].mean()) |
| | |
| | |
| | def evaluate_omini_control(): |
| | |
| | import torch |
| | from diffusers.pipelines import FluxPipeline |
| | from PIL import Image |
| | |
| | from omini.pipeline.flux_omini import Condition, generate, seed_everything |
| |
|
| | pipe = FluxPipeline.from_pretrained( |
| | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 |
| | ) |
| | |
| | pipe = pipe.to("cuda") |
| | pipe.load_lora_weights( |
| | "Yuanshi/OminiControl", |
| | weight_name=f"omini/subject_512.safetensors", |
| | adapter_name="subject", |
| | ) |
| | |
| | def generate_fn(image_path, prompt, seed, output_path): |
| | seed_everything(seed) |
| | |
| | image = Image.open(image_path).convert("RGB").resize((512, 512)) |
| | condition = Condition.from_image( |
| | image, |
| | "subject", position_delta=(0, 32) |
| | ) |
| | |
| | result_img = generate( |
| | pipe, |
| | prompt=prompt, |
| | conditions=[condition], |
| | ).images[0] |
| | |
| | result_img.save(output_path) |
| | |
| | return generate_fn |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | openai.api_key = os.getenv("OPENAI_API_KEY") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | result = evaluate_subject_driven_generation( |
| | "data/dreambooth/backpack/00.jpg", |
| | "data/dreambooth/backpack/01.jpg", |
| | "a backpack in the jungle", |
| | openai.Client() |
| | ) |
| | |
| | print(result) |