| import argparse |
| import random |
| import os |
| import matplotlib.pyplot as plt |
|
|
| import numpy as np |
| import torch |
| import json |
| import faiss |
| from tqdm import tqdm |
| import torch.nn.functional as F |
| import torchvision.transforms as T |
|
|
| import open_clip |
|
|
| from Datasets import CCDataset, Batcher |
| from model import ICCModel |
| from utils import get_vocabulary, unormalize, get_eval_score |
|
|
| AT_K = sorted([1, 3, 5, 10], reverse=True) |
|
|
|
|
| def captioning(args, config, model, data_loader, vocab, device): |
| scores, results = inference(config, model, data_loader, vocab, device, return_results=True) |
|
|
| with open(os.path.join(args.output_path, 'caption.txt'), 'w') as out: |
| for t in scores.items(): |
| out.write(str(t) + '\n') |
|
|
| scores, _ = inference(config, model, data_loader, vocab, device, sub=True, return_results=False) |
|
|
| with open(os.path.join(args.output_path, 'caption_sub.txt'), 'w') as out: |
| for t in scores.items(): |
| out.write(str(t) + '\n') |
|
|
| return results |
|
|
|
|
| def retrieve(args, config, model, src_loader, device): |
| scores_p, scores_r, scores_rr = search(config, model, src_loader, device) |
| with open(os.path.join(args.output_path, 'retrieve.txt'), 'w') as out: |
| for k in AT_K: |
| out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k])) |
| out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k])) |
| out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k])) |
| out.write('\n') |
|
|
| scores_p, scores_r, scores_rr = search(config, model, src_loader, device, sub=True) |
| with open(os.path.join(args.output_path, 'retrieve_sub.txt'), 'w') as out: |
| for k in AT_K: |
| out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k])) |
| out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k])) |
| out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k])) |
| out.write('\n') |
|
|
|
|
| @torch.no_grad() |
| def search(config, model, src_loader, device, sub=False): |
| model.eval() |
|
|
| visual = None |
| textual = None |
| flags = [] |
| embs = None |
| index = faiss.IndexFlatIP(config['d_model']) |
|
|
| batcher = src_loader |
|
|
| for batch in tqdm(batcher, desc='Indexing'): |
| imgs1, imgs2, = batch['images_before'], batch['images_after'] |
| imgs1 = imgs1.to(device) |
| imgs2 = imgs2.to(device) |
| flag = batch['flags'] |
| emb = batch['embs'] |
| if sub and flag[0] == -1: |
| continue |
|
|
| flags.append(flag) |
| embs = torch.cat([embs, emb]) if embs is not None else emb |
|
|
| vis_emb, _ = model.encoder(imgs1, imgs2) |
| visual = torch.cat([visual, vis_emb.cpu()]) if visual is not None else vis_emb.cpu() |
|
|
| input_ids, mask = batch['input_ids'], batch['pad_mask'] |
| input_ids = input_ids.to(device) |
| mask = mask.to(device) |
|
|
| _, text_emb, _, _ = model.decoder(input_ids, None, mask, None) |
| textual = torch.cat([textual, text_emb.cpu()]) if textual is not None else text_emb.cpu() |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| embs = embs.to(device) |
| sims = torch.matmul(embs, torch.t(embs)) |
|
|
| visual = F.normalize(visual, p=2, dim=1) |
| textual = F.normalize(textual, p=2, dim=1) |
|
|
| index.add(visual) |
|
|
| scores_p = {k: [] for k in AT_K} |
| scores_r = {k: [] for k in AT_K} |
| scores_rr = {k: [] for k in AT_K} |
|
|
| for i in tqdm(range(textual.shape[0]), desc='Ranking'): |
| indices = None |
| query = textual[i] |
| query_lab = flags[i] |
|
|
| relevants = set( |
| [x for x in range(len(textual)) if flags[x] == query_lab or sims[i][x] >= config['s-threshold']]) |
|
|
| for k in AT_K: |
| p = 0 |
| r = 0 |
| rr = 0 |
|
|
| if indices is None: |
| indices = index.search(query.unsqueeze(0), k)[1][0] |
| else: |
| indices = indices[:k] |
|
|
| for rank, idx in enumerate(indices): |
| if idx in relevants: |
| if p == 0: |
| rr = 1 / (rank + 1) |
| p += 1 |
| r += 1 |
|
|
| scores_p[k].append(p / len(indices)) |
| scores_r[k].append(r / len(relevants)) |
| scores_rr[k].append(rr) |
|
|
| for k in AT_K: |
| scores_p[k] = sum(scores_p[k]) / len(scores_p[k]) |
| scores_r[k] = sum(scores_r[k]) / len(scores_r[k]) |
| scores_rr[k] = sum(scores_rr[k]) / len(scores_rr[k]) |
|
|
| return scores_p, scores_r, scores_rr |
|
|
|
|
| @torch.no_grad() |
| def inference(config, model, data_loader, vocab, device, sub=False, return_results=False): |
| results = [] |
| references = [] |
| hypotheses = [] |
| inverse_vocab = {v: k for k, v in vocab.items()} |
|
|
| model.eval() |
| for batch in tqdm(data_loader, desc='Inference'): |
| img1 = batch['images_before'][0].unsqueeze(0).to(device) |
| img2 = batch['images_after'][0].unsqueeze(0).to(device) |
| raws = batch['raws'] |
| flags = batch['flags'] |
| if sub and flags[0] == -1: |
| continue |
|
|
| references.append(raws[0]) |
|
|
| input_ids = torch.tensor([[vocab['START']]], dtype=torch.long, device=device) |
| _, vis_toks = model.encoder(img1, img2) |
|
|
| for _ in range(config['max_len']): |
| _, _, lm_logits, weights = model.decoder(input_ids, None, None, vis_toks) |
|
|
| next_item = lm_logits[0][-1].topk(1)[1] |
| input_ids = torch.cat([input_ids, next_item.reshape(1, -1)], dim=1) |
| if next_item.item() == vocab['END']: |
| break |
|
|
| words = [inverse_vocab[x] for x in input_ids[0].cpu().tolist()] |
| sentence = ' '.join(words[1:-1]).strip() |
| hypotheses.append([sentence]) |
|
|
| if return_results: |
| results.append( |
| (img1.cpu(), img2.cpu(), weights.detach().cpu(), vis_toks.detach().cpu(), sentence)) |
|
|
| score_dict = get_eval_score(references, hypotheses) |
| return score_dict, results |
|
|
|
|
| def plot(args, feat_size, results): |
| fig_idx = 0 |
| for img1, img2, weights, diff, sentence in tqdm(results, desc='Plot'): |
| img1 = unormalize(img1) |
| img1 = img1[0].permute(1, 2, 0) |
| img2 = unormalize(img2) |
| img2 = img2[0].permute(1, 2, 0) |
|
|
| transform = T.Resize(size=(img1.size(0), img1.size(1))) |
| weights = weights[0].reshape(-1, feat_size, feat_size) |
| weights = transform(weights).permute(1, 2, 0) |
| weights = torch.sum(weights, 2) / weights.shape[2] |
| after = img2 |
|
|
| feature_map = diff[:, 0, :].reshape(-1, feat_size, feat_size) |
| feature_map = transform(feature_map).permute(1, 2, 0) |
| feature_map = torch.sum(feature_map, 2) / feature_map.shape[2] |
|
|
| fig, ax = plt.subplots(2, 2, figsize=(6, 8)) |
| fig.tight_layout() |
| ax[0, 0].imshow(img1) |
| ax[0, 0].set_title("Before") |
| ax[0, 0].axis('off') |
| ax[0, 1].imshow(img2) |
| ax[0, 1].set_title("After") |
| ax[0, 1].axis('off') |
|
|
| ax[1, 0].set_title("Img diff") |
| ax[1, 0].imshow(feature_map) |
| ax[1, 0].axis('off') |
|
|
| ax[1, 1].set_title("Att weights") |
| ax[1, 1].imshow(after, interpolation='nearest') |
| ax[1, 1].imshow(weights, interpolation='bilinear', alpha=0.5) |
| ax[1, 1].axis('off') |
|
|
| fig.text(.1, .05, sentence, wrap=True) |
|
|
| with open(os.path.join(args.output_path, str(fig_idx) + '.png'), 'wb') as f: |
| plt.savefig(f) |
| plt.close() |
| fig_idx += 1 |
|
|
|
|
| def run(args, config): |
| print('Initializing...') |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
| random.seed(args.seed) |
| torch.backends.cudnn.deterministic = True |
|
|
| device = torch.device('cpu') |
| if torch.cuda.is_available(): |
| device = torch.device('cuda') |
|
|
| if os.path.exists(args.vocab): |
| with open(args.vocab, 'r') as infile: |
| vocab = json.load(infile) |
| else: |
| vocab = get_vocabulary(args.annotation_json, args.vocab) |
|
|
| clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone']) |
|
|
| model = ICCModel(device, clip, config['backbone'], config['d_model'], |
| len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'], |
| config['encoder_layers'], config['decoder_layers'], config['dropout'], |
| learnable=config['learnable'], fine_tune=config['fine_tune'], |
| tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm']) |
|
|
| model.load_state_dict(torch.load(args.model, map_location=device)) |
| model = model.to(device) |
| del clip |
|
|
| print('Loading...') |
| test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'], |
| config['s-transformers'], device) |
| test_loader = Batcher(test_set, 1, config['max_len'], device) |
|
|
| print('Final evaluation...') |
| results = captioning(args, config, model, test_loader, vocab, device) |
| retrieve(args, config, model, test_loader, device) |
| plot(args, model.encoder.encoder.feat_size, results) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model', type=str, default='../input/model_best.pt') |
| parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json') |
| parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/') |
| parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json') |
|
|
| parser.add_argument('--config', type=str, default='../config.json') |
| parser.add_argument('--output_path', type=str, default='../output/') |
| parser.add_argument('--seed', type=int, default=42) |
|
|
| args = parser.parse_args() |
|
|
| with open(args.config, 'r') as config_file: |
| config = json.load(config_file) |
|
|
| run(args, config) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|