| |
| |
| |
|
|
| import os |
| from os import path |
| from argparse import ArgumentParser |
| import shutil |
|
|
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| import numpy as np |
| from PIL import Image |
|
|
| from inference.data.test_datasets import DAVISTestDataset_221128_TransColorization_batch |
| from inference.data.mask_mapper import MaskMapper |
| from model.network import ColorMNet |
| from inference.inference_core import InferenceCore |
|
|
| from progressbar import progressbar |
|
|
| from dataset.range_transform import inv_im_trans, inv_lll2rgb_trans |
|
|
| from skimage import color, io |
| import cv2 |
|
|
| try: |
| import hickle as hkl |
| except ImportError: |
| print('Failed to import hickle. Fine if not using multi-scale testing.') |
|
|
|
|
| |
| def detach_to_cpu(x): |
| return x.detach().cpu() |
|
|
| def tensor_to_np_float(image): |
| image_np = image.numpy().astype('float32') |
| return image_np |
|
|
| def lab2rgb_transform_PIL(mask): |
| mask_d = detach_to_cpu(mask) |
| mask_d = inv_lll2rgb_trans(mask_d) |
| im = tensor_to_np_float(mask_d) |
|
|
| if len(im.shape) == 3: |
| im = im.transpose((1, 2, 0)) |
| else: |
| im = im[:, :, None] |
|
|
| im = color.lab2rgb(im) |
|
|
| return im.clip(0, 1) |
|
|
|
|
| |
| def build_parser() -> ArgumentParser: |
| parser = ArgumentParser() |
| parser.add_argument('--model', default='saves/DINOv2FeatureV6_LocalAtten_s2_154000.pth') |
| parser.add_argument('--FirstFrameIsNotExemplar', help='Whether the provided reference frame is exactly the first input frame', action='store_true') |
|
|
| |
| parser.add_argument('--d16_batch_path', default='input', help='Point to folder A/ which contains <video_name>/00000.png etc.') |
| parser.add_argument('--ref_path', default='ref', help='Kept for parity; dataset will also read ref.png under each video folder when args provided') |
| parser.add_argument('--output', default='result', help='Directory to save results') |
|
|
| parser.add_argument('--reverse', default=False, action='store_true', help='whether to reverse the frame order') |
| parser.add_argument('--allow_resume', action='store_true', |
| help='skip existing videos that have been colorized') |
|
|
| |
| parser.add_argument('--generic_path') |
| parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D16_batch') |
| parser.add_argument('--split', help='val/test', default='val') |
| parser.add_argument('--save_all', action='store_true', |
| help='Save all frames. Useful only in YouTubeVOS/long-time video') |
| parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking') |
| |
| |
| parser.add_argument('--disable_long_term', action='store_true') |
| parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10) |
| parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5) |
| parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time', |
| type=int, default=10000) |
| parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128) |
|
|
| parser.add_argument('--top_k', type=int, default=30) |
| parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5) |
| parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1) |
|
|
| |
| parser.add_argument('--save_scores', action='store_true') |
| parser.add_argument('--flip', action='store_true') |
| parser.add_argument('--size', default=-1, type=int, |
| help='Resize the shorter side to this size. -1 to use original resolution. ') |
| return parser |
|
|
|
|
| |
| def run_inference(args): |
| """ |
| 真正的推理流程。必须在 ZeroGPU 的调度上下文里被调用(由 app.py 的 @spaces.GPU 包裹)。 |
| 不要在导入模块时做任何 CUDA 初始化。 |
| """ |
| config = vars(args) |
| config['enable_long_term'] = not config['disable_long_term'] |
|
|
| if args.output is None: |
| args.output = f'.output/{args.dataset}_{args.split}' |
| print(f'Output path not provided. Defaulting to {args.output}') |
|
|
| |
| is_youtube = args.dataset.startswith('Y') |
| is_davis = args.dataset.startswith('D') |
| is_lv = args.dataset.startswith('LV') |
|
|
| if is_youtube or args.save_scores: |
| out_path = path.join(args.output, 'Annotations') |
| else: |
| out_path = args.output |
|
|
| if args.split != 'val': |
| raise NotImplementedError('Only split=val is supported in this script.') |
|
|
| |
| meta_dataset = DAVISTestDataset_221128_TransColorization_batch( |
| args.d16_batch_path, imset=args.ref_path, size=args.size, args=args |
| ) |
| palette = None |
|
|
| torch.autograd.set_grad_enabled(False) |
|
|
| |
| meta_loader = meta_dataset.get_datasets() |
|
|
| |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
| network = ColorMNet(config, args.model).to(device).eval() |
| if args.model is not None: |
| |
| model_weights = torch.load(args.model, map_location=device) |
| network.load_weights(model_weights, init_as_zero_if_needed=True) |
| else: |
| print('No model loaded.') |
|
|
| total_process_time = 0.0 |
| total_frames = 0 |
|
|
| |
| for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True): |
| |
| loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=0, pin_memory=False) |
| vid_name = vid_reader.vid_name |
| vid_length = len(loader) |
|
|
| |
| config['enable_long_term_count_usage'] = ( |
| config['enable_long_term'] and |
| (vid_length |
| / (config['max_mid_term_frames'] - config['min_mid_term_frames']) |
| * config['num_prototypes']) |
| >= config['max_long_term_elements'] |
| ) |
|
|
| mapper = MaskMapper() |
| processor = InferenceCore(network, config=config) |
| first_mask_loaded = False |
|
|
| |
| if args.allow_resume: |
| this_out_path = path.join(out_path, vid_name) |
| if path.exists(this_out_path): |
| print(f'Skipping {this_out_path} because output already exists.') |
| continue |
|
|
| for ti, data in enumerate(loader): |
| with torch.cuda.amp.autocast(enabled=not args.benchmark): |
| rgb = data['rgb'].to(device)[0] |
|
|
| msk = data.get('mask') |
| if not config['FirstFrameIsNotExemplar']: |
| msk = msk[:, 1:3, :, :] if msk is not None else None |
| |
| info = data['info'] |
| frame = info['frame'][0] |
| shape = info['shape'] |
| need_resize = info['need_resize'][0] |
|
|
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| start.record() |
|
|
| |
| if not first_mask_loaded: |
| if msk is not None: |
| first_mask_loaded = True |
| else: |
| continue |
|
|
| if args.flip: |
| rgb = torch.flip(rgb, dims=[-1]) |
| msk = torch.flip(msk, dims=[-1]) if msk is not None else None |
|
|
| |
| if msk is not None: |
| msk = torch.Tensor(msk[0]).to(device) |
| if need_resize: |
| msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] |
| processor.set_all_labels(list(range(1, 3))) |
| labels = range(1, 3) |
| else: |
| labels = None |
| |
| |
| if config['FirstFrameIsNotExemplar']: |
| prob = processor.step_AnyExemplar( |
| rgb, |
| msk[:1, :, :].repeat(3, 1, 1) if msk is not None else None, |
| msk[1:3, :, :] if msk is not None else None, |
| labels, |
| end=(ti == vid_length - 1) |
| ) |
| else: |
| prob = processor.step(rgb, msk, labels, end=(ti == vid_length - 1)) |
|
|
| |
| if need_resize: |
| prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0] |
|
|
| end.record() |
| torch.cuda.synchronize() |
| total_process_time += (start.elapsed_time(end)/1000) |
| total_frames += 1 |
|
|
| if args.flip: |
| prob = torch.flip(prob, dims=[-1]) |
|
|
| if args.save_scores: |
| prob = (prob.detach().cpu().numpy() * 255).astype(np.uint8) |
|
|
| |
| if args.save_all or info['save'][0]: |
| this_out_path = path.join(out_path, vid_name) |
| os.makedirs(this_out_path, exist_ok=True) |
|
|
| out_mask_final = lab2rgb_transform_PIL(torch.cat([rgb[:1, :, :], prob], dim=0)) |
| out_mask_final = (out_mask_final * 255).astype(np.uint8) |
|
|
| out_img = Image.fromarray(out_mask_final) |
| out_img.save(os.path.join(this_out_path, frame[:-4] + '.png')) |
|
|
| print(f'Total processing time: {total_process_time}') |
| print(f'Total processed frames: {total_frames}') |
| print(f'FPS: {total_frames / total_process_time}') |
| print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}') |
|
|
| |
| if not args.save_scores: |
| if is_youtube: |
| print('Making zip for YouTubeVOS...') |
| shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations') |
| elif is_davis and args.split == 'test': |
| print('Making zip for DAVIS test-dev...') |
| shutil.make_archive(args.output, 'zip', args.output) |
|
|
|
|
| |
| def run_cli(args_list=None): |
| """ |
| 供 app.py 同进程调用:test.run_cli(args_list) |
| """ |
| parser = build_parser() |
| args = parser.parse_args(args_list) |
| return run_inference(args) |
|
|
|
|
| def main(): |
| """ |
| 保留命令行可运行:python test.py --d16_batch_path A --output result ... |
| 注意:若在 Hugging Face Spaces/ZeroGPU 无状态环境下直接 run main(), |
| 需要由上层(如 app.py 的 @spaces.GPU)提供调度上下文。 |
| """ |
| run_cli() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|