| import os |
| |
| import numpy as np |
| import cv2 |
| import math |
| import argparse |
| from tqdm import tqdm |
| import torch |
| from torch import nn |
| from torchvision import transforms |
| import torch.nn.functional as F |
| from model.raft.core.raft import RAFT |
| from model.raft.core.utils.utils import InputPadder |
| from model.bisenet.model import BiSeNet |
| from model.stylegan.model import Downsample |
|
|
| class Options(): |
| def __init__(self): |
|
|
| self.parser = argparse.ArgumentParser(description="Smooth Parsing Maps") |
| self.parser.add_argument("--window_size", type=int, default=5, help="temporal window size") |
| |
| self.parser.add_argument("--faceparsing_path", type=str, default='./checkpoint/faceparsing.pth', help="path of the face parsing model") |
| self.parser.add_argument("--raft_path", type=str, default='./checkpoint/raft-things.pth', help="path of the RAFT model") |
| |
| self.parser.add_argument("--video_path", type=str, help="path of the target video") |
| self.parser.add_argument("--output_path", type=str, default='./output/', help="path of the output parsing maps") |
| |
| def parse(self): |
| self.opt = self.parser.parse_args() |
| args = vars(self.opt) |
| print('Load options') |
| for name, value in sorted(args.items()): |
| print('%s: %s' % (str(name), str(value))) |
| return self.opt |
|
|
| |
| def warp(x, flo): |
| """ |
| warp an image/tensor (im2) back to im1, according to the optical flow |
| x: [B, C, H, W] (im2) |
| flo: [B, 2, H, W] flow |
| """ |
| B, C, H, W = x.size() |
| |
| xx = torch.arange(0, W).view(1,-1).repeat(H,1) |
| yy = torch.arange(0, H).view(-1,1).repeat(1,W) |
| xx = xx.view(1,1,H,W).repeat(B,1,1,1) |
| yy = yy.view(1,1,H,W).repeat(B,1,1,1) |
| grid = torch.cat((xx,yy),1).float() |
|
|
|
|
| |
| grid = grid.cuda() |
| vgrid = grid + flo |
|
|
| |
| |
| vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone()/max(W-1,1)-1.0 |
| vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone()/max(H-1,1)-1.0 |
|
|
| vgrid = vgrid.permute(0,2,3,1) |
| output = nn.functional.grid_sample(x, vgrid,align_corners=True) |
| mask = torch.autograd.Variable(torch.ones(x.size())).cuda() |
| mask = nn.functional.grid_sample(mask, vgrid,align_corners=True) |
|
|
| |
| mask[mask<0.9999] = 0 |
| mask[mask>0] = 1 |
|
|
| |
| |
|
|
| return output*mask, mask |
|
|
| |
| if __name__ == "__main__": |
|
|
| parser = Options() |
| args = parser.parse() |
| print('*'*98) |
| |
| |
| device = "cuda" |
| |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]), |
| ]) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model', help="restore checkpoint") |
| parser.add_argument('--small', action='store_true', help='use small model') |
| parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') |
| parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') |
|
|
| raft_model = torch.nn.DataParallel(RAFT(parser.parse_args(['--model', args.raft_path]))) |
| raft_model.load_state_dict(torch.load(args.raft_path)) |
|
|
| raft_model = raft_model.module |
| raft_model.to(device) |
| raft_model.eval() |
|
|
| parsingpredictor = BiSeNet(n_classes=19) |
| parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage)) |
| parsingpredictor.to(device).eval() |
|
|
| down = Downsample(kernel=[1, 3, 3, 1], factor=2).to(device).eval() |
|
|
| print('Load models successfully!') |
| |
| window = args.window_size |
|
|
| video_cap = cv2.VideoCapture(args.video_path) |
| num = int(video_cap.get(7)) |
|
|
| Is = [] |
| for i in range(num): |
| success, frame = video_cap.read() |
| if success == False: |
| break |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| with torch.no_grad(): |
| Is += [transform(frame).unsqueeze(dim=0).cpu()] |
| video_cap.release() |
|
|
| |
| Is = F.upsample(torch.cat(Is, dim=0), scale_factor=2, mode='bilinear') |
| Is_ = torch.cat((Is[0:window], Is, Is[-window:]), dim=0) |
|
|
| print('Load video with %d frames successfully!'%(len(Is))) |
|
|
| Ps = [] |
| for i in tqdm(range(len(Is))): |
| with torch.no_grad(): |
| Ps += [parsingpredictor(2*Is[i:i+1].to(device))[0].detach().cpu()] |
| Ps = torch.cat(Ps, dim=0) |
| Ps_ = torch.cat((Ps[0:window], Ps, Ps[-window:]), dim=0) |
|
|
| print('Predict parsing maps successfully!') |
| |
| |
| |
| wt = torch.exp(-(torch.arange(2*window+1).float()-window)**2/(2*((window+0.5)**2))).reshape(2*window+1,1,1,1).to(device) |
| |
| parse = [] |
| for ii in tqdm(range(len(Is))): |
| i = ii + window |
| image2 = Is_[i-window:i+window+1].to(device) |
| image1 = Is_[i].repeat(2*window+1,1,1,1).to(device) |
| padder = InputPadder(image1.shape) |
| image1, image2 = padder.pad(image1, image2) |
| with torch.no_grad(): |
| flow_low, flow_up = raft_model((image1+1)*255.0/2, (image2+1)*255.0/2, iters=20, test_mode=True) |
| output, mask = warp(torch.cat((image2, Ps_[i-window:i+window+1].to(device)), dim=1), flow_up) |
| aligned_Is = output[:,0:3].detach() |
| aligned_Ps = output[:,3:].detach() |
| |
| ws = torch.exp(-((aligned_Is-image1)**2).mean(dim=1, keepdims=True)/(2*(0.2**2))) * mask[:,0:1] |
| aligned_Ps[window] = Ps_[i].to(device) |
| |
| ws[window,:,:,:] = 1.0 |
| weights = ws*wt |
| weights = weights / weights.sum(dim=(0), keepdims=True) |
| fused_Ps = (aligned_Ps * weights).sum(dim=0, keepdims=True) |
| parse += [down(fused_Ps).detach().cpu()] |
| parse = torch.cat(parse, dim=0) |
| |
| basename = os.path.basename(args.video_path).split('.')[0] |
| np.save(os.path.join(args.output_path, basename+'_parsingmap.npy'), parse.numpy()) |
| |
| print('Done!') |