| |
| import os |
| import cv2 |
| import numpy as np |
| import torch |
| import bisect |
| import shutil |
| import pdb |
| from tqdm import tqdm |
|
|
| def init_frame_interpolation_model(): |
| print("Initializing frame interpolation model") |
| checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt") |
|
|
| model = torch.jit.load(checkpoint_name, map_location='cpu') |
| model.eval() |
| model = model.half() |
| model = model.to(device="cuda") |
| return model |
|
|
|
|
| def batch_images_interpolation_tool(input_tensor, model, inter_frames=1): |
|
|
| video_tensor = [] |
| frame_num = input_tensor.shape[2] |
| |
| for idx in tqdm(range(frame_num-1)): |
| image1 = input_tensor[:,:,idx] |
| image2 = input_tensor[:,:,idx+1] |
|
|
| results = [image1, image2] |
|
|
| inter_frames = int(inter_frames) |
| idxes = [0, inter_frames + 1] |
| remains = list(range(1, inter_frames + 1)) |
|
|
| splits = torch.linspace(0, 1, inter_frames + 2) |
|
|
| for _ in range(len(remains)): |
| starts = splits[idxes[:-1]] |
| ends = splits[idxes[1:]] |
| distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() |
| matrix = torch.argmin(distances).item() |
| start_i, step = np.unravel_index(matrix, distances.shape) |
| end_i = start_i + 1 |
|
|
| x0 = results[start_i] |
| x1 = results[end_i] |
|
|
| x0 = x0.half() |
| x1 = x1.half() |
| x0 = x0.cuda() |
| x1 = x1.cuda() |
|
|
| dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) |
|
|
| with torch.no_grad(): |
| prediction = model(x0, x1, dt) |
| insert_position = bisect.bisect_left(idxes, remains[step]) |
| idxes.insert(insert_position, remains[step]) |
| results.insert(insert_position, prediction.clamp(0, 1).cpu().float()) |
| del remains[step] |
|
|
| for sub_idx in range(len(results)-1): |
| video_tensor.append(results[sub_idx].unsqueeze(2)) |
|
|
| video_tensor.append(input_tensor[:,:,-1].unsqueeze(2)) |
| video_tensor = torch.cat(video_tensor, dim=2) |
| return video_tensor |