ToriiGate-0.5 / scripts /caption_distributed.py
Minthy's picture
Upload folder using huggingface_hub
9229b0a verified
import json
import base64
import requests
from pathlib import Path
from typing import Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from PIL import Image
# Import prompt building functions from prompts.py
from prompts import make_user_query, system_prompt, prompts_b
# ==================== CONFIGURATION ====================
# Captioning type (from prompts_b in prompts.py)
C_TYPE = 'long_thoughts_v2'
if C_TYPE not in prompts_b:
raise(f"{C_TYPE} not found in known formats!")
# Content options
USE_NAMES = True
ADD_TAGS = False
ADD_CHAR_LIST = False
ADD_CHARS_TAGS = False
ADD_CHARS_DESCR = False
# Grounding requires image folder to contain JSON files with the same name with following format:
# {
# "tags": [], # list of strings with tags
# "characters": [], # list of strings with character tags/names
# "char_p_tags": {"chars": {"Albedo": "girl", "horns", "black_hair",...}, "skins": {}},
# "char_descr": {"chars": {"Albedo": "Albedo is a curvy woman with..."}}, "skins": {}}
# }
# Output settings
SUFFIX = "_lsv2_zs.txt"
# API settings
API_URL = "http://127.0.0.1:9001/v1/chat/completions"
API_KEY = "not-needed" # vllm typically doesn't require auth
MODEL = "toriigate-0.5" # or your local model name
# Processing settings
INPUT_FOLDER = "/path/to/files"
#OUTPUT_FOLDER = "/path/to/output"
OUTPUT_FOLDER = INPUT_FOLDER
# Thread pool settings
NUM_WORKERS = 16
# Image settings
MAX_PIXELS = 1.0 # Maximum resolution in megapixels (e.g., 1.0 = 1MP)
# Request settings
MAX_TOKENS = 2048
TEMPERATURE = 0.5
REQUEST_TIMEOUT = 60 # seconds
# ==================== END CONFIGURATION ====================
def encode_image_base64(image_path: str, max_pixels: float = MAX_PIXELS) -> str:
"""Encode image to base64 string, resizing if necessary."""
img = Image.open(image_path)
# Check if resizing needed
current_pixels = img.width * img.height
max_pixels_count = max_pixels * 1_000_000
if current_pixels <= max_pixels_count:
# No resize needed
if img.mode != 'RGB':
img = img.convert('RGB')
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
# Calculate new dimensions while preserving aspect ratio
scale = (max_pixels_count / current_pixels) ** 0.5
new_width = int(img.width * scale)
new_height = int(img.height * scale)
# Resize with high quality
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
if img.mode != 'RGB':
img = img.convert('RGB')
# Encode resized image to base64
import io
buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=95)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def load_json_item(json_path: Optional[Path]) -> tuple[Optional[Dict[str, Any]], bool]:
"""
Load JSON metadata from file.
Returns (data, was_loaded) tuple. If file missing/None, returns (empty_template, False).
"""
empty_template = {
"tags": [],
"characters": [],
"char_p_tags": {"chars": {}, "skins": {}},
"char_descr": {"chars": {}, "skins": {}}
}
if json_path is None or not json_path.exists():
#print(f"[WARN] JSON file not found: {json_path.name if json_path else 'N/A'}")
return empty_template, False
try:
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f), True
except Exception as e:
print(f"[ERROR] Failed to load {json_path}: {e}")
return empty_template, False
def find_image_path(image_name: str, folder: Path) -> Optional[Path]:
"""Find image file with given name (supports jpg, png, etc.)."""
extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp']
for ext in extensions:
path = folder / f"{image_name}{ext}"
if path.exists():
return path
return None
def find_json_path(image_name: str, folder: Path) -> Optional[Path]:
"""Find JSON file with given name."""
path = folder / f"{image_name}.json"
return path if path.exists() else None
def prepare_messages(item: Dict[str, Any], image_data: str) -> list:
"""Prepare OpenAI-style messages for the API."""
user_query = make_user_query(
item,
c_type=C_TYPE,
use_names=USE_NAMES,
add_tags=ADD_TAGS,
add_characters=ADD_CHAR_LIST,
add_char_tags=ADD_CHARS_TAGS,
add_descritpion=ADD_CHARS_DESCR,
underscores_replace=False
)
return [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}},
{"type": "text", "text": user_query}
]
}
]
def call_caption_api(messages: list) -> Optional[str]:
"""Call the captioning API (no retries)."""
payload = {
"model": MODEL,
"messages": messages,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
"stream": False
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
try:
response = requests.post(
API_URL,
headers=headers,
json=payload,
timeout=REQUEST_TIMEOUT
)
response.raise_for_status()
result = response.json()
content = result['choices'][0]['message']['content']
return content
except requests.exceptions.RequestException as e:
print(f"[API ERROR] {e}")
return None
except (KeyError, IndexError) as e:
print(f"[PARSE ERROR] Failed to parse API response: {e}")
return None
return None
def process_image(image_path: Path, json_path: Path) -> tuple[Optional[str], bool]:
"""
Process a single image and return (caption, json_loaded) tuple.
If JSON missing, uses empty template.
"""
# Load JSON metadata
item, json_loaded = load_json_item(json_path)
# Encode image (with resizing if needed)
try:
image_data = encode_image_base64(str(image_path), MAX_PIXELS)
except Exception as e:
print(f"[ERROR] Failed to encode image {image_path.name}: {e}")
return None, json_loaded
# Prepare messages
messages = prepare_messages(item, image_data)
# Call API (no retries)
caption = call_caption_api(messages)
return caption, json_loaded
def get_base_name(filename: str) -> str:
"""Get base name without extension."""
return Path(filename).stem
def main():
"""Main processing loop with progress bar."""
input_dir = Path(INPUT_FOLDER)
output_dir = Path(OUTPUT_FOLDER)
if not input_dir.exists():
print(f"Error: Input folder '{INPUT_FOLDER}' not found")
return
output_dir.mkdir(exist_ok=True)
# Find all image files
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp']
image_files = []
for ext_pattern in image_extensions:
image_files.extend(input_dir.glob(ext_pattern))
# Remove duplicates and sort
image_files = sorted(set(image_files))
if not image_files:
print(f"No image files found in '{INPUT_FOLDER}'")
return
print(f"Found {len(image_files)} images to process")
print(f"Configuration:")
print(f" C_TYPE: {C_TYPE}")
print(f" USE_NAMES: {USE_NAMES}")
print(f" ADD_TAGS: {ADD_TAGS}")
print(f" ADD_CHAR_LIST: {ADD_CHAR_LIST}")
print(f" ADD_CHARS_TAGS: {ADD_CHARS_TAGS}")
print(f" ADD_CHARS_DESCR: {ADD_CHARS_DESCR}")
print(f" MODEL: {MODEL}")
print(f" API_URL: {API_URL}")
print(f" NUM_WORKERS: {NUM_WORKERS}")
print(f" MAX_PIXELS: {MAX_PIXELS} MP")
print("-" * 50)
processed = 0
failed = 0
json_missing = 0
# Prepare tasks
tasks = []
for image_file in image_files:
base_name = get_base_name(image_file.name)
json_path = find_json_path(base_name, input_dir)
tasks.append((image_file, json_path))
# Process with thread pool and progress bar
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
future_to_file = {
executor.submit(process_image, img_path, json_path): (img_path, json_path)
for img_path, json_path in tasks
}
for future in tqdm(as_completed(future_to_file), total=len(tasks), desc="Processing", unit="img"):
image_path, json_path = future_to_file[future]
output_file = output_dir / f"{get_base_name(image_path.name)}{SUFFIX}"
try:
caption, json_loaded = future.result()
if not json_loaded:
json_missing += 1
if caption:
# Save caption
try:
with open(output_file, "w", encoding="utf-8") as f:
f.write(caption)
processed += 1
except Exception as e:
tqdm.write(f"[ERROR] Failed to save {output_file.name}: {e}")
failed += 1
else:
tqdm.write(f"[ERROR] Captioning failed for {image_path.name}")
failed += 1
except Exception as e:
tqdm.write(f"[ERROR] Task failed for {image_path.name}: {e}")
failed += 1
print("=" * 50)
print(f"Processing complete:")
print(f" Processed: {processed}")
print(f" JSON missing (warnings): {json_missing}")
print(f" Failed: {failed}")
print(f" Output folder: {OUTPUT_FOLDER}")
if __name__ == "__main__":
main()