| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import timm |
| import torch |
| from huggingface_hub import hf_hub_download |
| from huggingface_hub.utils import HfHubHTTPError |
| from PIL import Image |
| from simple_parsing import field, parse_known_args |
| from timm.data import create_transform, resolve_data_config |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
| import os |
| import time |
| from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer |
| from PIL import Image, UnidentifiedImageError |
| from pathlib import Path |
| from tqdm import tqdm |
|
|
| @dataclass |
| class ScriptOptions: |
| image_folder: Path = "/workspace/ds/reddit" |
| model: str = field(default="vit") |
| gen_threshold: float = field(default=0.7) |
| char_threshold: float = field(default=0.6) |
|
|
| dream_model = AutoModelForCausalLM.from_pretrained( |
| "moondream/moondream-2b-2025-04-14-4bit", |
| trust_remote_code=True, |
| device_map={"": "cuda"} |
| ) |
| dream_model.model.compile() |
|
|
| torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| new_path = '/workspace/wdv3-timm' |
| os.chdir(new_path) |
| print(os.getcwd()) |
|
|
| MODEL_REPO_MAP = { |
| "vit": "SmilingWolf/wd-vit-tagger-v3", |
| "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", |
| "convnext": "SmilingWolf/wd-convnext-tagger-v3", |
| } |
|
|
| def pil_ensure_rgb(image: Image.Image) -> Image.Image: |
| if image.mode not in ["RGB", "RGBA"]: |
| image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") |
| if image.mode == "RGBA": |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) |
| canvas.alpha_composite(image) |
| image = canvas.convert("RGB") |
| return image |
|
|
| def pil_pad_square(image: Image.Image) -> Image.Image: |
| w, h = image.size |
| px = max(image.size) |
| canvas = Image.new("RGB", (px, px), (255, 255, 255)) |
| canvas.paste(image, ((px - w) // 2, (px - h) // 2)) |
| return canvas |
|
|
| @dataclass |
| class LabelData: |
| names: list[str] |
| rating: list[np.int64] |
| general: list[np.int64] |
| character: list[np.int64] |
|
|
| def load_labels_hf( |
| repo_id: str, |
| revision: Optional[str] = None, |
| token: Optional[str] = None, |
| ) -> LabelData: |
| try: |
| csv_path = hf_hub_download( |
| repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token |
| ) |
| csv_path = Path(csv_path).resolve() |
| except HfHubHTTPError as e: |
| raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e |
|
|
| df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) |
| tag_data = LabelData( |
| names=df["name"].tolist(), |
| rating=list(np.where(df["category"] == 9)[0]), |
| general=list(np.where(df["category"] == 0)[0]), |
| character=list(np.where(df["category"] == 4)[0]), |
| ) |
|
|
| return tag_data |
|
|
| def get_tags( |
| probs: Tensor, |
| labels: LabelData, |
| gen_threshold: float, |
| char_threshold: float, |
| ): |
| probs = list(zip(labels.names, probs.numpy())) |
|
|
| rating_labels = dict([probs[i] for i in labels.rating]) |
| rating_labels = dict(sorted(rating_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
| gen_labels = [probs[i] for i in labels.general] |
| gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) |
| gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
| char_labels = [probs[i] for i in labels.character] |
| char_labels = dict([x for x in char_labels if x[1] > char_threshold]) |
| char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
| combined_names = [x for x in gen_labels] |
| combined_names.extend([x for x in char_labels]) |
|
|
| caption = ", ".join(combined_names) |
|
|
| taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") |
|
|
| caption = caption.replace("_", " ") |
| caption += ", rating_" + next(iter(sorted(rating_labels, key=rating_labels.get, reverse=True)), '') |
|
|
| return caption, taglist, rating_labels, char_labels, gen_labels |
|
|
| def get_all_images(folder): |
| count = 0 |
|
|
| for path in folder.rglob('*'): |
| if path.suffix.lower() in ('.jpeg', '.jpg', '.png'): |
| count += 1 |
| yield path |
|
|
| def main(opts: ScriptOptions): |
| repo_id = MODEL_REPO_MAP.get(opts.model) |
| image_folder = Path(opts.image_folder).resolve() |
| if not image_folder.is_dir(): |
| raise NotADirectoryError(f"Image folder not found: {image_folder}") |
|
|
| print(f"Loading model '{opts.model}' from '{repo_id}'...") |
| model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval() |
| state_dict = timm.models.load_state_dict_from_hf(repo_id) |
| model.load_state_dict(state_dict) |
|
|
| print("Loading tag list...") |
| labels: LabelData = load_labels_hf(repo_id=repo_id) |
|
|
| print("Creating data transform...") |
| transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) |
|
|
| image_paths = list(get_all_images(image_folder)) |
| num_images = len(image_paths) |
|
|
| for image_path in tqdm(image_paths, desc="Processing images"): |
| txt_file = image_path.with_suffix('.txt') |
| if txt_file.exists(): |
| continue |
| try: |
| img_input: Image.Image = Image.open(image_path) |
| img_input = pil_ensure_rgb(img_input) |
| img_input = pil_pad_square(img_input) |
| inputs: Tensor = transform(img_input).unsqueeze(0) |
| inputs = inputs[:, [2, 1, 0]] |
|
|
| with torch.inference_mode(): |
| mdream_capt = dream_model.caption(img_input, length="normal")["caption"] |
| mdream_capt = mdream_capt.replace("The image depicts ", "").replace("The image presents ", "").replace("The image features ", "").replace("The image portrays ", "").replace("The image is ", "").strip() |
|
|
| if torch_device.type != "cpu": |
| model = model.to(torch_device) |
| inputs = inputs.to(torch_device) |
| outputs = model.forward(inputs) |
| outputs = F.sigmoid(outputs) |
| if torch_device.type != "cpu": |
| inputs = inputs.to("cpu") |
| outputs = outputs.to("cpu") |
| model = model.to("cpu") |
|
|
| caption, taglist, ratings, character, general = get_tags( |
| probs=outputs.squeeze(0), |
| labels=labels, |
| gen_threshold=opts.gen_threshold, |
| char_threshold=opts.char_threshold, |
| ) |
|
|
| clean_name = image_path.stem |
| clean_name = ' '.join(word for word in clean_name.split() if not word.startswith(('1', '2', '3', '4', '5', '6', '7', '8', '9', '0'))) |
|
|
| tags_filename = str(image_path.with_suffix('.tag')) |
| text_filename = str(image_path.with_suffix('.txt')) |
|
|
| with open(tags_filename, 'w') as file_tag: |
| file_tag.write(f"{caption}") |
| with open(text_filename, 'w') as file_txt: |
| file_txt.write(f"{mdream_capt} {caption}. {clean_name}") |
|
|
| except (OSError, UnidentifiedImageError) as e: |
| print(f"Error processing {image_path}: {str(e)}") |
| continue |
|
|
| print("Done!") |
|
|
| if __name__ == "__main__": |
| opts, _ = parse_known_args(ScriptOptions) |
| main(opts) |
|
|