| import os |
| import cv2 |
| import gc |
| import numpy as np |
| import pandas as pd |
| import itertools |
| from tqdm.autonotebook import tqdm |
| import albumentations as A |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| import timm |
| from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer |
| import os |
|
|
| class CFG: |
| debug = False |
| image_path = "" |
| captions_path = os.getcwd() |
| batch_size = 30 |
| num_workers = 4 |
| head_lr = 1e-3 |
| image_encoder_lr = 1e-4 |
| text_encoder_lr = 1e-5 |
| weight_decay = 1e-3 |
| patience = 1 |
| factor = 0.8 |
| epochs = 4 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| model_name = 'resnet50' |
| image_embedding = 2048 |
| text_encoder_model = "distilbert/distilbert-base-uncased" |
| text_embedding = 768 |
| text_tokenizer = "distilbert/distilbert-base-uncased" |
| max_length = 200 |
|
|
| pretrained = True |
| trainable = True |
| temperature = 1.0 |
|
|
| |
| size = 224 |
|
|
| |
| num_projection_layers = 1 |
| projection_dim = 256 |
| dropout = 0.1 |
| |
| class AvgMeter: |
| def __init__(self, name="Metric"): |
| self.name = name |
| self.reset() |
|
|
| def reset(self): |
| self.avg, self.sum, self.count = [0] * 3 |
|
|
| def update(self, val, count=1): |
| self.count += count |
| self.sum += val * count |
| self.avg = self.sum / self.count |
|
|
| def __repr__(self): |
| text = f"{self.name}: {self.avg:.4f}" |
| return text |
|
|
| def get_lr(optimizer): |
| for param_group in optimizer.param_groups: |
| return param_group["lr"] |
|
|
| class CLIPDataset(torch.utils.data.Dataset): |
| def __init__(self, image_filenames, captions, tokenizer, transforms): |
| """ |
| image_filenames and cpations must have the same length; so, if there are |
| multiple captions for each image, the image_filenames must have repetitive |
| file names |
| """ |
|
|
| self.image_filenames = image_filenames |
| self.captions = list(captions) |
| self.encoded_captions = tokenizer( |
| list(captions), padding=True, truncation=True, max_length=CFG.max_length |
| ) |
| self.transforms = transforms |
|
|
| def __getitem__(self, idx): |
| item = { |
| key: torch.tensor(values[idx]) |
| for key, values in self.encoded_captions.items() |
| } |
|
|
| image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}") |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| image = self.transforms(image=image)['image'] |
| item['image'] = torch.tensor(image).permute(2, 0, 1).float() |
| item['caption'] = self.captions[idx] |
|
|
| return item |
|
|
|
|
| def __len__(self): |
| return len(self.captions) |
|
|
|
|
|
|
| def get_transforms(mode="train"): |
| if mode == "train": |
| return A.Compose( |
| [ |
| A.Resize(CFG.size, CFG.size, always_apply=True), |
| A.Normalize(max_pixel_value=255.0, always_apply=True), |
| ] |
| ) |
| else: |
| return A.Compose( |
| [ |
| A.Resize(CFG.size, CFG.size, always_apply=True), |
| A.Normalize(max_pixel_value=255.0, always_apply=True), |
| ] |
| ) |
| |
| class ImageEncoder(nn.Module): |
| """ |
| Encode images to a fixed size vector |
| """ |
|
|
| def __init__( |
| self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable |
| ): |
| super().__init__() |
| self.model = timm.create_model( |
| model_name, pretrained, num_classes=0, global_pool="avg" |
| ) |
| for p in self.model.parameters(): |
| p.requires_grad = trainable |
|
|
| def forward(self, x): |
| return self.model(x) |
| |
| class TextEncoder(nn.Module): |
| def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): |
| super().__init__() |
| if pretrained: |
| self.model = DistilBertModel.from_pretrained(model_name, use_safetensors=True) |
| else: |
| self.model = DistilBertModel(config=DistilBertConfig()) |
| |
| for p in self.model.parameters(): |
| p.requires_grad = trainable |
|
|
| |
| self.target_token_idx = 0 |
|
|
| def forward(self, input_ids, attention_mask): |
| output = self.model(input_ids=input_ids, attention_mask=attention_mask) |
| last_hidden_state = output.last_hidden_state |
| return last_hidden_state[:, self.target_token_idx, :] |
| |
| class ProjectionHead(nn.Module): |
| def __init__( |
| self, |
| embedding_dim, |
| projection_dim=CFG.projection_dim, |
| dropout=CFG.dropout |
| ): |
| super().__init__() |
| self.projection = nn.Linear(embedding_dim, projection_dim) |
| self.gelu = nn.GELU() |
| self.fc = nn.Linear(projection_dim, projection_dim) |
| self.dropout = nn.Dropout(dropout) |
| self.layer_norm = nn.LayerNorm(projection_dim) |
| |
| def forward(self, x): |
| projected = self.projection(x) |
| x = self.gelu(projected) |
| x = self.fc(x) |
| x = self.dropout(x) |
| x = x + projected |
| x = self.layer_norm(x) |
| return x |
| |
| class CLIPModel(nn.Module): |
| def __init__( |
| self, |
| temperature=CFG.temperature, |
| image_embedding=CFG.image_embedding, |
| text_embedding=CFG.text_embedding, |
| ): |
| super().__init__() |
| self.image_encoder = ImageEncoder() |
| self.text_encoder = TextEncoder() |
| self.image_projection = ProjectionHead(embedding_dim=image_embedding) |
| self.text_projection = ProjectionHead(embedding_dim=text_embedding) |
| self.temperature = temperature |
|
|
| def forward(self, batch): |
| |
| image_features = self.image_encoder(batch["image"]) |
| text_features = self.text_encoder( |
| input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
| ) |
| |
| image_embeddings = self.image_projection(image_features) |
| text_embeddings = self.text_projection(text_features) |
|
|
| |
| logits = (text_embeddings @ image_embeddings.T) / self.temperature |
| images_similarity = image_embeddings @ image_embeddings.T |
| texts_similarity = text_embeddings @ text_embeddings.T |
| targets = F.softmax( |
| (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 |
| ) |
| texts_loss = cross_entropy(logits, targets, reduction='none') |
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') |
| loss = (images_loss + texts_loss) / 2.0 |
| return loss.mean() |
|
|
|
|
| def cross_entropy(preds, targets, reduction='none'): |
| log_softmax = nn.LogSoftmax(dim=-1) |
| loss = (-targets * log_softmax(preds)).sum(1) |
| if reduction == "none": |
| return loss |
| elif reduction == "mean": |
| return loss.mean() |
| |
| def make_train_valid_dfs(): |
| dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv") |
| dataframe['id'] = dataframe.index |
| max_id = dataframe["id"].max() + 1 if not CFG.debug else 100 |
| image_ids = np.arange(0, max_id) |
| np.random.seed(42) |
| valid_ids = np.random.choice( |
| image_ids, size=int(0.2 * len(image_ids)), replace=False |
| ) |
| train_ids = [id_ for id_ in image_ids if id_ not in valid_ids] |
| train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True) |
| valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True) |
| return train_dataframe, valid_dataframe |
|
|
|
|
| def build_loaders(dataframe, tokenizer, mode): |
| transforms = get_transforms(mode=mode) |
| dataset = CLIPDataset( |
| dataframe["image"].values, |
| dataframe["caption"].values, |
| tokenizer=tokenizer, |
| transforms=transforms, |
| ) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=CFG.batch_size, |
| num_workers=CFG.num_workers, |
| shuffle=True if mode == "train" else False, |
| ) |
| return dataloader |
|
|
| def train_epoch(model, train_loader, optimizer, lr_scheduler, step): |
| loss_meter = AvgMeter() |
| tqdm_object = tqdm(train_loader, total=len(train_loader)) |
| for batch in tqdm_object: |
| batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} |
| loss = model(batch) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| if step == "batch": |
| lr_scheduler.step() |
|
|
| count = batch["image"].size(0) |
| loss_meter.update(loss.item(), count) |
|
|
| tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) |
| return loss_meter |
|
|
|
|
| def valid_epoch(model, valid_loader): |
| loss_meter = AvgMeter() |
|
|
| tqdm_object = tqdm(valid_loader, total=len(valid_loader)) |
| for batch in tqdm_object: |
| batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} |
| loss = model(batch) |
|
|
| count = batch["image"].size(0) |
| loss_meter.update(loss.item(), count) |
|
|
| tqdm_object.set_postfix(valid_loss=loss_meter.avg) |
| return loss_meter |
|
|
|
|
| def main(): |
| train_df, valid_df = make_train_valid_dfs() |
| tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
| train_loader = build_loaders(train_df, tokenizer, mode="train") |
| valid_loader = build_loaders(valid_df, tokenizer, mode="valid") |
|
|
|
|
| model = CLIPModel().to(CFG.device) |
| params = [ |
| {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr}, |
| {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr}, |
| {"params": itertools.chain( |
| model.image_projection.parameters(), model.text_projection.parameters() |
| ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay} |
| ] |
| optimizer = torch.optim.AdamW(params, weight_decay=0.) |
| lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode="min", patience=CFG.patience, factor=CFG.factor |
| ) |
| step = "epoch" |
|
|
| best_loss = float('inf') |
| for epoch in range(CFG.epochs): |
| print(f"Epoch: {epoch + 1}") |
| model.train() |
| train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) |
| model.eval() |
| with torch.no_grad(): |
| valid_loss = valid_epoch(model, valid_loader) |
| |
| if valid_loss.avg < best_loss: |
| best_loss = valid_loss.avg |
| torch.save(model.state_dict(), "best.pt") |
| print("Saved Best Model!") |
| |
| lr_scheduler.step(valid_loss.avg) |
| |
| if __name__ == "__main__": |
| main() |