| | import sys |
| | import open_clip |
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from torchvision.transforms import transforms |
| |
|
| | from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL |
| |
|
| |
|
| |
|
| | def print_statistics(arr): |
| | |
| | assert len(arr.shape) == 1 |
| | print(f'[mean] {arr.mean():.4f} [median] {np.median(arr):.4f} [min] {arr.min():.4f} [max] ' |
| | f'{arr.max():.4f} [std] {arr.std():.4f} [n] {len(arr)}\n') |
| |
|
| | def interpolate_state_dict(m1, beta=0.): |
| | m = {} |
| | try: |
| | m2 = torch.load("/mnt/nsingh/project_multimodal/models/clip-vit-l-visual.pt", map_location='cpu') |
| | except: |
| | m2 = torch.load("/data/naman_deep_singh/project_multimodal/clip-vit-l-visual.pt", map_location='cpu') |
| | for k in m1.keys(): |
| | m[k] = (1 - beta) * m1[k] + beta * m2[k] |
| |
|
| | return m |
| |
|
| |
|
| | def load_clip_model(clip_model_name, pretrained, beta=0.): |
| | try: |
| | model, _, image_processor = open_clip.create_model_and_transforms( |
| | clip_model_name, pretrained='openai', device='cpu' |
| | ) |
| | if pretrained != 'openai': |
| | if isinstance(pretrained, str): |
| | checkpoint = torch.load(pretrained, map_location=torch.device('cpu')) |
| | else: |
| | checkpoint = pretrained |
| | |
| | if beta != 0.: |
| | print("beta", beta) |
| | checkpoint = interpolate_state_dict(pretrained, beta) |
| |
|
| | if 'vision_encoder_state_dict' in checkpoint.keys(): |
| | model.visual.load_state_dict(checkpoint['vision_encoder_state_dict']) |
| | else: |
| | model.visual.load_state_dict(checkpoint) |
| | except RuntimeError as e: |
| | print(f'error: {e}', file=sys.stderr) |
| | print('retrying by loading whole model..', file=sys.stderr) |
| | torch.cuda.empty_cache() |
| | model, _, image_processor = open_clip.create_model_and_transforms( |
| | clip_model_name, pretrained=pretrained, force_quick_gelu=True, device='cpu' |
| | ) |
| | model.eval() |
| |
|
| | |
| | preprocessor_no_norm = transforms.Compose(image_processor.transforms[:-1]) |
| | normalizer = image_processor.transforms[-1] |
| | return model, preprocessor_no_norm, normalizer |
| |
|
| | @torch.no_grad() |
| | def get_text_embeddings(model, dataset, texts): |
| | assert not (dataset and texts) |
| | if dataset: |
| | assert dataset == 'imagenet' |
| | if dataset == 'imagenet': |
| | template = 'This is a photo of a {}' |
| | texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()] |
| | text_tokens = open_clip.tokenize(texts) |
| | elif texts: |
| | text_tokens = open_clip.tokenize(texts) |
| | embedding_text_labels_norm = [] |
| | chunk_size = 500 |
| | for i in range(0, len(text_tokens), chunk_size): |
| | el = text_tokens[i:i+chunk_size] |
| | embedding_text_labels_norm.append( |
| | model.model.encode_text(el.cuda(), normalize=True).detach().cpu() |
| | ) |
| | embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T |
| | if dataset == 'imagenet': |
| | assert (embedding_text_labels_norm.shape == (512, 1000) |
| | or embedding_text_labels_norm.shape == (768, 1000)), embedding_text_labels_norm.shape |
| | return embedding_text_labels_norm |
| |
|
| |
|
| | @torch.inference_mode() |
| | def compute_accuracy_no_dataloader(model, data, targets, device, batch_size=1000): |
| | |
| | |
| | train_flag = model.training |
| | model.eval() |
| | n_batches = int(np.ceil(data.shape[0] / batch_size)) |
| | n_total = 0 |
| | n_correct = 0 |
| | for batch_idx in range(n_batches): |
| | start_idx = batch_idx * batch_size |
| | end_idx = min((batch_idx + 1) * batch_size, data.shape[0]) |
| | data_batch = data[start_idx:end_idx, :].clone().to(device) |
| | targets_batch = targets[start_idx:end_idx].clone().to(device) |
| | logits = model(data_batch) |
| | confs, preds = F.softmax(logits, dim=1).max(dim=1) |
| | n_total += targets_batch.size(0) |
| | n_correct += (preds.eq(targets_batch).sum()).item() |
| | acc = n_correct / n_total |
| |
|
| | |
| | |
| | if train_flag: |
| | model.train() |
| | return acc |
| |
|
| |
|
| |
|