| import sys |
|
|
|
|
| from train.datasets import COCOFlickrDataset, ImageNetDataset |
| from CLIP_eval.eval_utils import load_clip_model |
|
|
| sys.path.append("open_flamingo") |
| import os |
| import shutil |
| import time |
| import string |
| import random |
|
|
| import numpy as np |
| import open_clip |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from training.scheduler import cosine_lr |
| from torchvision import transforms |
| from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL |
| from train.pgd_train import pgd |
| from train.apgd_train import apgd_train as apgd |
| import wandb |
| from train.utils import init_wandb, AverageMeter |
| from train.sam_data import SamData |
| from open_flamingo.eval.models.utils import unwrap_model |
| from train.utils import str2bool |
|
|
| from slots.DINOSAUR import DINOSAURpp |
| import matplotlib.pyplot as plt |
| from einops import rearrange, repeat |
| from tqdm import tqdm |
|
|
|
|
|
|
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32') |
| parser.add_argument('--pretrained', type=str, default='openai') |
| parser.add_argument('--dataset', type=str, default='imagenet') |
| parser.add_argument('--template', type=str, default='std') |
| parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory') |
| parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized') |
| parser.add_argument('--start_step', type=int, default=0, help='Start step for training') |
| parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path') |
| parser.add_argument('--steps', type=int, default=20000, help='Number of training steps') |
| parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps') |
| parser.add_argument('--batch_size', type=int, default=256) |
| parser.add_argument('--loss', type=str, default='l2', help='ce, l2') |
| parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2') |
| parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss') |
| parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES') |
| parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw') |
| parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer') |
| parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') |
| parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay') |
| parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type') |
| parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training') |
| parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation') |
| parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation') |
| parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack') |
| parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)') |
| parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging') |
| parser.add_argument('--experiment_name', type=str, default='') |
| parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory') |
| parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency') |
| parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency') |
| parser.add_argument('--output_dir', type=str, default='', help='Output directory') |
| parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints') |
| parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA') |
|
|
|
|
| def main(args): |
| |
| if args.wandb: |
| init_wandb( |
| project_name='clip-finetune', |
| model_name=args.finetuned_model_name, |
| config=vars(args) |
| ) |
| else: |
| wandb.init(mode='disabled') |
|
|
| |
| print(f"Arguments:\n{'-' * 20}") |
| for arg, value in vars(args).items(): |
| print(f"{arg}: {value}") |
| print(f"{'-' * 20}") |
|
|
| |
| if args.overwrite: |
| shutil.rmtree(args.output_dir, ignore_errors=True) |
| os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=False) |
|
|
| |
| with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f: |
| f.write(str(args)) |
|
|
| main_device = 0 |
| |
| from open_clip.model import CLIPVisionCfg |
| CLIPVisionCfg.output_tokens = True |
| model_orig, _, image_processor = open_clip.create_model_and_transforms( |
| args.clip_model_name, pretrained='openai' |
| ) |
| |
| preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1]) |
| normalize = image_processor.transforms[-1] |
| del image_processor |
| print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}') |
|
|
| |
| cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False} |
| model_slots = DINOSAURpp(cfg_dict) |
|
|
| |
| if args.dataset == 'imagenet': |
| dataset = ImageNetDataset( |
| root=args.imagenet_root + '/train', |
| transform=preprocessor_without_normalize, |
| ) |
|
|
| elif args.dataset == 'segment_anything': |
| dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=preprocessor_without_normalize) |
|
|
| print(dataset.__len__()) |
| elif args.dataset == 'coco': |
| if os.path.exists('/mnt/datasets/coco'): |
| image_dir_path = '/mnt/datasets/coco/train2017' |
| annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json' |
| elif os.path.exists('/mnt/lustre'): |
| image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017' |
| annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json' |
| else: |
| raise ValueError('COCO dataset not found') |
| dataset = COCOFlickrDataset( |
| image_dir_path=image_dir_path, |
| annotations_path=annotations_path, |
| transform=preprocessor_without_normalize |
| ) |
| dataset_eval = ImageNetDataset( |
| root=args.imagenet_root + '/val', |
| transform=preprocessor_without_normalize, |
| ) |
| dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True) |
| dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True) |
|
|
| |
| if args.template == 'std': |
| template = 'This is a photo of a {}' |
| elif args.template == 'blurry': |
| template = 'This is a blurry photo of a {}' |
| else: |
| raise ValueError(f'Unknown template: {args.template}') |
| print(f'template: {template}') |
| texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()] |
| text_tokens = open_clip.tokenize(texts) |
| model_orig.to(main_device) |
| with torch.no_grad(): |
| embedding_text_labels_norm = [] |
| for el in (text_tokens[:500], text_tokens[500:]): |
| |
| |
| |
| embedding_text_labels_norm.append( |
| model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu() |
| ) |
| embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device) |
| assert torch.allclose( |
| F.normalize(embedding_text_labels_norm, dim=0), |
| embedding_text_labels_norm |
| ) |
| if args.clip_model_name == 'ViT-B-32': |
| assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape |
| elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'): |
| assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape |
| else: |
| raise ValueError(f'Unknown model: {args.clip_model_name}') |
|
|
| model_orig.cpu() |
| model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize) |
| if num_gpus > 1: |
| model_orig = torch.nn.DataParallel(model_orig) |
| model_orig.cuda() |
|
|
| model_slots = model_slots |
| if num_gpus > 1: |
| model_slots = torch.nn.DataParallel(model_slots) |
| model_slots.cuda() |
|
|
| |
| params = unwrap_model(model_slots).parameters() |
|
|
| if args.opt == 'adamw': |
| optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) |
| elif args.opt == 'sgd': |
| optimizer = torch.optim.SGD( |
| params, |
| lr=args.lr, |
| momentum=args.momentum_sgd, |
| weight_decay=args.wd |
| ) |
| else: |
| raise ValueError(f'Optimizer {args.optimizer} not supported.') |
| if args.optimizer_state != '': |
| optimizer.load_state_dict(torch.load(args.optimizer_state)) |
|
|
| |
| scheduler = cosine_lr(optimizer, args.lr, args.warmup, args.steps) |
|
|
| |
| total_epochs = args.steps / len(dataloader) |
| print(f'train for {total_epochs} epochs') |
| args.total_epochs = total_epochs |
|
|
| |
| step_total = args.start_step |
| epoch = 0 |
|
|
| step_total = train_one_epoch_slots( |
| step_total, |
| model_slots=model_slots, |
| model_orig=model_orig, |
| dataloader=dataloader, |
| dataloader_eval=dataloader_eval, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| embedding_text_labels_norm=embedding_text_labels_norm, |
| normalize=normalize, |
| args=args, |
| epoch=epoch |
| ) |
| print(f'Epoch {epoch} done.') |
| epoch += 1 |
|
|
| class ClipVisionModel(torch.nn.Module): |
| def __init__(self, model, args, normalize): |
| super().__init__() |
| self.model = model |
| self.args = args |
| self.normalize = normalize |
|
|
| def forward(self, vision, output_normalize): |
| vision = self.normalize(vision) |
| embedding, patches = self.model(vision) |
| if output_normalize: |
| embedding = F.normalize(embedding, dim=-1) |
| return embedding, patches |
|
|
|
|
| class ComputeLossWrapper: |
| def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None, |
| logit_scale=100.): |
| self.embedding_orig = embedding_orig |
| self.embedding_text_labels_norm = embedding_text_labels_norm |
| self.reduction = reduction |
| self.loss_str = loss |
| self.logit_scale = logit_scale |
|
|
| def __call__(self, embedding, targets): |
| return compute_loss( |
| loss_str=self.loss_str, embedding=embedding, targets=targets, |
| embedding_orig=self.embedding_orig, logit_scale=self.logit_scale, |
| embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction |
| ) |
|
|
| def train_one_epoch_slots( |
| step_total, model_slots, model_orig, dataloader, optimizer, scheduler, normalize, |
| embedding_text_labels_norm, args, epoch, dataloader_eval=None |
| ): |
| model_orig.eval() |
| model_slots.eval() |
|
|
| MSEFunc = torch.nn.MSELoss() |
|
|
|
|
| loss_meter = AverageMeter('loss') |
|
|
|
|
| epoch_start_time = time.time() |
| for i, (data, targets) in tqdm(enumerate(dataloader)): |
| is_classification = isinstance(targets, torch.Tensor) |
| data = data.cuda() |
| n_samples = data.shape[0] |
| if is_classification: |
| targets = targets.cuda() |
|
|
| with torch.no_grad(): |
| embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize) |
|
|
| if num_gpus > 1: |
| patches_orig = model_orig.module.model.ln_pre(patches_orig) |
| else: |
| patches_orig = model_orig.model.ln_pre(patches_orig) |
|
|
|
|
|
|
| for j in range(patches_orig.size(0)): |
| store_npy = patches_orig[j].detach().cpu().numpy() |
| label = targets[j].detach().cpu().numpy() |
| store_name = 'class{}_batch{}_sample{}.npy'.format(label, i, j) |
| store_path = os.path.join('/home/tly/RobustVLM/datasets/imagenet_features', str(label)) |
| os.makedirs(store_path, exist_ok=True) |
| np.save(os.path.join('/home/tly/RobustVLM/datasets/imagenet_features', str(label), store_name), store_npy) |
| np.savez_compressed(os.path.join('/home/tly/RobustVLM/datasets/imagenet_features', str(label), store_name), x=store_npy) |
|
|
| torch.cuda.empty_cache() |
|
|
|
|
|
|
| return step_total |
|
|
|
|
| @torch.no_grad() |
| def compute_acc(logits, targets): |
| preds_clean = logits.max(dim=1)[1].detach() |
| acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100 |
| return acc |
|
|
|
|
| def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale, |
| embedding_text_labels_norm=None, reduction='mean'): |
| if loss_str == 'l2': |
| loss = l2(out=embedding, targets=embedding_orig, reduction=reduction) |
| elif loss_str == 'ce': |
| loss = ce( |
| out=embedding @ (logit_scale * embedding_text_labels_norm), |
| targets=targets, |
| reduction=reduction |
| ) |
| else: |
| raise ValueError(f'loss {loss_str} not supported') |
| return loss |
|
|
| def l2(out, targets, reduction='none'): |
| |
| |
| assert out.shape == targets.shape, f'{out.shape} != {targets.shape}' |
| assert out.shape[0] > 1 |
| |
| squared_error_batch = F.mse_loss(out, targets, reduction='none') |
| if reduction == 'mean': |
| squared_error_batch = torch.mean(squared_error_batch.sum(dim=1)) |
| else: |
| squared_error_batch = squared_error_batch.sum(dim=1) |
| assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}' |
| return squared_error_batch |
|
|
| def ce(out, targets, reduction='mean'): |
| |
| assert out.shape[0] == targets.shape[0], (out.shape, targets.shape) |
| assert out.shape[0] > 1 |
|
|
| return F.cross_entropy(out, targets, reduction=reduction) |
|
|
| if __name__ == '__main__': |
| |
| torch.manual_seed(0) |
| np.random.seed(0) |
|
|
| |
| args = parser.parse_args() |
| args.eps /= 255 |
| args.stepsize_adv /= 255 |
| |
| assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}' |
| assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq' |
|
|
| if args.devices != '': |
| |
| os.environ['CUDA_VISIBLE_DEVICES'] = args.devices |
|
|
| num_gpus = torch.cuda.device_count() |
| if num_gpus > 1: |
| print(f'Number of GPUs available: {num_gpus}') |
| else: |
| print('No multiple GPUs available.') |
|
|
| |
| random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=5)) |
| args.finetuned_model_name = f'{args.clip_model_name}_{args.pretrained}_{args.dataset}_{args.loss}_{args.dataset}_{args.experiment_name}_{random_str}' |
| args.finetuned_model_name = args.finetuned_model_name.replace('/', '_') |
| args.output_dir = os.path.join(args.output_dir, args.finetuned_model_name) |
| |
| main(args) |