import json import base64 import requests from pathlib import Path from typing import Dict, Any, Optional from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm from PIL import Image # Import prompt building functions from prompts.py from prompts import make_user_query, system_prompt, prompts_b # ==================== CONFIGURATION ==================== # Captioning type (from prompts_b in prompts.py) C_TYPE = 'long_thoughts_v2' if C_TYPE not in prompts_b: raise(f"{C_TYPE} not found in known formats!") # Content options USE_NAMES = True ADD_TAGS = False ADD_CHAR_LIST = False ADD_CHARS_TAGS = False ADD_CHARS_DESCR = False # Grounding requires image folder to contain JSON files with the same name with following format: # { # "tags": [], # list of strings with tags # "characters": [], # list of strings with character tags/names # "char_p_tags": {"chars": {"Albedo": "girl", "horns", "black_hair",...}, "skins": {}}, # "char_descr": {"chars": {"Albedo": "Albedo is a curvy woman with..."}}, "skins": {}} # } # Output settings SUFFIX = "_lsv2_zs.txt" # API settings API_URL = "http://127.0.0.1:9001/v1/chat/completions" API_KEY = "not-needed" # vllm typically doesn't require auth MODEL = "toriigate-0.5" # or your local model name # Processing settings INPUT_FOLDER = "/path/to/files" #OUTPUT_FOLDER = "/path/to/output" OUTPUT_FOLDER = INPUT_FOLDER # Thread pool settings NUM_WORKERS = 16 # Image settings MAX_PIXELS = 1.0 # Maximum resolution in megapixels (e.g., 1.0 = 1MP) # Request settings MAX_TOKENS = 2048 TEMPERATURE = 0.5 REQUEST_TIMEOUT = 60 # seconds # ==================== END CONFIGURATION ==================== def encode_image_base64(image_path: str, max_pixels: float = MAX_PIXELS) -> str: """Encode image to base64 string, resizing if necessary.""" img = Image.open(image_path) # Check if resizing needed current_pixels = img.width * img.height max_pixels_count = max_pixels * 1_000_000 if current_pixels <= max_pixels_count: # No resize needed if img.mode != 'RGB': img = img.convert('RGB') with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") # Calculate new dimensions while preserving aspect ratio scale = (max_pixels_count / current_pixels) ** 0.5 new_width = int(img.width * scale) new_height = int(img.height * scale) # Resize with high quality img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) if img.mode != 'RGB': img = img.convert('RGB') # Encode resized image to base64 import io buffer = io.BytesIO() img.save(buffer, format='JPEG', quality=95) return base64.b64encode(buffer.getvalue()).decode("utf-8") def load_json_item(json_path: Optional[Path]) -> tuple[Optional[Dict[str, Any]], bool]: """ Load JSON metadata from file. Returns (data, was_loaded) tuple. If file missing/None, returns (empty_template, False). """ empty_template = { "tags": [], "characters": [], "char_p_tags": {"chars": {}, "skins": {}}, "char_descr": {"chars": {}, "skins": {}} } if json_path is None or not json_path.exists(): #print(f"[WARN] JSON file not found: {json_path.name if json_path else 'N/A'}") return empty_template, False try: with open(json_path, "r", encoding="utf-8") as f: return json.load(f), True except Exception as e: print(f"[ERROR] Failed to load {json_path}: {e}") return empty_template, False def find_image_path(image_name: str, folder: Path) -> Optional[Path]: """Find image file with given name (supports jpg, png, etc.).""" extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp'] for ext in extensions: path = folder / f"{image_name}{ext}" if path.exists(): return path return None def find_json_path(image_name: str, folder: Path) -> Optional[Path]: """Find JSON file with given name.""" path = folder / f"{image_name}.json" return path if path.exists() else None def prepare_messages(item: Dict[str, Any], image_data: str) -> list: """Prepare OpenAI-style messages for the API.""" user_query = make_user_query( item, c_type=C_TYPE, use_names=USE_NAMES, add_tags=ADD_TAGS, add_characters=ADD_CHAR_LIST, add_char_tags=ADD_CHARS_TAGS, add_descritpion=ADD_CHARS_DESCR, underscores_replace=False ) return [ { "role": "system", "content": [{"type": "text", "text": system_prompt}] }, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}, {"type": "text", "text": user_query} ] } ] def call_caption_api(messages: list) -> Optional[str]: """Call the captioning API (no retries).""" payload = { "model": MODEL, "messages": messages, "max_tokens": MAX_TOKENS, "temperature": TEMPERATURE, "stream": False } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}" } try: response = requests.post( API_URL, headers=headers, json=payload, timeout=REQUEST_TIMEOUT ) response.raise_for_status() result = response.json() content = result['choices'][0]['message']['content'] return content except requests.exceptions.RequestException as e: print(f"[API ERROR] {e}") return None except (KeyError, IndexError) as e: print(f"[PARSE ERROR] Failed to parse API response: {e}") return None return None def process_image(image_path: Path, json_path: Path) -> tuple[Optional[str], bool]: """ Process a single image and return (caption, json_loaded) tuple. If JSON missing, uses empty template. """ # Load JSON metadata item, json_loaded = load_json_item(json_path) # Encode image (with resizing if needed) try: image_data = encode_image_base64(str(image_path), MAX_PIXELS) except Exception as e: print(f"[ERROR] Failed to encode image {image_path.name}: {e}") return None, json_loaded # Prepare messages messages = prepare_messages(item, image_data) # Call API (no retries) caption = call_caption_api(messages) return caption, json_loaded def get_base_name(filename: str) -> str: """Get base name without extension.""" return Path(filename).stem def main(): """Main processing loop with progress bar.""" input_dir = Path(INPUT_FOLDER) output_dir = Path(OUTPUT_FOLDER) if not input_dir.exists(): print(f"Error: Input folder '{INPUT_FOLDER}' not found") return output_dir.mkdir(exist_ok=True) # Find all image files image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp'] image_files = [] for ext_pattern in image_extensions: image_files.extend(input_dir.glob(ext_pattern)) # Remove duplicates and sort image_files = sorted(set(image_files)) if not image_files: print(f"No image files found in '{INPUT_FOLDER}'") return print(f"Found {len(image_files)} images to process") print(f"Configuration:") print(f" C_TYPE: {C_TYPE}") print(f" USE_NAMES: {USE_NAMES}") print(f" ADD_TAGS: {ADD_TAGS}") print(f" ADD_CHAR_LIST: {ADD_CHAR_LIST}") print(f" ADD_CHARS_TAGS: {ADD_CHARS_TAGS}") print(f" ADD_CHARS_DESCR: {ADD_CHARS_DESCR}") print(f" MODEL: {MODEL}") print(f" API_URL: {API_URL}") print(f" NUM_WORKERS: {NUM_WORKERS}") print(f" MAX_PIXELS: {MAX_PIXELS} MP") print("-" * 50) processed = 0 failed = 0 json_missing = 0 # Prepare tasks tasks = [] for image_file in image_files: base_name = get_base_name(image_file.name) json_path = find_json_path(base_name, input_dir) tasks.append((image_file, json_path)) # Process with thread pool and progress bar with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: future_to_file = { executor.submit(process_image, img_path, json_path): (img_path, json_path) for img_path, json_path in tasks } for future in tqdm(as_completed(future_to_file), total=len(tasks), desc="Processing", unit="img"): image_path, json_path = future_to_file[future] output_file = output_dir / f"{get_base_name(image_path.name)}{SUFFIX}" try: caption, json_loaded = future.result() if not json_loaded: json_missing += 1 if caption: # Save caption try: with open(output_file, "w", encoding="utf-8") as f: f.write(caption) processed += 1 except Exception as e: tqdm.write(f"[ERROR] Failed to save {output_file.name}: {e}") failed += 1 else: tqdm.write(f"[ERROR] Captioning failed for {image_path.name}") failed += 1 except Exception as e: tqdm.write(f"[ERROR] Task failed for {image_path.name}: {e}") failed += 1 print("=" * 50) print(f"Processing complete:") print(f" Processed: {processed}") print(f" JSON missing (warnings): {json_missing}") print(f" Failed: {failed}") print(f" Output folder: {OUTPUT_FOLDER}") if __name__ == "__main__": main()