| | import json |
| | import random |
| | import os |
| |
|
| | |
| | NUM_TO_WORD = { |
| | 1: "one", |
| | 2: "two", |
| | 3: "three", |
| | 4: "four", |
| | } |
| |
|
| | import torch |
| | from PIL import Image |
| | import numpy as np |
| | from diffusers import FluxPipeline |
| | from flow_grpo.diffusers_patch.flux_pipeline_with_logprob import pipeline_with_logprob |
| | import importlib |
| |
|
| | model_id = "black-forest-labs/FLUX.1-dev" |
| | device = "cuda" |
| |
|
| | pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) |
| | pipe = pipe.to(device) |
| |
|
| | def process_jsonl(input_file, output_file, image_directory): |
| | """ |
| | 处理输入的jsonl文件,并生成新的jsonl文件和图片。 |
| | |
| | Args: |
| | input_file (str): 输入的jsonl文件名。 |
| | output_file (str): 输出的jsonl文件名。 |
| | image_directory (str): 保存图片的目录。 |
| | """ |
| | |
| | if not os.path.exists(image_directory): |
| | os.makedirs(image_directory) |
| |
|
| | with open(input_file, 'r', encoding='utf-8') as infile, \ |
| | open(output_file, 'w', encoding='utf-8') as outfile: |
| | for i, line in enumerate(infile): |
| | try: |
| | data = json.loads(line.strip()) |
| |
|
| | |
| | original_count = data["include"][0]["count"] |
| | class_name = data["include"][0]["class"] |
| |
|
| | image = pipe( |
| | data["t2i_prompt"], |
| | height=1024, |
| | width=1024, |
| | guidance_scale=3.5, |
| | num_inference_steps=50, |
| | max_sequence_length=512, |
| | ).images[0] |
| | image_path = os.path.join(image_directory, f"image_{i}.jpg") |
| | image.save(image_path) |
| |
|
| | |
| | change_num = set([1, 2, 3, 4]) - set([original_count]) |
| | for num in change_num: |
| | new_data = { |
| | "tag": data["tag"], |
| | "include": [{"class": class_name, "count": num}], |
| | "exclude": [{"class": class_name, "count": num + 1}], |
| | "t2i_prompt": data["t2i_prompt"], |
| | "prompt": f"Change the number of {class_name} in the image to {NUM_TO_WORD[num]}.", |
| | "image": image_path |
| | } |
| |
|
| | |
| | outfile.write(json.dumps(new_data, ensure_ascii=False) + '\n') |
| |
|
| | except (json.JSONDecodeError, KeyError, IndexError) as e: |
| | print(f"处理第 {i+1} 行时出错: {e}") |
| | continue |
| |
|
| | if __name__ == '__main__': |
| | |
| | input_filename = "metadata.jsonl" |
| | output_filename = "output.jsonl" |
| | image_save_directory = "generated_images" |
| |
|
| | |
| | process_jsonl(input_filename, output_filename, image_save_directory) |
| |
|
| | print(f"处理完成!结果已保存到 '{output_filename}',图片路径保存在 '{image_save_directory}' 目录。") |