| import argparse |
| import copy |
| import os.path as osp |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from mmengine import Config |
| from mmengine.dist import init_dist, get_dist_info |
| from mmengine.utils.dl_utils import set_multi_processing |
| from transformers import GenerationConfig |
| from xtuner.configs import cfgs_name_path |
| from xtuner.registry import BUILDER |
| from xtuner.tools.chat import TORCH_DTYPE_MAP |
| from xtuner.tools.utils import get_stop_criteria |
| from xtuner.utils import PROMPT_TEMPLATE |
|
|
|
|
| from projects.llava_sam2.datasets import video_lisa_collate_fn |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='RefCocoSeg') |
| parser.add_argument('config', help='config file name or path.') |
| parser.add_argument('--pth_model', default=None, help='pth model file') |
| parser.add_argument( |
| '--split', |
| default='val', |
| help='Specify a split') |
| parser.add_argument( |
| '--prompt-template', |
| choices=PROMPT_TEMPLATE.keys(), |
| default='internlm2_chat', |
| help='Specify a prompt template') |
| parser.add_argument( |
| '--stop-words', nargs='+', type=str, default=[], help='Stop words') |
| parser.add_argument( |
| '--torch-dtype', |
| default='fp16', |
| choices=TORCH_DTYPE_MAP.keys(), |
| help='Override the default `torch.dtype` and load the model under ' |
| 'a specific `dtype`.') |
| parser.add_argument( |
| '--bits', |
| type=int, |
| choices=[4, 8, None], |
| default=None, |
| help='LLM bits') |
| parser.add_argument( |
| '--bot-name', type=str, default='BOT', help='Name for Bot') |
| parser.add_argument( |
| '--offload-folder', |
| default=None, |
| help='The folder in which to offload the model weights (or where the ' |
| 'model weights are already offloaded).') |
| parser.add_argument( |
| '--max-new-tokens', |
| type=int, |
| default=100, |
| help='Maximum number of new tokens allowed in generated text') |
| parser.add_argument( |
| '--seed', |
| type=int, |
| default=0, |
| help='Random seed for reproducible text generation') |
| parser.add_argument( |
| '--launcher', |
| choices=['none', 'pytorch', 'slurm', 'mpi'], |
| default='none', |
| help='job launcher') |
| args = parser.parse_args() |
| return args |
|
|
| def main(): |
| args = parse_args() |
| if args.launcher != 'none': |
| set_multi_processing(distributed=True) |
| init_dist(args.launcher) |
|
|
| rank, world_size = get_dist_info() |
| torch.cuda.set_device(rank) |
| else: |
| rank = 0 |
| world_size = 1 |
| print(f'Rank: {rank} / World size: {world_size}') |
|
|
| if not osp.isfile(args.config): |
| try: |
| args.config = cfgs_name_path[args.config] |
| except KeyError: |
| raise FileNotFoundError(f'Cannot find {args.config}') |
|
|
| cfg = Config.fromfile(args.config) |
| model_name = cfg.model.type if isinstance(cfg.model.type, str) else cfg.model.type.__name__ |
| assert model_name in ('VideoLLaVASAMModel',) |
|
|
| model = BUILDER.build(cfg.model) |
|
|
| if args.pth_model is not None: |
| state_dict_pth = args.pth_model |
| state_dict = torch.load(state_dict_pth, map_location='cpu') |
| model.load_state_dict(state_dict, strict=False) |
|
|
| model.to('cuda:0') |
| model.eval() |
|
|
| |
| tokenizer = model.tokenizer |
|
|
| |
| gen_config = GenerationConfig( |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id |
| if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, |
| ) |
| stop_words = args.stop_words |
| if args.prompt_template: |
| template = PROMPT_TEMPLATE[args.prompt_template] |
| stop_words += template.get('STOP_WORDS', []) |
| stop_criteria = get_stop_criteria( |
| tokenizer=tokenizer, stop_words=stop_words) |
|
|
| data_cfg = copy.deepcopy(cfg.video_revos_dataset) |
| data_cfg.update(expression_file=data_cfg.expression_file.replace('train', 'val')) |
| dataset = BUILDER.build(cfg.video_revos_dataset) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=1, |
| num_workers=1, |
| shuffle=False, |
| collate_fn=video_lisa_collate_fn, |
| ) |
|
|
| for data_item in dataloader: |
| data_item = model.data_preprocessor(data_item) |
| inputs, data_samples = data_item['data'], data_item['data_samples'] |
| g_pixel_values = inputs.pop('g_pixel_values', None) |
| gt_masks = inputs.pop('masks', None) |
|
|
| output = model.mllm.generate( |
| pixel_values=inputs['pixel_values'], |
| input_ids=inputs['input_ids'], |
| attention_mask=inputs['attention_mask'], |
| visual_features=None, |
| generation_config=gen_config, |
| streamer=None, |
| bos_token_id=tokenizer.bos_token_id, |
| stopping_criteria=stop_criteria, |
| output_hidden_states=True, |
| return_dict_in_generate=True |
| ) |
| print(1) |
| print(1) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|