| """ |
| Math Vision Dataset Preprocessing Script |
| |
| This script reads the existing Math Vision dataset from data/math_vision directory |
| and preprocesses it into the format expected by the dataloader. |
| The preprocessed data will be saved with fields: prompt, completion, solution, image_path |
| |
| Usage: |
| # Using config file |
| uv run scripts/math_vision_process.py --config configs/latent_memory/math_vision.yaml |
| |
| # Manual parameters |
| uv run scripts/math_vision_process.py --input_dir data/math_vision --output_dir data/math_vision |
| """ |
|
|
| import os |
| import re |
| import json |
| import logging |
| import argparse |
| from typing import Dict, List, Optional |
| import yaml |
| from datasets import load_dataset, DatasetDict |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
| def load_existing_dataset(data_path: str) -> DatasetDict: |
| """Load existing Math Vision dataset. |
| |
| Args: |
| data_path: Directory containing train.json and test.json |
| |
| Returns: |
| DatasetDict with train/test splits |
| """ |
| data_files = {} |
| train_path = os.path.join(data_path, "train.json") |
| test_path = os.path.join(data_path, "test.json") |
| |
| if os.path.exists(train_path): |
| data_files["train"] = train_path |
| logging.info(f"Found train.json at {train_path}") |
| |
| if os.path.exists(test_path): |
| data_files["test"] = test_path |
| logging.info(f"Found test.json at {test_path}") |
| |
| if len(data_files) == 0: |
| raise FileNotFoundError(f"No data files found in {data_path}") |
| |
| logging.info(f"Loading dataset from {data_path}") |
| dataset_dict = load_dataset("json", data_files=data_files) |
| |
| return dataset_dict |
|
|
|
|
| def split_train_valid(dataset_dict: DatasetDict, val_ratio: float = 0.1) -> DatasetDict: |
| """Split train set into train/valid if valid doesn't exist. |
| |
| Args: |
| dataset_dict: Dataset dictionary |
| val_ratio: Validation set ratio |
| |
| Returns: |
| DatasetDict with train/valid/test splits |
| """ |
| if "valid" in dataset_dict: |
| logging.info("Validation set already exists, skipping split") |
| return dataset_dict |
| |
| if "train" not in dataset_dict: |
| logging.warning("No train set found, cannot create validation split") |
| return dataset_dict |
| |
| if val_ratio <= 0 or val_ratio >= 1: |
| logging.warning(f"Invalid val_ratio {val_ratio}, skipping validation split") |
| return dataset_dict |
| |
| logging.info(f"Splitting train set with val_ratio={val_ratio}") |
| train_test = dataset_dict["train"].train_test_split(test_size=val_ratio, seed=42, shuffle=True) |
| |
| |
| original_test = dataset_dict.get("test", None) |
| |
| new_dataset_dict = DatasetDict({ |
| "train": train_test["train"], |
| "valid": train_test["test"], |
| }) |
| |
| |
| if original_test is not None: |
| new_dataset_dict["test"] = original_test |
| |
| logging.info(f"Split sizes - train: {len(new_dataset_dict['train'])}, valid: {len(new_dataset_dict['valid'])}") |
| if original_test is not None: |
| logging.info(f"Test size: {len(new_dataset_dict['test'])}") |
| |
| return new_dataset_dict |
|
|
|
|
| def preprocess_batch(batch: Dict, image_root: str) -> Dict: |
| """Preprocess a batch of examples. |
| |
| Args: |
| batch: Batch of raw examples with fields: |
| - id, question, options (list), answer, solution, level, subject, image |
| image_root: Root directory for images (not used, as absolute paths are in data) |
| |
| Returns: |
| Preprocessed batch with fields: |
| - prompt: formatted question prompt |
| - completion: formatted solution/answer text |
| - solution: extracted answer (for reward computation) |
| - image_path: path to image file |
| """ |
| def _format_answer(answer: str, options: List[str] = None) -> str: |
| """Format answer in \\boxed{} format. |
| |
| For multiple choice, if answer is A/B/C/D/E, include the full option text. |
| """ |
| answer = (answer or "").strip() |
| |
| |
| if answer.startswith("\\boxed{") and answer.endswith("}"): |
| return answer |
| |
| |
| if len(answer) == 1 and answer.upper() in ['A', 'B', 'C', 'D', 'E'] and options and len(options) > 0: |
| |
| idx = ord(answer.upper()) - ord('A') |
| if 0 <= idx < len(options): |
| option_text = options[idx] |
| return f"\\boxed{{{answer}}}" |
| |
| return "\\boxed{" + answer + "}" |
|
|
| def _extract_answer(answer_str: str, options: List[str] = None) -> str: |
| """Extract raw answer without boxed formatting.""" |
| answer = (answer_str or "").strip() |
| |
| |
| if answer.startswith("\\boxed{") and answer.endswith("}"): |
| answer = answer[7:-1].strip() |
| |
| return answer |
|
|
| def _format_question_with_options(question: str, options: List[str] = None) -> str: |
| """Format question with multiple choice options if available.""" |
| formatted = question.strip() |
| |
| if options and len(options) > 0: |
| formatted += "\nOptions:\n" |
| for i, opt in enumerate(options): |
| letter = chr(ord('A') + i) |
| formatted += f"{letter}. {opt}\n" |
| |
| return formatted |
|
|
| |
| format_template = r"""Solve the problem and output the answer in the format of \boxed{your answer}.""" |
| prompt_template = "\n Question: {prompt}\n" |
| |
| |
| questions: List[str] = batch.get("question", []) |
| options_list: List[List[str]] = batch.get("options", [[]] * len(questions)) |
| answers: List[str] = batch.get("answer", [""] * len(questions)) |
| solutions: List[Optional[str]] = batch.get("solution", [None] * len(questions)) |
| image_paths_src: List[str] = batch.get("image", [""] * len(questions)) |
| |
| prompts: List[str] = [] |
| completions: List[str] = [] |
| solution_labels: List[str] = [] |
| image_paths: List[str] = [] |
| |
| for q, opts, ans, sol, img_path in zip(questions, options_list, answers, solutions, image_paths_src): |
| |
| formatted_q = _format_question_with_options(q, opts) |
| processed_prompt = format_template + prompt_template.format(prompt=formatted_q) |
| |
| |
| raw_answer = _extract_answer(ans, opts) |
| solution_label = _format_answer(ans, opts) |
| |
| |
| if sol and sol.strip(): |
| |
| completion_text = sol.strip() |
| else: |
| |
| completion_text = f"The answer is {solution_label}" |
| |
| prompts.append(processed_prompt) |
| completions.append(completion_text) |
| solution_labels.append(solution_label) |
| |
| |
| image_paths.append(img_path if img_path else None) |
| |
| return { |
| "prompt": prompts, |
| "completion": completions, |
| "solution": solution_labels, |
| "image_path": image_paths, |
| } |
|
|
|
|
| def preprocess_dataset(dataset_dict: DatasetDict, image_root: str, batch_size: int = 512) -> DatasetDict: |
| """Preprocess all splits. |
| |
| Args: |
| dataset_dict: Raw dataset dictionary |
| image_root: Root directory for images (not used for this dataset) |
| batch_size: Batch size for processing |
| |
| Returns: |
| Preprocessed DatasetDict with fields: prompt, completion, solution, image_path |
| """ |
| keep_keys = ["prompt", "completion", "solution", "image_path"] |
|
|
| def _map(split): |
| logging.info(f"Preprocessing {split} split with batch_size={batch_size}") |
| ds = dataset_dict[split].map( |
| lambda batch: preprocess_batch(batch, image_root), |
| batched=True, |
| batch_size=batch_size, |
| num_proc=None, |
| remove_columns=dataset_dict[split].column_names, |
| desc=f"Math_Vision preprocess ({split})", |
| ) |
|
|
| |
| def has_valid_solution(example): |
| solution = example.get("solution", "") |
| return solution is not None and len(solution.strip()) > 0 |
|
|
| ds_filtered = ds.filter(has_valid_solution, desc=f"Filter empty solutions ({split})") |
| |
| num_filtered = len(ds) - len(ds_filtered) |
| if num_filtered > 0: |
| logging.info(f"Filtered {num_filtered} samples with empty solutions from {split}") |
| |
| return ds_filtered.select_columns(keep_keys) |
|
|
| result = DatasetDict({split: _map(split) for split in dataset_dict.keys()}) |
| |
| for split, ds in result.items(): |
| logging.info(f"Preprocessed {split}: {len(ds)} samples") |
| |
| return result |
|
|
|
|
| def save_dataset(dataset_dict: DatasetDict, output_dir: str): |
| """Save preprocessed dataset to JSON files. |
| |
| Args: |
| dataset_dict: Preprocessed dataset |
| output_dir: Output directory |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
| |
| for split_name, ds in dataset_dict.items(): |
| output_path = os.path.join(output_dir, f"{split_name}.json") |
| logging.info(f"Saving {split_name} split to {output_path}") |
| |
| |
| data = [] |
| for example in ds: |
| data.append({ |
| "prompt": example["prompt"], |
| "completion": example["completion"], |
| "solution": example["solution"], |
| "image_path": example["image_path"], |
| }) |
| |
| with open(output_path, "w", encoding="utf-8") as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
| |
| logging.info(f"Saved {len(data)} samples to {output_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Preprocess Math Vision dataset") |
| parser.add_argument("--config", type=str, help="Path to config YAML file") |
| parser.add_argument("--input_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Input directory with train.json/test.json") |
| parser.add_argument("--output_dir", type=str, default="/root/CVPR/MemGen/data/math_vision", help="Output directory for preprocessed data") |
| parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio") |
| parser.add_argument("--image_root", type=str, default="/root/CVPR/MemGen/dataset/math_vision/images", help="Image root directory") |
| parser.add_argument("--batch_size", type=int, default=512, help="Batch size for preprocessing") |
| |
| args = parser.parse_args() |
| |
| |
| if args.config: |
| logging.info(f"Loading config from {args.config}") |
| with open(args.config, "r") as f: |
| cfg = yaml.safe_load(f) |
| |
| |
| dataset_cfg = cfg.get("datasets", {}).get("math_vision", {}) |
| mode = dataset_cfg.get("mode", "sft") |
| mode_cfg = dataset_cfg.get(mode, {}) |
| |
| val_ratio = mode_cfg.get("val_ratio", args.val_ratio) |
| image_root = mode_cfg.get("image_root", args.image_root) |
| else: |
| val_ratio = args.val_ratio |
| image_root = args.image_root |
| |
| input_dir = args.input_dir |
| output_dir = args.output_dir |
| batch_size = args.batch_size |
| |
| logging.info("=" * 80) |
| logging.info("Math Vision Dataset Preprocessing") |
| logging.info("=" * 80) |
| logging.info(f"Input directory: {input_dir}") |
| logging.info(f"Output directory: {output_dir}") |
| logging.info(f"Validation ratio: {val_ratio}") |
| logging.info(f"Image root: {image_root}") |
| logging.info(f"Batch size: {batch_size}") |
| |
| |
| dataset_dict = load_existing_dataset(input_dir) |
| |
| |
| dataset_dict = split_train_valid(dataset_dict, val_ratio=val_ratio) |
| |
| |
| preprocessed = preprocess_dataset(dataset_dict, image_root, batch_size=batch_size) |
| |
| |
| save_dataset(preprocessed, output_dir) |
| |
| logging.info("=" * 80) |
| logging.info("Preprocessing complete!") |
| logging.info("=" * 80) |
| for split, ds in preprocessed.items(): |
| logging.info(f"{split}: {len(ds)} samples") |
| logging.info(f"Output saved to: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|