| import os |
| import json |
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset |
| from PIL import Image |
| from tqdm import tqdm |
| import faiss |
| import torch.nn.functional as F |
| from sentence_transformers import SentenceTransformer |
| import torchvision.transforms as transforms |
| from random import choice |
|
|
|
|
| class CCDataset(Dataset): |
| def __init__(self, json_file, root_dir, vocab, transform, split, max_length, s_pretrained, device): |
| super(CCDataset, self).__init__() |
| self.vocab = vocab |
| self.split = split |
| self.max_length = max_length |
| self.device = device |
| self.transform = transform |
| assert self.split in {'train', 'val', 'test'} |
|
|
| s_model = SentenceTransformer(s_pretrained) |
| self.s_model = s_model.to(device) |
|
|
| self.root_dir = root_dir |
| self.convert = transforms.ToTensor() |
|
|
| with open(json_file) as f: |
| data = json.load(f)['images'] |
|
|
| self.raw_dataset = [entry for entry in data if entry['split'] == split] |
| self.sentences = [] |
| self.embeddings = [] |
|
|
| self.images = [] |
| self.captions = [] |
| for record in tqdm(self.raw_dataset, desc='Tokenize ' + self.split): |
| self.sentences.extend(self.tokenize(record['sentences'])) |
|
|
| for record in tqdm(self.raw_dataset, desc='Embeddings ' + self.split): |
| self.embeddings.extend(self.compute_embeddings(record['sentences'])) |
|
|
| self.preprocess() |
| del self.raw_dataset |
| del self.sentences |
| del self.embeddings |
| del self.s_model |
|
|
| def tokenize(self, batch): |
| for elem in batch: |
| tokens = [self.vocab[x] if x in self.vocab.keys() else self.vocab['UNK'] for x in elem['tokens']] |
| if len(tokens) > self.max_length - 2: |
| continue |
|
|
| tokens = [self.vocab['START']] + tokens + [self.vocab['END']] |
|
|
| mask = [False] * len(tokens) |
|
|
| diff = self.max_length - len(tokens) |
| tokens += [self.vocab['PAD']] * diff |
| mask += [True] * diff |
|
|
| elem['input_ids'] = tokens |
| elem['mask'] = mask |
|
|
| if len(batch) < 5: |
| diff = 5 - len(batch) |
| batch += [choice(batch) for _ in range(diff)] |
|
|
| assert len(batch) == 5 |
| return batch |
|
|
| def compute_embeddings(self, batch): |
| sents = [x['raw'].strip() for x in batch] |
| embs = self.s_model.encode(sents) |
| return embs |
|
|
| def __len__(self): |
| return len(self.captions) |
|
|
| def __getitem__(self, idx): |
| img_idx = idx // 5 if self.split == 'train' else idx |
| elem = self.captions[idx] |
| for k, v in self.images[img_idx].items(): |
| elem[k] = v |
| return elem |
|
|
| def preprocess(self): |
| idx = 0 |
| prev_idx = -1 |
| pbar = tqdm(total=len(self.sentences), desc='Preprocessing ' + self.split) |
| while idx < len(self.sentences): |
| img_idx = idx // 5 |
| assert (self.sentences[idx]['imgid'] == self.raw_dataset[img_idx]['imgid']) |
|
|
| input_ids = torch.tensor(self.sentences[idx]['input_ids'], dtype=torch.long) |
| mask = torch.tensor(self.sentences[idx]['mask'], dtype=torch.bool) |
| raws = [x['raw'] for x in self.raw_dataset[img_idx]['sentences']] |
| flag = -1 if self.raw_dataset[img_idx]['changeflag'] == 0 else self.raw_dataset[img_idx]['imgid'] |
| flag = torch.tensor(flag, dtype=torch.long) |
| embs = torch.tensor(self.embeddings[idx]) if len(self.embeddings) > 0 else None |
|
|
| self.captions.append({'input_ids': input_ids, 'pad_masks': mask, 'raws': raws, 'flags': flag, 'embs': embs}) |
|
|
| if img_idx != prev_idx: |
| before_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'A', |
| self.raw_dataset[img_idx]['filename']) |
| image_before = Image.open(before_img_path) |
| after_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'B', |
| self.raw_dataset[img_idx]['filename']) |
| image_after = Image.open(after_img_path) |
|
|
| image_before = self.transform(image_before).unsqueeze(0) |
| image_after = self.transform(image_after).unsqueeze(0) |
|
|
| self.images.append({'image_before': image_before, 'image_after': image_after, 'flags': flag}) |
| prev_idx = img_idx |
|
|
| inc = 1 if self.split == 'train' else 5 |
| idx += inc |
| pbar.update(inc) |
|
|
| pbar.close() |
|
|
|
|
| class Batcher: |
| def __init__(self, dataset, batch_size, max_len, device, hd=0, model=None, shuffle=False): |
| self.dataset = dataset |
| self.batch_size = batch_size |
| self.hd = hd |
| self.max_len = max_len |
| self.device = device |
| self.model = model |
| self.index = None |
| self.visual = None |
| self.textual = None |
|
|
| self.ptr = 0 |
| self.indices = np.arange(len(self.dataset)) |
| self.shuffle = shuffle |
|
|
| if shuffle: |
| np.random.shuffle(self.indices) |
|
|
| if model and hd > 0 and self.dataset.split == 'train': |
| self.create_index() |
|
|
| def __iter__(self): |
| return self |
|
|
| def __len__(self): |
| return len(self.dataset) // self.batch_size |
|
|
| def __next__(self): |
| if self.ptr >= len(self.dataset): |
| self.ptr = 0 |
| self.index = None |
| self.visual = None |
| self.textual = None |
|
|
| if self.shuffle: |
| np.random.shuffle(self.indices) |
| if self.model and self.hd > 0 and self.dataset.split == 'train': |
| self.create_index() |
|
|
| raise StopIteration |
|
|
| batched = 0 |
| samples = [] |
| hard_negatives = [] |
| while self.ptr < len(self.dataset) and batched < self.batch_size: |
| sample = self.dataset[self.indices[self.ptr]] |
| samples.append(sample) |
|
|
| if self.hd > 0 and self.dataset.split == 'train': |
| hard_neg = self.mine_negatives(self.indices[self.ptr], self.hd) |
| hard_negatives.extend(hard_neg) |
|
|
| self.ptr += 1 |
| batched += 1 |
|
|
| return self.create_batch(samples + hard_negatives) |
|
|
| def get_elem(self, idx): |
| return self.dataset[idx] |
|
|
| @torch.no_grad() |
| def create_index(self): |
| is_training = self.model.training |
| self.model.eval() |
| self.index = faiss.IndexFlatIP(self.model.feature_dim) |
| prev_img = None |
| for idx in tqdm(range(len(self.dataset)), desc='Creating index'): |
| sample = self.dataset[idx] |
| imgs1, imgs2, = sample['image_before'], sample['image_after'] |
| input_ids, mask = sample['input_ids'], sample['pad_masks'] |
|
|
| if idx // 5 != prev_img: |
| imgs1 = imgs1.to(self.device) |
| imgs2 = imgs2.to(self.device) |
| vis_emb, _, = self.model.encoder(imgs1, imgs2) |
| self.visual = torch.cat([self.visual, vis_emb.cpu()]) if self.visual is not None else vis_emb.cpu() |
| prev_img = prev_img + 1 if prev_img is not None else 0 |
|
|
| input_ids = input_ids.unsqueeze(0).to(self.device) |
| mask = mask.unsqueeze(0).to(self.device) |
| _, text_emb, _, _ = self.model.decoder(input_ids, None, mask, None) |
| self.textual = torch.cat([self.textual, text_emb.cpu()]) if self.textual is not None else text_emb.cpu() |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| self.visual = F.normalize(self.visual, p=2, dim=1) |
| self.textual = F.normalize(self.textual, p=2, dim=1) |
| self.index.add(self.visual) |
| if is_training: |
| self.model.train() |
|
|
| def mine_negatives(self, idx, n): |
| negatives = [] |
| m = 4 |
| label = self.dataset[idx]['flags'].item() |
|
|
| while len(negatives) < n and (n * m) < self.index.ntotal: |
| k = n * m |
| indeces = self.index.search(self.textual[idx].unsqueeze(0), k)[1][0] |
| indeces = [x * 5 for x in indeces] |
| negatives = [self.dataset[x] for x in indeces if self.dataset[x]['flags'].item() != label][:n] |
| m *= 2 |
|
|
| return negatives |
|
|
| def create_batch(self, samples): |
| images_before = images_after = input_ids = pad_mask = labels = flags = embs = None |
| raws = [] |
|
|
| for sample in samples: |
| img1 = sample['image_before'] |
| img2 = sample['image_after'] |
|
|
| tokens = sample['input_ids'] |
| mask = sample['pad_masks'] |
| flag = sample['flags'] |
| emb = sample['embs'] |
|
|
| tokens = tokens.unsqueeze(0) |
| mask = mask.unsqueeze(0) |
| flag = flag.unsqueeze(0) |
| lab = tokens.clone() * ~mask |
| lab += torch.tensor([[-100]], dtype=torch.long).repeat(1, self.max_len) * mask |
| if emb is not None: |
| emb = emb.unsqueeze(0) |
|
|
| images_before = torch.cat([images_before, img1]) if images_before is not None else img1 |
| images_after = torch.cat([images_after, img2]) if images_after is not None else img2 |
| input_ids = torch.cat([input_ids, tokens]) if input_ids is not None else tokens |
| labels = torch.cat([labels, lab]) if labels is not None else lab |
| pad_mask = torch.cat([pad_mask, mask]) if pad_mask is not None else mask |
| flags = torch.cat([flags, flag]) if flags is not None else flag |
| if emb is not None: |
| embs = torch.cat([embs, emb]) if embs is not None else emb |
|
|
| raws.append(sample['raws']) |
|
|
| return {'images_before': images_before, 'images_after': images_after, 'input_ids': input_ids, |
| 'pad_mask': pad_mask, 'labels': labels, 'flags': flags, 'raws': raws, 'embs': embs} |
|
|