| | import os |
| | import re |
| | import glob |
| | import argparse |
| | import pickle |
| | import warnings |
| | from io import BytesIO |
| | from dataclasses import dataclass |
| | from typing import Optional, List, Dict, Any, Tuple |
| |
|
| | import torch |
| | from PIL import Image, ImageFile |
| | from tqdm.auto import tqdm |
| | from collections import Counter |
| |
|
| | |
| | |
| | |
| | Image.MAX_IMAGE_PIXELS = None |
| | ImageFile.LOAD_TRUNCATED_IMAGES = True |
| | warnings.simplefilter("ignore", Image.DecompressionBombWarning) |
| |
|
| | |
| | |
| | |
| | @dataclass |
| | class GenSample: |
| | image: Any |
| | prompt: str |
| | correct_solution: str |
| | wrong_solution: str |
| | answer: str |
| | source: str |
| |
|
| | |
| | |
| | |
| | LETTERS = list("abcdefghijklmnopqrstuvwxyz") |
| | IDX2LETTER = {i: LETTERS[i] for i in range(len(LETTERS))} |
| |
|
| | |
| | |
| | |
| | def get_dist_info(): |
| | local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| | rank = int(os.environ.get("RANK", 0)) |
| | world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| | return local_rank, rank, world_size |
| |
|
| | def init_dist_if_needed(): |
| | local_rank, rank, world_size = get_dist_info() |
| | if world_size > 1 and torch.distributed.is_available() and not torch.distributed.is_initialized(): |
| | torch.cuda.set_device(local_rank) |
| | torch.distributed.init_process_group(backend="nccl") |
| | return local_rank, rank, world_size |
| |
|
| | def barrier(): |
| | if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| |
|
| | def destroy_dist(): |
| | if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| | torch.distributed.destroy_process_group() |
| |
|
| | |
| | |
| | |
| | BOX_RE = re.compile(r"\\boxed\{([^}]+)\}") |
| |
|
| | def extract_boxed_answer(text: str) -> Optional[str]: |
| | if not text: |
| | return None |
| | ms = BOX_RE.findall(text) |
| | if not ms: |
| | return None |
| | return ms[-1].strip().lower() |
| |
|
| | def count_boxed(text: str) -> int: |
| | return len(BOX_RE.findall(text or "")) |
| |
|
| | def strip_last_boxed(text: str) -> str: |
| | if not text: |
| | return text |
| | s = text.rstrip() |
| | s2 = re.sub(r"\s*\\boxed\{[^}]+\}\s*$", "", s, flags=re.DOTALL) |
| | if s2 != s: |
| | return s2.rstrip() |
| | matches = list(BOX_RE.finditer(s)) |
| | if not matches: |
| | return s |
| | m = matches[-1] |
| | return (s[:m.start()] + s[m.end():]).rstrip() |
| |
|
| | |
| | |
| | |
| | def _pil_from_any(img: Any) -> Optional[Image.Image]: |
| | if img is None: |
| | return None |
| | if isinstance(img, Image.Image): |
| | return img.convert("RGB") |
| | if isinstance(img, dict) and img.get("bytes") is not None: |
| | try: |
| | with Image.open(BytesIO(img["bytes"])) as im: |
| | return im.convert("RGB") |
| | except Exception: |
| | return None |
| | if isinstance(img, str) and os.path.exists(img): |
| | try: |
| | with Image.open(img) as im: |
| | return im.convert("RGB") |
| | except Exception: |
| | return None |
| | return None |
| |
|
| | def get_pil_image(ex: Dict[str, Any]) -> Optional[Image.Image]: |
| | for k in ("decoded_image", "image"): |
| | if k in ex: |
| | im = _pil_from_any(ex.get(k)) |
| | if im is not None: |
| | return im |
| | return None |
| |
|
| | |
| | |
| | |
| | SOLVER_SYSTEM = "You are a rigorous visual question answering expert." |
| |
|
| | def solver_text(question: str, choices: List[str]) -> str: |
| | if len(choices) > len(IDX2LETTER): |
| | raise ValueError(f"Too many choices: {len(choices)}") |
| | opts = "\n".join([f"{IDX2LETTER[i]}. {c}" for i, c in enumerate(choices)]) |
| | return ( |
| | "Solve the following multiple-choice problem step by step.\n\n" |
| | f"Problem:\n{question}\n\n" |
| | f"Choices:\n{opts}\n\n" |
| | "Give your reasoning in plain text.\n" |
| | "At the end, output your answer ONLY in LaTeX boxed format, e.g. \\boxed{a}.\n" |
| | ) |
| |
|
| | def build_messages(system_text, user_text, image): |
| | if image is not None: |
| | return [ |
| | {"role": "system", "content": [{"type": "text", "text": system_text}]}, |
| | {"role": "user", "content": [ |
| | {"type": "image", "image": image}, |
| | {"type": "text", "text": user_text} |
| | ]}, |
| | ] |
| | return [ |
| | {"role": "system", "content": [{"type": "text", "text": system_text}]}, |
| | {"role": "user", "content": [{"type": "text", "text": user_text}]}, |
| | ] |
| |
|
| | |
| | |
| | |
| | class QwenBatchRunner: |
| | def __init__(self, model_id, cache_dir, local_rank): |
| | from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
| | self.device = torch.device(f"cuda:{local_rank}") |
| | self.processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir) |
| | self.processor.tokenizer.padding_side = "left" |
| | if self.processor.tokenizer.pad_token_id is None: |
| | self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id |
| |
|
| | self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.bfloat16, |
| | device_map={"": local_rank}, |
| | attn_implementation="flash_attention_2", |
| | ).eval() |
| |
|
| | @torch.inference_mode() |
| | def generate_batch(self, messages, images, max_new_tokens, temperature, do_sample=True): |
| | texts = [ |
| | self.processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) |
| | for m in messages |
| | ] |
| | enc = self.processor( |
| | text=texts, |
| | images=images if any(images) else None, |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| | enc = {k: v.to(self.device) for k, v in enc.items()} |
| |
|
| | gen_kwargs = dict( |
| | max_new_tokens=max_new_tokens, |
| | do_sample=do_sample, |
| | pad_token_id=self.processor.tokenizer.pad_token_id, |
| | eos_token_id=self.processor.tokenizer.eos_token_id, |
| | ) |
| | if do_sample: |
| | gen_kwargs["temperature"] = temperature |
| |
|
| | out = self.model.generate(**enc, **gen_kwargs) |
| |
|
| | in_len = enc["input_ids"].shape[1] |
| | outs = [] |
| | for o in out: |
| | outs.append(self.processor.tokenizer.decode(o[in_len:], skip_special_tokens=True).strip()) |
| | return outs |
| |
|
| | |
| | |
| | |
| | def interleave(a: List[Any], b: List[Any]) -> List[Any]: |
| | out = [] |
| | i = j = 0 |
| | while i < len(a) or j < len(b): |
| | if i < len(a): |
| | out.append(a[i]); i += 1 |
| | if j < len(b): |
| | out.append(b[j]); j += 1 |
| | return out |
| |
|
| | |
| | |
| | |
| | def main(): |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--model_id", default="Qwen/Qwen2.5-VL-7B-Instruct") |
| |
|
| | ap.add_argument("--dataset_id", default="HuggingFaceM4/A-OKVQA") |
| | ap.add_argument("--split", default="train") |
| |
|
| | ap.add_argument("--scienceqa_id", default="derek-thomas/ScienceQA") |
| | ap.add_argument("--scienceqa_split", default=None) |
| |
|
| | ap.add_argument("--cache_dir", default=None) |
| | ap.add_argument("--out_pkl", default="train.pkl") |
| | ap.add_argument("--batch_size", type=int, default=64) |
| | ap.add_argument("--max_items", type=int, default=3000) |
| |
|
| | ap.add_argument("--solver_max_new_tokens", type=int, default=512) |
| | ap.add_argument("--solver_temp", type=float, default=0.1) |
| | ap.add_argument("--solver_greedy", action="store_true") |
| | args = ap.parse_args() |
| |
|
| | local_rank, rank, world_size = init_dist_if_needed() |
| | is_master = rank == 0 |
| |
|
| | from datasets import load_dataset, Image as HFImage |
| |
|
| | sq_split = args.scienceqa_split or args.split |
| |
|
| | if world_size > 1 and is_master: |
| | load_dataset(args.dataset_id, split=args.split, cache_dir=args.cache_dir) |
| | load_dataset(args.scienceqa_id, split=sq_split, cache_dir=args.cache_dir) |
| | barrier() |
| |
|
| | ds_ok = load_dataset(args.dataset_id, split=args.split, cache_dir=args.cache_dir) |
| | ds_sq = load_dataset(args.scienceqa_id, split=sq_split, cache_dir=args.cache_dir) |
| |
|
| | if "image" in ds_ok.column_names and isinstance(ds_ok.features.get("image", None), HFImage): |
| | ds_ok = ds_ok.cast_column("image", HFImage(decode=False)) |
| |
|
| | if "image" in ds_sq.column_names and isinstance(ds_sq.features.get("image", None), HFImage): |
| | ds_sq = ds_sq.cast_column("image", HFImage(decode=False)) |
| |
|
| | ok_indices = list(range(rank, len(ds_ok), world_size)) |
| | sq_indices = list(range(rank, len(ds_sq), world_size)) |
| |
|
| | if args.max_items and args.max_items > 0: |
| | ok_lim = args.max_items // 2 |
| | sq_lim = args.max_items - ok_lim |
| | ok_indices = ok_indices[:ok_lim] |
| | sq_indices = sq_indices[:sq_lim] |
| |
|
| | items = interleave( |
| | [("okvqa", i) for i in ok_indices], |
| | [("scienceqa", i) for i in sq_indices], |
| | ) |
| |
|
| | runner = QwenBatchRunner(args.model_id, args.cache_dir, local_rank) |
| | samples: List[GenSample] = [] |
| |
|
| | def build_meta_okvqa(ex): |
| | gt_idx = ex.get("correct_choice_idx", None) |
| | if gt_idx is None: |
| | return None |
| | gt_idx = int(gt_idx) |
| | if gt_idx == 2: |
| | return None |
| | choices = ex.get("choices", None) |
| | if not isinstance(choices, (list, tuple)) or len(choices) < 3: |
| | return None |
| | image = get_pil_image(ex) |
| | if image is None: |
| | return None |
| | question = ex.get("question", "") |
| | choices = [str(c) for c in choices] |
| | prompt = solver_text(question, choices) |
| | return { |
| | "image": image, |
| | "prompt": prompt, |
| | "gt_letter": IDX2LETTER[gt_idx], |
| | "source": "aokvqa", |
| | } |
| |
|
| | def build_meta_scienceqa(ex): |
| | choices = ex.get("choices", None) |
| | if not isinstance(choices, (list, tuple)) or len(choices) < 3: |
| | return None |
| | gt_idx = ex.get("answer", None) |
| | if gt_idx is None: |
| | return None |
| | gt_idx = int(gt_idx) |
| | if gt_idx == 2: |
| | return None |
| | image = get_pil_image(ex) |
| | if image is None: |
| | return None |
| | question = ex.get("question", "") |
| | choices = [str(c) for c in choices] |
| | prompt = solver_text(question, choices) |
| | return { |
| | "image": image, |
| | "prompt": prompt, |
| | "gt_letter": IDX2LETTER[gt_idx], |
| | "source": "scienceqa", |
| | } |
| |
|
| | for b in tqdm(range(0, len(items), args.batch_size), desc=f"rank{rank}"): |
| | batch_items = items[b:b + args.batch_size] |
| | metas, solver_messages, solver_images = [], [], [] |
| |
|
| | for tag, i in batch_items: |
| | ex = ds_ok[i] if tag == "okvqa" else ds_sq[i] |
| | meta = build_meta_okvqa(ex) if tag == "okvqa" else build_meta_scienceqa(ex) |
| | if meta is None: |
| | continue |
| | solver_messages.append(build_messages(SOLVER_SYSTEM, meta["prompt"], meta["image"])) |
| | solver_images.append(meta["image"]) |
| | metas.append(meta) |
| |
|
| | if not metas: |
| | continue |
| |
|
| | solver_outs = runner.generate_batch( |
| | solver_messages, |
| | solver_images, |
| | max_new_tokens=args.solver_max_new_tokens, |
| | temperature=args.solver_temp, |
| | do_sample=(not args.solver_greedy), |
| | ) |
| |
|
| | for meta, solver_out in zip(metas, solver_outs): |
| | if extract_boxed_answer(solver_out) != meta["gt_letter"]: |
| | continue |
| | if count_boxed(solver_out) != 1: |
| | continue |
| |
|
| | base = strip_last_boxed(solver_out).rstrip() |
| | if count_boxed(base) != 0: |
| | continue |
| |
|
| | wrong_solution = base + "\n\n" + r"but, the answer is \boxed{c}" |
| |
|
| | if count_boxed(wrong_solution) != 1: |
| | continue |
| | if extract_boxed_answer(wrong_solution) != "c": |
| | continue |
| | if not re.search(r"\\boxed\{c\}\s*$", wrong_solution): |
| | continue |
| |
|
| | samples.append(GenSample( |
| | image=meta["image"], |
| | prompt=meta["prompt"], |
| | correct_solution=solver_out, |
| | wrong_solution=wrong_solution, |
| | answer=meta["gt_letter"], |
| | source=meta["source"] |
| | )) |
| |
|
| | shard_pkl = args.out_pkl if world_size == 1 else f"{args.out_pkl}.rank{rank}" |
| | with open(shard_pkl, "wb") as f: |
| | pickle.dump(samples, f) |
| |
|
| | barrier() |
| |
|
| | |
| | |
| | |
| | if world_size > 1 and is_master: |
| | all_samples: List[GenSample] = [] |
| | for fp in sorted(glob.glob(args.out_pkl + ".rank*")): |
| | with open(fp, "rb") as f: |
| | all_samples.extend(pickle.load(f)) |
| | with open(args.out_pkl, "wb") as f: |
| | pickle.dump(all_samples, f) |
| |
|
| | cnt = Counter([s.source for s in all_samples]) |
| | print(f"[rank0] merged total={len(all_samples)} -> {args.out_pkl}") |
| | print(f"[rank0] by source: scienceqa={cnt.get('scienceqa', 0)}, aokvqa={cnt.get('aokvqa', 0)}") |
| |
|
| | if world_size == 1 and is_master: |
| | cnt = Counter([s.source for s in samples]) |
| | print(f"[rank0] total={len(samples)} -> {args.out_pkl}") |
| | print(f"[rank0] by source: scienceqa={cnt.get('scienceqa', 0)}, aokvqa={cnt.get('aokvqa', 0)}") |
| |
|
| | destroy_dist() |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |