| | import os |
| | import re |
| | import ast |
| | import math |
| | import yaml |
| | import warnings |
| | from datetime import datetime |
| | from dataclasses import dataclass, field |
| | from collections import defaultdict |
| | from typing import Any, Callable, Optional, Union, Sized, Dict, Tuple, List, Literal, Type |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| | import datasets |
| |
|
| | from PIL import Image |
| |
|
| | from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config |
| | from trl.models import unwrap_model_for_generation |
| |
|
| | from transformers import ( |
| | TrainingArguments, |
| | Trainer, |
| | GenerationConfig, |
| | ) |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import ( |
| | is_safetensors_available, |
| | is_peft_available |
| | ) |
| |
|
| | if is_safetensors_available(): |
| | import safetensors.torch |
| | from peft import PeftConfig, get_peft_model, PeftModel |
| | from accelerate.utils import is_peft_model, set_seed |
| |
|
| | from qwen_vl_utils import process_vision_info |
| |
|
| | from src.model.vlm_backbone.qwen2_5_vl_gp.process_gp import Qwen2_5_VL_GP_Processor |
| |
|
| | from transformers.trainer import ( |
| | logger, |
| | TRAINING_ARGS_NAME, |
| | CONFIG_NAME, |
| | ADAPTER_WEIGHTS_NAME, |
| | ADAPTER_SAFE_WEIGHTS_NAME, |
| | WEIGHTS_NAME, |
| | WEIGHTS_INDEX_NAME, |
| | SAFE_WEIGHTS_NAME, |
| | SAFE_WEIGHTS_INDEX_NAME, |
| | FSDP_MODEL_NAME, |
| | ) |
| |
|
| | from src.model.vlm_backbone.qwen2_5_vl_gp.warppers import debug_calls |
| | from src.utils_gp import ( |
| | LLMClient, |
| | norm_bboxes, |
| | extract_one_bbox_from_str, |
| | cal_paired_ious, |
| | print_rank0 |
| | ) |
| |
|
| |
|
| | |
| |
|
| | QUERY_KEY = "query" |
| | IMG_PATH_KEY = "img_path" |
| | ANSWER_KEY = "answer" |
| | NORMED_BBOXES_KEY = "normed_bboxes" |
| | SCORE_FUNCS_KEY = "score_funcs" |
| |
|
| | REMAIN_KEYS = [ |
| | QUERY_KEY, |
| | IMG_PATH_KEY, |
| | NORMED_BBOXES_KEY, |
| | ANSWER_KEY, |
| | SCORE_FUNCS_KEY, |
| | ] |
| |
|
| | MAPPER_REGISTRY: Dict[str, Callable] = {} |
| | FILTER_REGISTRY: Dict[str, Callable] = {} |
| |
|
| | def register_mappers(): |
| | def wrapper(func): |
| | name = func.__name__.replace("_dataset_mapper", "") |
| | MAPPER_REGISTRY[name] = func |
| | return func |
| | return wrapper |
| |
|
| | def register_filters(): |
| | def wrapper(func): |
| | name = func.__name__.replace("_dataset_filter", "") |
| | FILTER_REGISTRY[name] = func |
| | return func |
| | return wrapper |
| |
|
| |
|
| | @register_mappers() |
| | def cot_train_dataset_mapper(one_data, **kwargs): |
| | query = one_data['question'] |
| | if 'prompt' in kwargs: |
| | query = kwargs['prompt'].format(query) |
| | answer = one_data['answer'] |
| | image = one_data['image'] |
| | dataset = one_data['dataset'] |
| | img_path = os.path.join(kwargs['img_dir'], "cot", dataset, image) |
| | bboxes = one_data['bboxs'] |
| | return { |
| | QUERY_KEY: query, |
| | ANSWER_KEY: answer, |
| | IMG_PATH_KEY: img_path, |
| | NORMED_BBOXES_KEY: bboxes, |
| | SCORE_FUNCS_KEY: kwargs['score_funcs'] |
| | } |
| | |
| |
|
| | @register_mappers() |
| | def cot_train_fullmask_dataset_mapper(one_data, **kwargs): |
| | query = one_data['question'] |
| | if 'prompt' in kwargs: |
| | query = kwargs['prompt'].format(query) |
| | answer = one_data['answer'] |
| | image = one_data['image'] |
| | dataset = one_data['dataset'] |
| | img_path = os.path.join(kwargs['img_dir'], "cot", dataset, image) |
| | normed_bboxes = [[0.0, 0.0, 1.0, 1.0]] |
| | return { |
| | QUERY_KEY: query, |
| | ANSWER_KEY: answer, |
| | IMG_PATH_KEY: img_path, |
| | NORMED_BBOXES_KEY: normed_bboxes, |
| | SCORE_FUNCS_KEY: kwargs['score_funcs'] |
| | } |
| | |
| | |
| | @register_mappers() |
| | def norm_bboxes_dataset_mapper(one_data, **kwargs): |
| | bboxes = one_data.pop(NORMED_BBOXES_KEY) |
| | if 'width' in one_data: |
| | width = one_data['width'] |
| | height = one_data['height'] |
| | else: |
| | img_path = one_data[IMG_PATH_KEY] |
| | img_pil = Image.open(img_path) |
| | width, height = img_pil.size |
| | img_pil.close() |
| | normed_bboxes = norm_bboxes(bboxes, height, width, bbox_type=kwargs['bbox_type']) |
| | one_data[NORMED_BBOXES_KEY] = normed_bboxes |
| | return one_data |
| |
|
| | |
| | @register_filters() |
| | def image_exist_dataset_filter(one_data, **kwargs): |
| | img_path = one_data[IMG_PATH_KEY] |
| | try: |
| | img = Image.open(img_path) |
| | img.close() |
| | return True |
| | except (FileNotFoundError, OSError) as e: |
| | print_rank0(f"Image not found or invalid: {img_path}. Error: {e}") |
| | return False |
| | except Exception as e: |
| | print_rank0(f"Unexpected error while checking image: {img_path}. Error: {e}") |
| | return False |
| | |
| | @register_filters() |
| | def inputs_seq_length_dataset_filter(one_data, **kwargs): |
| | processor = kwargs['processor'] |
| | max_input_seq_length = kwargs.get('max_input_seq_length', None) |
| | max_input_remain_seq_length = kwargs.get('max_input_remain_seq_length', None) |
| | if max_input_seq_length is None and max_input_remain_seq_length is None: |
| | return True |
| | img_path = one_data[IMG_PATH_KEY] |
| | query = one_data[QUERY_KEY] |
| | normed_bboxes = [one_data[NORMED_BBOXES_KEY]] if max_input_remain_seq_length is not None else None |
| | messages = [[{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}]] |
| | text = processor.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | image_inputs, video_inputs = process_vision_info(messages) |
| | inputs = processor( |
| | text=text, |
| | images=image_inputs, |
| | videos=video_inputs, |
| | normed_bboxes=normed_bboxes, |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| | seq_length = inputs.input_ids.shape[1] |
| | if max_input_seq_length is not None and seq_length > max_input_seq_length: |
| | return False |
| | |
| | if max_input_remain_seq_length is not None: |
| | ref_token_masks = inputs.ref_token_masks[0] |
| | reduced_num = ref_token_masks.numel() - ref_token_masks.sum().item() |
| | remain_seq_length = seq_length - reduced_num |
| | if remain_seq_length > max_input_remain_seq_length: |
| | return False |
| | return True |
| |
|
| |
|
| | |
| |
|
| | LOSS_REGISTRY: Dict[str, Type[nn.Module]] = {} |
| |
|
| | def register_loss(loss_class): |
| | name = loss_class.__name__ |
| | if name in LOSS_REGISTRY: |
| | raise ValueError(f"Loss class '{name}' is already registered.") |
| | LOSS_REGISTRY[name] = loss_class |
| | return loss_class |
| |
|
| |
|
| | @register_loss |
| | class DiceLoss(nn.Module): |
| | def __init__(self, epsilon: float = 1e-6, **kwargs): |
| | super().__init__() |
| | self.epsilon = epsilon |
| |
|
| | def forward(self, |
| | image_token_mask_logits: List[torch.Tensor], |
| | ref_token_masks: List[torch.Tensor] |
| | ) -> torch.Tensor: |
| | if not isinstance(image_token_mask_logits, list) or not isinstance(ref_token_masks, list): |
| | raise TypeError("Inputs must be lists of tensors.") |
| | if len(image_token_mask_logits) != len(ref_token_masks): |
| | raise ValueError(f"Input lists must have the same length, but got " |
| | f"{len(image_token_mask_logits)} and {len(ref_token_masks)}") |
| | if len(image_token_mask_logits) == 0: |
| | return torch.tensor(0.0, device=image_token_mask_logits[0].device if image_token_mask_logits else None) |
| |
|
| | batch_size = len(image_token_mask_logits) |
| | total_dice_loss = 0.0 |
| |
|
| | for i in range(batch_size): |
| | pred_mask_1d = image_token_mask_logits[i].flatten().sigmoid() |
| | gt_mask_1d = ref_token_masks[i].flatten().to(pred_mask_1d.device, dtype=torch.float) |
| | intersection = (pred_mask_1d * gt_mask_1d).sum() |
| | pred_sum = pred_mask_1d.sum() |
| | gt_sum = gt_mask_1d.sum() |
| | dice_coefficient = (2.0 * intersection + self.epsilon) / (pred_sum + gt_sum + self.epsilon) |
| | total_dice_loss += (1.0 - dice_coefficient) |
| |
|
| | return total_dice_loss / batch_size |
| |
|
| |
|
| | @register_loss |
| | class BCELoss(nn.Module): |
| | def ___init__(self, **kwargs): |
| | super(BCELoss, self).__init__() |
| | |
| | def forward(self, |
| | image_token_mask_logits: List[torch.Tensor], |
| | ref_token_masks: List[torch.Tensor] |
| | ) -> torch.Tensor: |
| | |
| | batch_size = len(image_token_mask_logits) |
| | total_bce_loss = 0.0 |
| | for i in range(batch_size): |
| | pred_mask_1d = image_token_mask_logits[i].flatten() |
| | gt_mask_1d = ref_token_masks[i].flatten().to(pred_mask_1d.device) |
| | bce_loss = F.binary_cross_entropy_with_logits( |
| | pred_mask_1d.float(), |
| | gt_mask_1d.float(), |
| | ) |
| | total_bce_loss += bce_loss |
| | return total_bce_loss / batch_size |
| |
|
| |
|
| | @register_loss |
| | class MaskLoss(nn.Module): |
| | def __init__(self, |
| | dice_weight: float = 0.5, |
| | bce_weight: float = 0.5, |
| | epsilon: float = 1e-6, |
| | **kwargs): |
| | super().__init__() |
| | self.dice_loss = DiceLoss(epsilon=epsilon) |
| | self.bce_loss = BCELoss() |
| | self.dice_weight = dice_weight |
| | self.bce_weight = bce_weight |
| | |
| | def forward(self, image_token_mask_logits: List[torch.Tensor], |
| | ref_token_masks: List[torch.Tensor] |
| | ) -> torch.Tensor: |
| | dice_loss = self.dice_loss(image_token_mask_logits, ref_token_masks) |
| | bce_loss = self.bce_loss(image_token_mask_logits, ref_token_masks) |
| | return self.dice_weight * dice_loss + self.bce_weight * bce_loss |
| |
|
| |
|
| | |
| |
|
| | SCORE_REGISTRY: Dict[str, Callable] = {} |
| |
|
| | def register_score(): |
| | def wrapper(func): |
| | name = func.__name__.replace("_score", "") |
| | SCORE_REGISTRY[name] = func |
| | return func |
| | return wrapper |
| |
|
| | @register_score() |
| | def llm_score(query, completion, answer, args): |
| | """ |
| | YAML 里可能写了 'score_funcs: [llm]'。本工程不使用这些分数,返回 0 占位即可。 |
| | """ |
| | |
| | if isinstance(query, list): |
| | return [0.0] * len(query) |
| | return [0.0] |
| |
|
| |
|
| | |
| |
|
| | def _resolve_rel_path(rel_path: str, base_dir: str) -> str: |
| | """ |
| | Resolve a relative path against base_dir; if not found, try parent dirs up to 4 levels. |
| | """ |
| | if os.path.isabs(rel_path): |
| | return rel_path |
| | candidates = [os.path.join(base_dir, rel_path)] |
| | parent = base_dir |
| | for _ in range(4): |
| | parent = os.path.dirname(parent) |
| | if not parent or parent in ("/", ""): |
| | break |
| | candidates.append(os.path.join(parent, rel_path)) |
| | for cand in candidates: |
| | if os.path.exists(cand): |
| | return cand |
| | return candidates[0] |
| |
|
| |
|
| | class GPDataset(torch.utils.data.Dataset): |
| | """ |
| | A PyTorch Dataset that loads and combines multiple datasets |
| | based on a YAML configuration file. It handles sampling |
| | and applies specified mapping functions. |
| | """ |
| | @classmethod |
| | def _load_config(cls, config_path: str) -> Dict[str, Any]: |
| | print_rank0(f"Loading configuration from: {config_path}") |
| | try: |
| | with open(config_path, 'r', encoding='utf-8') as f: |
| | conf = yaml.safe_load(f) |
| | if conf is None: |
| | raise ValueError("YAML config is empty.") |
| |
|
| | base_dir = os.path.dirname(config_path) |
| | |
| | if 'datasets' not in conf: |
| | if 'train_dataset' in conf: |
| | ds_yaml = _resolve_rel_path(conf['train_dataset'], base_dir) |
| | print_rank0(f"Loading dataset config from: {ds_yaml}") |
| | with open(ds_yaml, 'r', encoding='utf-8') as f: |
| | conf2 = yaml.safe_load(f) |
| | if conf2 is None or 'datasets' not in conf2: |
| | raise ValueError(f"'{ds_yaml}' missing 'datasets' key.") |
| | conf = conf2 |
| | base_dir = os.path.dirname(ds_yaml) |
| | else: |
| | raise ValueError("YAML config is missing both 'datasets' and 'train_dataset' keys.") |
| |
|
| | conf['__root_dir__'] = base_dir |
| | print_rank0("Configuration loaded successfully.") |
| | return conf |
| | except Exception as e: |
| | print_rank0(f"Failed to load config: {e}") |
| | raise |
| |
|
| | @classmethod |
| | def _apply_sampling(cls, dataset: datasets.Dataset, strategy: Optional[str], seed: Optional[int] = None) -> datasets.Dataset: |
| | """Applies sampling strategy to a dataset.""" |
| | if not strategy: |
| | print_rank0("No sampling strategy specified, using full dataset.") |
| | return dataset |
| |
|
| | try: |
| | parts = strategy.split(':') |
| | if len(parts) != 2: |
| | raise ValueError(f"Invalid sampling strategy format: '{strategy}'. Expected 'type:value'.") |
| | strat_type, strat_value = parts[0].lower(), parts[1] |
| | num_samples = int(strat_value) |
| | total_size = len(dataset) |
| | if num_samples <= 0: |
| | raise ValueError(f"Sampling value must be positive, got: {num_samples} [{strategy}]") |
| | num_samples = min(num_samples, total_size) |
| |
|
| | print_rank0(f"Applying sampling: {strategy} ({num_samples} samples) to dataset of size {total_size}") |
| |
|
| | if strat_type == "first": |
| | return dataset.select(range(num_samples)) |
| | elif strat_type == "end": |
| | start_index = max(0, total_size - num_samples) |
| | return dataset.select(range(start_index, total_size)) |
| | elif strat_type == "random": |
| | shuffled_dataset = dataset.shuffle(seed=seed) |
| | return shuffled_dataset.select(range(num_samples)) |
| | else: |
| | print_rank0(f"Warning: Unknown sampling strategy type: '{strat_type}'. Using full dataset.") |
| | return dataset |
| | except ValueError as e: |
| | print_rank0(f"Error parsing sampling strategy '{strategy}': {e}. Using full dataset.") |
| | return dataset |
| | except Exception as e: |
| | print_rank0(f"An unexpected error occurred during sampling: {e}. Using full dataset.") |
| | return dataset |
| | |
| | @classmethod |
| | def _all_processed_datasets(cls, config, processor, args): |
| | root_dir = config.get('__root_dir__', os.getcwd()) |
| | all_processed_datasets: Dict[str, datasets.Dataset] = {} |
| | for i, dataset_config in enumerate(config['datasets']): |
| | print_rank0(f"\nProcessing dataset entry {i+1}/{len(config['datasets'])}...") |
| | json_path = dataset_config.get('json_path') |
| | if not json_path: |
| | print_rank0(f"Warning: Skipping dataset entry {i+1} due to missing 'json_path'.") |
| | continue |
| | json_path = _resolve_rel_path(json_path, root_dir) |
| |
|
| | base_name = '.'.join(os.path.basename(json_path).split('.')[:-1]) |
| | dataset_name = dataset_config.get('dataset_name', base_name) |
| |
|
| | sampling_strategy = dataset_config.get('sampling_strategy', None) |
| | sampling_seed = dataset_config['sampling_seed'] if 'sampling_seed' in dataset_config else getattr(args, 'sampling_seed', 42) |
| |
|
| | mapper_name = dataset_config.get('mapper') |
| | bbox_type = dataset_config.get('bbox_type') |
| |
|
| | |
| | if 'img_dir' in dataset_config: |
| | img_dir = _resolve_rel_path(dataset_config['img_dir'], root_dir) |
| | else: |
| | img_dir = getattr(args, 'img_dir', None) |
| | if img_dir is not None: |
| | img_dir = _resolve_rel_path(img_dir, root_dir) |
| |
|
| | additional_mappers = dataset_config.get('additional_mappers', []) |
| | score_funcs = dataset_config.get('score_funcs', []) |
| | prompt = dataset_config.get('prompt', None) |
| |
|
| | max_input_seq_length = dataset_config['max_input_seq_length'] if 'max_input_seq_length' in dataset_config else getattr(args, 'max_input_seq_length', None) |
| | max_input_remain_seq_length = dataset_config['max_input_remain_seq_length'] if 'max_input_remain_seq_length' in dataset_config else getattr(args, 'max_input_remain_seq_length', None) |
| | |
| | |
| | if score_funcs: |
| | filtered = [] |
| | for sf in score_funcs: |
| | if sf in SCORE_REGISTRY: |
| | filtered.append(sf) |
| | else: |
| | print_rank0(f"Warning: Score function '{sf}' not registered. Will ignore.") |
| | score_funcs = filtered |
| |
|
| | try: |
| | print_rank0(f"Loading raw data from: {json_path}") |
| | raw_dataset = datasets.load_dataset('json', data_files=json_path, split='train') |
| | print_rank0(f"Loaded {len(raw_dataset)} examples raw.") |
| |
|
| | sampled_dataset = cls._apply_sampling(raw_dataset, sampling_strategy, sampling_seed) |
| | if len(sampled_dataset) == 0: |
| | print_rank0("Dataset is empty after sampling, skipping.") |
| | continue |
| | print_rank0(f"Dataset size after sampling: {len(sampled_dataset)}") |
| |
|
| | mapper_func = MAPPER_REGISTRY[mapper_name] |
| | print_rank0(f"Applying mapper: '{mapper_name}'") |
| | mapper_kwargs = { |
| | 'img_dir': img_dir, |
| | 'score_funcs': score_funcs, |
| | } |
| | if prompt is not None: |
| | mapper_kwargs['prompt'] = prompt |
| | print_rank0(f"Mapper arguments: {mapper_kwargs}") |
| |
|
| | processed_dataset = sampled_dataset.map( |
| | mapper_func, |
| | num_proc=8, |
| | fn_kwargs=mapper_kwargs, |
| | ) |
| |
|
| | processed_dataset = processed_dataset.remove_columns( |
| | [col for col in processed_dataset.column_names if col not in REMAIN_KEYS] |
| | ) |
| | |
| | print_rank0("Applying dataset filter: 'image_exist_dataset_filter'") |
| | processed_dataset = processed_dataset.filter( |
| | image_exist_dataset_filter, |
| | num_proc=8, |
| | fn_kwargs={} |
| | ) |
| | print_rank0(f"Processed dataset size after image_exist_dataset_filter: {len(processed_dataset)}") |
| | |
| | if max_input_seq_length is not None or max_input_remain_seq_length is not None: |
| | processed_dataset = processed_dataset.filter( |
| | inputs_seq_length_dataset_filter, |
| | num_proc=8, |
| | fn_kwargs={ |
| | 'processor': processor, |
| | 'max_input_seq_length': max_input_seq_length, |
| | 'max_input_remain_seq_length': max_input_remain_seq_length, |
| | } |
| | ) |
| | print_rank0(f"Processed dataset size after inputs_seq_length_dataset_filter: {len(processed_dataset)}") |
| | |
| | for additional_mapper in additional_mappers: |
| | mapper_func = MAPPER_REGISTRY[additional_mapper] |
| | print_rank0(f"Applying additional mapper: '{additional_mapper}'") |
| | processed_dataset = processed_dataset.map( |
| | mapper_func, |
| | num_proc=8, |
| | fn_kwargs={ |
| | 'bbox_type': bbox_type, |
| | } |
| | ) |
| | print_rank0(f"Processed dataset size: {len(processed_dataset)}") |
| | if len(processed_dataset) == 0: |
| | print_rank0(f"Warning: Processed dataset {dataset_name} is empty after mapping. Skipping.") |
| | continue |
| |
|
| | if dataset_name in all_processed_datasets: |
| | dataset_name_with_uuid = f"{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| | print_rank0(f"Warning: Dataset name '{dataset_name}' already exists. Renaming to '{dataset_name_with_uuid}'") |
| | all_processed_datasets[dataset_name_with_uuid] = processed_dataset |
| | else: |
| | all_processed_datasets[dataset_name] = processed_dataset |
| |
|
| | except FileNotFoundError: |
| | print_rank0(f"Error: Data file not found for dataset entry {i+1}: {json_path}. Skipping.") |
| | except Exception as e: |
| | print_rank0(f"Error processing dataset entry {i+1} ({json_path}): {e}. Skipping.") |
| | |
| | return all_processed_datasets |
| | |
| |
|
| | def __init__(self, config_path: str, processor: Qwen2_5_VL_GP_Processor, script_args: Optional[Any] = None): |
| | """ |
| | Initializes the GPDataset. |
| | |
| | Args: |
| | config_path (str): Path to the YAML configuration file. |
| | processor (Qwen2_5_VL_GP_Processor): Processor for handling text and vision data. |
| | script_args (Any, optional): Additional arguments passed from the script |
| | (e.g., training args, could contain seed). Defaults to None. |
| | """ |
| | super().__init__() |
| | self.args = script_args |
| | self.config = self._load_config(config_path) |
| | self.processor = processor |
| | all_processed_datasets = self._all_processed_datasets(self.config, self.processor, self.args) |
| | if all_processed_datasets: |
| | print_rank0(f"\nConcatenating {len(all_processed_datasets)} processed dataset(s)...") |
| | self.final_dataset = datasets.concatenate_datasets(list(all_processed_datasets.values())) |
| | if len(self.final_dataset) == 0: |
| | raise ValueError("Final dataset is empty after concatenation.") |
| | print_rank0(f"Final combined dataset size: {len(self.final_dataset)}") |
| | print_rank0(f"Final dataset features: {self.final_dataset.features}") |
| | else: |
| | raise ValueError("No datasets were successfully processed. Please check your configuration.") |
| | self.final_dataset = None |
| |
|
| | def __len__(self) -> int: |
| | return len(self.final_dataset) if self.final_dataset else 0 |
| |
|
| | def __getitem__(self, index: int) -> Dict[str, Any]: |
| | if self.final_dataset is None: |
| | raise IndexError("Dataset is not initialized or is empty.") |
| | if not 0 <= index < len(self.final_dataset): |
| | raise IndexError(f"Index {index} out of bounds for dataset of size {len(self.final_dataset)}") |
| | return self.final_dataset[index] |
| | |
| | |
| | @classmethod |
| | def get_processed_dataset_dict(cls, config_path: str, processor: Qwen2_5_VL_GP_Processor, script_args: Optional[Any] = None) -> Dict[str, datasets.Dataset]: |
| | config = cls._load_config(config_path) |
| | all_processed_datasets = cls._all_processed_datasets(config, processor, script_args) |
| | return all_processed_datasets |
| |
|
| |
|
| |
|
| | class GPCollator: |
| | def __init__(self, processor, is_sft): |
| | self.processor = processor |
| | self.is_sft = is_sft |
| | self.im_start_id = self.processor.tokenizer.encode("<|im_start|>")[0] |
| | |
| | def _prepare_labels_from_input_ids(self, input_ids): |
| | B, L = input_ids.shape |
| | labels = input_ids.clone() |
| | mask = input_ids == self.im_start_id |
| | flipped_mask = mask.flip(dims=(1,)) |
| | first_idx_in_flipped = torch.argmax(flipped_mask.int(), dim=1) |
| | last_pos = (L - 1) - first_idx_in_flipped |
| | mask_until_idx = last_pos + 3 |
| | mask_until_idx = torch.clamp(mask_until_idx, max=L) |
| | arange_l = torch.arange(L, device=input_ids.device).expand(B, -1) |
| | modification_mask = arange_l < mask_until_idx.unsqueeze(1) |
| | labels[modification_mask] = -100 |
| | return labels |
| | |
| | def __call__(self, features): |
| | messages = [] |
| | normed_bboxes = [] |
| | answers = [] |
| | querys = [] |
| | score_funcs = [] |
| | for feature in features: |
| | query = feature[QUERY_KEY] |
| | answer = feature[ANSWER_KEY] |
| | img_path = feature[IMG_PATH_KEY] |
| | if self.is_sft: |
| | messages.append([{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}, {"role": "assistant", "content": [{"type": "text", "text": answer}]}]) |
| | else: |
| | messages.append([{"role": "user", "content": [{"type": "image", "image": img_path}, {"type": "text", "text": query}]}]) |
| | normed_bboxes.append(feature[NORMED_BBOXES_KEY]) |
| | querys.append(query) |
| | answers.append(answer) |
| | score_funcs.append(feature[SCORE_FUNCS_KEY]) |
| | |
| | text = self.processor.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=(not self.is_sft) |
| | ) |
| | image_inputs, video_inputs = process_vision_info(messages) |
| | inputs = self.processor( |
| | text=text, |
| | normed_bboxes=normed_bboxes, |
| | images=image_inputs, |
| | videos=video_inputs, |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| | |
| | if self.is_sft: |
| | labels = self._prepare_labels_from_input_ids(inputs.input_ids) |
| | inputs["labels"] = labels |
| | |
| | inputs[QUERY_KEY] = querys |
| | inputs[ANSWER_KEY] = answers |
| | inputs[SCORE_FUNCS_KEY] = score_funcs |
| | return inputs |