| |
| |
|
|
|
|
| import cv2, os, torch, re |
| import matplotlib.pyplot as plt |
| from scipy.ndimage import zoom |
| import numpy as np |
| from model_two import MakiAlexNet |
| from tqdm import tqdm |
|
|
| |
| TOP_ACCURACY_PERCENTILE = 10 |
|
|
| TEST_IMAGE = "dataset/root/train/left1_frame_10.jpg" |
| MODEL_PARAMS = "alexnet_2.0.pth" |
| GIF_STORE = "dataset/gifs2/" |
| TRAIN_STORE = "dataset/root/train/" |
|
|
| model = MakiAlexNet() |
| model.load_state_dict(torch.load(MODEL_PARAMS)) |
| model.eval() |
|
|
| |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| print("Running on cuda") |
|
|
|
|
| print(dir(model)) |
|
|
| for name, module in model.named_modules(): |
| |
| print(name) |
|
|
|
|
| def extract_file_paths(filename): |
| """With aid from https://regex101.com/, regex.""" |
| extractor_reg = r"(left|right)([0-9]+)(_frame_)([0-9]+)" |
| result = re.search(extractor_reg, filename) |
| frame_no = result.group(4) |
| frame_name = result.group(1) |
| video_no = result.group(2) |
| return frame_no, frame_name, video_no |
|
|
|
|
| def create_mp4_from_frames(file_name, frames): |
| """Generate MP4/GIF file with the collection of frames given with a duration of 2000 msec. """ |
| print("Sorted frames: ", sorted(frames)) |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| height, width, _ = cv2.imread(frames[0]).shape |
| fps = 20 |
| video_path = os.path.join(os.getcwd(), "dataset", "gifs2", f"{file_name}.mp4") |
| video = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) |
| for frame_path in sorted(frames): |
| |
| image = cv2.imread(frame_path) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| video.write(image) |
|
|
| |
| video.release() |
|
|
|
|
|
|
| current_video_name = None |
| selected_frames = [] |
| for image_filename in tqdm(sorted(os.listdir(TRAIN_STORE)), desc="Running Images"): |
|
|
| frame_no, frame_name, video_no = extract_file_paths(image_filename) |
| obtained_video_name = video_no+"vid"+frame_name |
| if current_video_name != obtained_video_name: |
| |
| if selected_frames: |
| filename = f"{current_video_name}" |
| |
| if current_video_name: |
| create_mp4_from_frames(filename, selected_frames) |
| |
| selected_frames = [] |
| current_video_name = obtained_video_name |
|
|
| |
| |
|
|
| img = cv2.imread(os.path.join(TRAIN_STORE, image_filename)) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img = torch.unsqueeze(torch.tensor(img.astype(np.float32)), 0) |
| X = torch.einsum("BWHC->BCWH", img) |
| if torch.cuda.is_available(): |
| X = X.cuda() |
|
|
| output = model(X) |
| |
| |
| |
| conv = model.layer_outputs['Conv2d'] |
| pred = model.layer_outputs["Linear"] |
| pred_weights, pred_bias = model.f_linear.weight, model.f_linear.bias |
| |
|
|
|
|
| conv = torch.einsum("BCWH->BWHC", conv).cpu().detach().numpy() |
| |
| |
| |
| target = np.argmax(pred.cpu().detach().numpy(), axis=1).squeeze() |
|
|
| weights = pred_weights[target, :].cpu().detach().numpy() |
| |
| heatmap = conv.squeeze(0) @ weights |
| |
| |
| scale = 224 / 12 |
| plt.figure(figsize=(12, 12)) |
| img = cv2.imread(os.path.join(TRAIN_STORE, image_filename)) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| plt.imshow(img) |
| plt.imshow(zoom(heatmap, zoom=(scale, scale)), cmap='jet', alpha=0.5) |
| |
| if len(frame_no) == 1: |
| frame_no = "0"+frame_no |
| filename = video_no+frame_name+frame_no+".jpg" |
| file_path = os.path.join(os.getcwd(), "dataset/gifs2/raw/", filename) |
| plt.savefig(file_path) |
| selected_frames.append(file_path) |
| plt.close() |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| exit() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|