| import os |
| import gc |
| import numpy as np |
| import pandas as pd |
| from tqdm import tqdm |
|
|
| import torch |
| from torch import nn |
| from transformers import DistilBertTokenizer |
|
|
| import config as CFG |
| from dataset import CLIPDataset, get_transforms |
| from CLIP import CLIPModel |
| from utils import AvgMeter, get_lr |
|
|
|
|
| def make_train_valid_dfs(): |
| dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv") |
| dataframe['id'] = dataframe.index |
| max_id = dataframe["id"].max() + 1 if not CFG.debug else 100 |
| image_ids = np.arange(0, max_id) |
| np.random.seed(42) |
| valid_ids = np.random.choice( |
| image_ids, size=int(0.2 * len(image_ids)), replace=False |
| ) |
| train_ids = [id_ for id_ in image_ids if id_ not in valid_ids] |
| train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True) |
| valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True) |
| return train_dataframe, valid_dataframe |
|
|
|
|
| def build_loaders(dataframe, tokenizer, mode): |
| transforms = get_transforms(mode=mode) |
| dataset = CLIPDataset( |
| dataframe["image"].values, |
| dataframe["caption"].values, |
| tokenizer=tokenizer, |
| transforms=transforms, |
| ) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=CFG.batch_size, |
| num_workers=CFG.num_workers, |
| shuffle=True if mode == "train" else False, |
| ) |
| return dataloader |
|
|
|
|
| def train_epoch(model, train_loader, optimizer, lr_scheduler, step): |
| loss_meter = AvgMeter() |
| tqdm_object = tqdm(train_loader, total=len(train_loader)) |
| for batch in tqdm_object: |
| batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} |
| loss = model(batch) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| if step == "batch": |
| lr_scheduler.step() |
|
|
| count = batch["image"].size(0) |
| loss_meter.update(loss.item(), count) |
|
|
| tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) |
| return loss_meter |
|
|
|
|
| def valid_epoch(model, valid_loader): |
| loss_meter = AvgMeter() |
|
|
| tqdm_object = tqdm(valid_loader, total=len(valid_loader)) |
| for batch in tqdm_object: |
| batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} |
| loss = model(batch) |
|
|
| count = batch["image"].size(0) |
| loss_meter.update(loss.item(), count) |
|
|
| tqdm_object.set_postfix(valid_loss=loss_meter.avg) |
| return loss_meter |
|
|
|
|
| def main(): |
| train_df, valid_df = make_train_valid_dfs() |
| tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
| train_loader = build_loaders(train_df, tokenizer, mode="train") |
| valid_loader = build_loaders(valid_df, tokenizer, mode="valid") |
|
|
|
|
| model = CLIPModel().to(CFG.device) |
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay |
| ) |
| lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode="min", patience=CFG.patience, factor=CFG.factor |
| ) |
| step = "epoch" |
|
|
| best_loss = float('inf') |
| for epoch in range(CFG.epochs): |
| print(f"Epoch: {epoch + 1}") |
| model.train() |
| train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) |
| model.eval() |
| with torch.no_grad(): |
| valid_loss = valid_epoch(model, valid_loader) |
| |
| if valid_loss.avg < best_loss: |
| best_loss = valid_loss.avg |
| torch.save(model.state_dict(), "best2.pt") |
| print("Saved Best Model!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|