import os import re import glob import argparse import pickle import warnings from io import BytesIO from dataclasses import dataclass from typing import Optional, List, Dict, Any, Tuple import torch from PIL import Image, ImageFile from tqdm.auto import tqdm from collections import Counter # ========================= # PIL safety # ========================= Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True warnings.simplefilter("ignore", Image.DecompressionBombWarning) # ========================= # Data record # ========================= @dataclass class GenSample: image: Any prompt: str correct_solution: str wrong_solution: str answer: str # ground-truth letter source: str # ========================= # Choice mapping # ========================= LETTERS = list("abcdefghijklmnopqrstuvwxyz") IDX2LETTER = {i: LETTERS[i] for i in range(len(LETTERS))} # ========================= # Distributed helpers # ========================= def get_dist_info(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) return local_rank, rank, world_size def init_dist_if_needed(): local_rank, rank, world_size = get_dist_info() if world_size > 1 and torch.distributed.is_available() and not torch.distributed.is_initialized(): torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl") return local_rank, rank, world_size def barrier(): if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier() def destroy_dist(): if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.destroy_process_group() # ========================= # Boxed answer utils # ========================= BOX_RE = re.compile(r"\\boxed\{([^}]+)\}") def extract_boxed_answer(text: str) -> Optional[str]: if not text: return None ms = BOX_RE.findall(text) if not ms: return None return ms[-1].strip().lower() def count_boxed(text: str) -> int: return len(BOX_RE.findall(text or "")) def strip_last_boxed(text: str) -> str: if not text: return text s = text.rstrip() s2 = re.sub(r"\s*\\boxed\{[^}]+\}\s*$", "", s, flags=re.DOTALL) if s2 != s: return s2.rstrip() matches = list(BOX_RE.finditer(s)) if not matches: return s m = matches[-1] return (s[:m.start()] + s[m.end():]).rstrip() # ========================= # Image loader # ========================= def _pil_from_any(img: Any) -> Optional[Image.Image]: if img is None: return None if isinstance(img, Image.Image): return img.convert("RGB") if isinstance(img, dict) and img.get("bytes") is not None: try: with Image.open(BytesIO(img["bytes"])) as im: return im.convert("RGB") except Exception: return None if isinstance(img, str) and os.path.exists(img): try: with Image.open(img) as im: return im.convert("RGB") except Exception: return None return None def get_pil_image(ex: Dict[str, Any]) -> Optional[Image.Image]: for k in ("decoded_image", "image"): if k in ex: im = _pil_from_any(ex.get(k)) if im is not None: return im return None # ========================= # Prompt # ========================= SOLVER_SYSTEM = "You are a rigorous visual question answering expert." def solver_text(question: str, choices: List[str]) -> str: if len(choices) > len(IDX2LETTER): raise ValueError(f"Too many choices: {len(choices)}") opts = "\n".join([f"{IDX2LETTER[i]}. {c}" for i, c in enumerate(choices)]) return ( "Solve the following multiple-choice problem step by step.\n\n" f"Problem:\n{question}\n\n" f"Choices:\n{opts}\n\n" "Give your reasoning in plain text.\n" "At the end, output your answer ONLY in LaTeX boxed format, e.g. \\boxed{a}.\n" ) def build_messages(system_text, user_text, image): if image is not None: return [ {"role": "system", "content": [{"type": "text", "text": system_text}]}, {"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": user_text} ]}, ] return [ {"role": "system", "content": [{"type": "text", "text": system_text}]}, {"role": "user", "content": [{"type": "text", "text": user_text}]}, ] # ========================= # Qwen runner (FIXED padding slicing) # ========================= class QwenBatchRunner: def __init__(self, model_id, cache_dir, local_rank): from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration self.device = torch.device(f"cuda:{local_rank}") self.processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir) self.processor.tokenizer.padding_side = "left" if self.processor.tokenizer.pad_token_id is None: self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map={"": local_rank}, attn_implementation="flash_attention_2", ).eval() @torch.inference_mode() def generate_batch(self, messages, images, max_new_tokens, temperature, do_sample=True): texts = [ self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages ] enc = self.processor( text=texts, images=images if any(images) else None, padding=True, return_tensors="pt", ) enc = {k: v.to(self.device) for k, v in enc.items()} gen_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=do_sample, pad_token_id=self.processor.tokenizer.pad_token_id, eos_token_id=self.processor.tokenizer.eos_token_id, ) if do_sample: gen_kwargs["temperature"] = temperature out = self.model.generate(**enc, **gen_kwargs) in_len = enc["input_ids"].shape[1] outs = [] for o in out: outs.append(self.processor.tokenizer.decode(o[in_len:], skip_special_tokens=True).strip()) return outs # ========================= # Mix helper # ========================= def interleave(a: List[Any], b: List[Any]) -> List[Any]: out = [] i = j = 0 while i < len(a) or j < len(b): if i < len(a): out.append(a[i]); i += 1 if j < len(b): out.append(b[j]); j += 1 return out # ========================= # Main # ========================= def main(): ap = argparse.ArgumentParser() ap.add_argument("--model_id", default="Qwen/Qwen2.5-VL-7B-Instruct") ap.add_argument("--dataset_id", default="HuggingFaceM4/A-OKVQA") ap.add_argument("--split", default="train") ap.add_argument("--scienceqa_id", default="derek-thomas/ScienceQA") ap.add_argument("--scienceqa_split", default=None) ap.add_argument("--cache_dir", default=None) ap.add_argument("--out_pkl", default="train.pkl") ap.add_argument("--batch_size", type=int, default=64) ap.add_argument("--max_items", type=int, default=3000) ap.add_argument("--solver_max_new_tokens", type=int, default=512) ap.add_argument("--solver_temp", type=float, default=0.1) ap.add_argument("--solver_greedy", action="store_true") args = ap.parse_args() local_rank, rank, world_size = init_dist_if_needed() is_master = rank == 0 from datasets import load_dataset, Image as HFImage sq_split = args.scienceqa_split or args.split if world_size > 1 and is_master: load_dataset(args.dataset_id, split=args.split, cache_dir=args.cache_dir) load_dataset(args.scienceqa_id, split=sq_split, cache_dir=args.cache_dir) barrier() ds_ok = load_dataset(args.dataset_id, split=args.split, cache_dir=args.cache_dir) ds_sq = load_dataset(args.scienceqa_id, split=sq_split, cache_dir=args.cache_dir) if "image" in ds_ok.column_names and isinstance(ds_ok.features.get("image", None), HFImage): ds_ok = ds_ok.cast_column("image", HFImage(decode=False)) if "image" in ds_sq.column_names and isinstance(ds_sq.features.get("image", None), HFImage): ds_sq = ds_sq.cast_column("image", HFImage(decode=False)) ok_indices = list(range(rank, len(ds_ok), world_size)) sq_indices = list(range(rank, len(ds_sq), world_size)) if args.max_items and args.max_items > 0: ok_lim = args.max_items // 2 sq_lim = args.max_items - ok_lim ok_indices = ok_indices[:ok_lim] sq_indices = sq_indices[:sq_lim] items = interleave( [("okvqa", i) for i in ok_indices], [("scienceqa", i) for i in sq_indices], ) runner = QwenBatchRunner(args.model_id, args.cache_dir, local_rank) samples: List[GenSample] = [] def build_meta_okvqa(ex): gt_idx = ex.get("correct_choice_idx", None) if gt_idx is None: return None gt_idx = int(gt_idx) if gt_idx == 2: return None choices = ex.get("choices", None) if not isinstance(choices, (list, tuple)) or len(choices) < 3: return None image = get_pil_image(ex) if image is None: return None question = ex.get("question", "") choices = [str(c) for c in choices] prompt = solver_text(question, choices) return { "image": image, "prompt": prompt, "gt_letter": IDX2LETTER[gt_idx], "source": "aokvqa", } def build_meta_scienceqa(ex): choices = ex.get("choices", None) if not isinstance(choices, (list, tuple)) or len(choices) < 3: return None gt_idx = ex.get("answer", None) if gt_idx is None: return None gt_idx = int(gt_idx) if gt_idx == 2: return None image = get_pil_image(ex) if image is None: return None question = ex.get("question", "") choices = [str(c) for c in choices] prompt = solver_text(question, choices) return { "image": image, "prompt": prompt, "gt_letter": IDX2LETTER[gt_idx], "source": "scienceqa", } for b in tqdm(range(0, len(items), args.batch_size), desc=f"rank{rank}"): batch_items = items[b:b + args.batch_size] metas, solver_messages, solver_images = [], [], [] for tag, i in batch_items: ex = ds_ok[i] if tag == "okvqa" else ds_sq[i] meta = build_meta_okvqa(ex) if tag == "okvqa" else build_meta_scienceqa(ex) if meta is None: continue solver_messages.append(build_messages(SOLVER_SYSTEM, meta["prompt"], meta["image"])) solver_images.append(meta["image"]) metas.append(meta) if not metas: continue solver_outs = runner.generate_batch( solver_messages, solver_images, max_new_tokens=args.solver_max_new_tokens, temperature=args.solver_temp, do_sample=(not args.solver_greedy), ) for meta, solver_out in zip(metas, solver_outs): if extract_boxed_answer(solver_out) != meta["gt_letter"]: continue if count_boxed(solver_out) != 1: continue base = strip_last_boxed(solver_out).rstrip() if count_boxed(base) != 0: continue wrong_solution = base + "\n\n" + r"but, the answer is \boxed{c}" if count_boxed(wrong_solution) != 1: continue if extract_boxed_answer(wrong_solution) != "c": continue if not re.search(r"\\boxed\{c\}\s*$", wrong_solution): continue samples.append(GenSample( image=meta["image"], prompt=meta["prompt"], correct_solution=solver_out, wrong_solution=wrong_solution, answer=meta["gt_letter"], source=meta["source"] )) shard_pkl = args.out_pkl if world_size == 1 else f"{args.out_pkl}.rank{rank}" with open(shard_pkl, "wb") as f: pickle.dump(samples, f) barrier() # ========================= # Merge shards (rank0) + print source stats # ========================= if world_size > 1 and is_master: all_samples: List[GenSample] = [] for fp in sorted(glob.glob(args.out_pkl + ".rank*")): with open(fp, "rb") as f: all_samples.extend(pickle.load(f)) with open(args.out_pkl, "wb") as f: pickle.dump(all_samples, f) cnt = Counter([s.source for s in all_samples]) print(f"[rank0] merged total={len(all_samples)} -> {args.out_pkl}") print(f"[rank0] by source: scienceqa={cnt.get('scienceqa', 0)}, aokvqa={cnt.get('aokvqa', 0)}") if world_size == 1 and is_master: cnt = Counter([s.source for s in samples]) print(f"[rank0] total={len(samples)} -> {args.out_pkl}") print(f"[rank0] by source: scienceqa={cnt.get('scienceqa', 0)}, aokvqa={cnt.get('aokvqa', 0)}") destroy_dist() if __name__ == "__main__": main()