| import os |
| import cv2 |
| import sys |
| import tqdm |
| import torch |
| import argparse |
|
|
| import numpy as np |
|
|
| from PIL import Image |
|
|
| filepath = os.path.split(os.path.abspath(__file__))[0] |
| repopath = os.path.split(filepath)[0] |
| sys.path.append(repopath) |
|
|
| from lib import * |
| from utils.misc import * |
| from data.dataloader import * |
| from data.custom_transforms import * |
|
|
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
|
|
| def _args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml') |
| parser.add_argument('--source', '-s', type=str) |
| parser.add_argument('--dest', '-d', type=str, default=None) |
| parser.add_argument('--type', '-t', type=str, default='map') |
| parser.add_argument('--gpu', '-g', action='store_true', default=False) |
| parser.add_argument('--jit', '-j', action='store_true', default=False) |
| parser.add_argument('--verbose', '-v', action='store_true', default=False) |
| return parser.parse_args() |
|
|
| def get_format(source): |
| img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))]) |
| vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))]) |
| |
| if img_count * vid_count != 0: |
| return '' |
| elif img_count != 0: |
| return 'Image' |
| elif vid_count != 0: |
| return 'Video' |
| else: |
| return '' |
|
|
| def inference(opt, args): |
| model = eval(opt.Model.name)(**opt.Model) |
| model.load_state_dict(torch.load(os.path.join( |
| opt.Test.Checkpoint.checkpoint_dir, 'latest.pth'), map_location=torch.device('cpu')), strict=True) |
| |
| if args.gpu is True: |
| model = model.cuda() |
| model.eval() |
| |
| if args.jit is True: |
| if os.path.isfile(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) is False: |
| model = Simplify(model) |
| model = torch.jit.trace(model, torch.rand(1, 3, *opt.Test.Dataset.transforms.static_resize.size).cuda(), strict=False) |
| torch.jit.save(model, os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) |
| |
| else: |
| del model |
| model = torch.jit.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) |
| |
| save_dir = None |
| _format = None |
| |
| if args.source.isnumeric() is True: |
| _format = 'Webcam' |
|
|
| elif os.path.isdir(args.source): |
| save_dir = os.path.join('results', args.source.split(os.sep)[-1]) |
| _format = get_format(os.listdir(args.source)) |
|
|
| elif os.path.isfile(args.source): |
| save_dir = 'results' |
| _format = get_format([args.source]) |
| |
| if args.dest is not None: |
| save_dir = args.dest |
| |
| if save_dir is not None: |
| os.makedirs(save_dir, exist_ok=True) |
| |
| sample_list = eval(_format + 'Loader')(args.source, opt.Test.Dataset.transforms) |
|
|
| if args.verbose is True: |
| samples = tqdm.tqdm(sample_list, desc='Inference', total=len( |
| sample_list), position=0, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}') |
| else: |
| samples = sample_list |
| |
| writer = None |
| background = None |
|
|
| for sample in samples: |
| if _format == 'Video' and writer is None: |
| writer = cv2.VideoWriter(os.path.join(save_dir, sample['name'] + '.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), sample_list.fps, sample['shape'][::-1]) |
| samples.total += int(sample_list.cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| if _format == 'Video' and sample['image'] is None: |
| if writer is not None: |
| writer.release() |
| writer = None |
| continue |
| |
| if args.gpu is True: |
| sample = to_cuda(sample) |
|
|
| with torch.no_grad(): |
| if args.jit is True: |
| out = model(sample['image']) |
| else: |
| out = model(sample) |
| |
| |
| pred = to_numpy(out['pred'], sample['shape']) |
| img = np.array(sample['original']) |
| |
| if args.type == 'map': |
| img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8) |
| elif args.type == 'rgba': |
| r, g, b = cv2.split(img) |
| pred = (pred * 255).astype(np.uint8) |
| img = cv2.merge([r, g, b, pred]) |
| elif args.type == 'green': |
| bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] |
| img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
| elif args.type == 'blur': |
| img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * (1 - pred[..., np.newaxis]) |
| elif args.type == 'overlay': |
| bg = (np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img) // 2 |
| img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis]) |
| border = cv2.Canny(((pred > .5) * 255).astype(np.uint8), 50, 100) |
| img[border != 0] = [120, 255, 155] |
| elif args.type.lower().endswith(('.jpg', '.jpeg', '.png')): |
| if background is None: |
| background = cv2.cvtColor(cv2.imread(args.type), cv2.COLOR_BGR2RGB) |
| background = cv2.resize(background, img.shape[:2][::-1]) |
| img = img * pred[..., np.newaxis] + background * (1 - pred[..., np.newaxis]) |
| elif args.type == 'debug': |
| debs = [] |
| for k in opt.Train.Debug.keys: |
| debs.extend(out[k]) |
| for i, j in enumerate(debs): |
| log = torch.sigmoid(j).cpu().detach().numpy().squeeze() |
| log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8) |
| log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB) |
| log = cv2.resize(log, img.shape[:2][::-1]) |
| Image.fromarray(log).save(os.path.join(save_dir, sample['name'] + '_' + str(i) + '.png')) |
| |
| |
| |
| img = img.astype(np.uint8) |
| |
| if _format == 'Image': |
| Image.fromarray(img).save(os.path.join(save_dir, sample['name'] + '.png')) |
| elif _format == 'Video' and writer is not None: |
| writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
| elif _format == 'Webcam': |
| cv2.imshow('InSPyReNet', img) |
|
|
| if __name__ == "__main__": |
| args = _args() |
| opt = load_config(args.config) |
| inference(opt, args) |
|
|