import torch.nn as nn import argparse import torch import clip from PIL import Image import sys sys.path.append('../../../') from codes.datasets import build_dataset from codes.models import build_algorithm from mmengine.config import Config from transformers import AutoTokenizer from baselines.utils import calc_accuracy, calc_f1 import torchmetrics import numpy as np from torch.utils.data import ConcatDataset import torch.optim as optim def process_text(text): tokenizer_clinical = AutoTokenizer.from_pretrained('/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000') ixtoword = {v: k for k, v in tokenizer_clinical.get_vocab().items()} if type(text) == str: text = [text] processed_text_tensors = [] for t in text: text_tensors = tokenizer_clinical( t, return_tensors="pt", truncation=True, padding="max_length", max_length=77, ) text_tensors["sent"] = [ ixtoword[ix] for ix in text_tensors["input_ids"][0].tolist() ] processed_text_tensors.append(text_tensors) caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors]) attention_mask = torch.stack( [x["attention_mask"] for x in processed_text_tensors] ) token_type_ids = torch.stack( [x["token_type_ids"] for x in processed_text_tensors] ) if len(text) == 1: caption_ids = caption_ids.squeeze(0).cuda() attention_mask = attention_mask.squeeze(0).cuda()#.to(device) token_type_ids = token_type_ids.squeeze(0).cuda() else: caption_ids = caption_ids.squeeze().cuda() attention_mask = attention_mask.squeeze().cuda() token_type_ids = token_type_ids.squeeze().cuda() cap_lens = [] for txt in text: cap_lens.append(len([w for w in txt if not w.startswith("[")])) return { "input_ids": caption_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "cap_lens": cap_lens, } def test(classifier, test_loader, model, args): class_prompt=args.class_prompt model.eval() with open(class_prompt) as f: lines = f.readlines() f.close() class_texts = [i.replace('\n', '') for i in lines] class_texts = process_text(class_texts) text_features = model(None, class_texts, mode='text')['text_emb'].cuda() text_features /= text_features.norm(dim=-1, keepdim=True) total_acc = [] total_f1_phase = [] total_f1_phase_class = [] with torch.no_grad(): for test_loader in test_loaders: probs_list = [] label_list = [] for i, data in enumerate(test_loader): frames = data['video'].cuda() # (1, M, T, C, H, W) # B, M, T, C, H, W = frames.shape B, C, H, W = frames.shape frames = frames.view(-1, C, H, W) image_features = model(frames, None, mode='video')['img_emb'] # (B*M*T, D) probs = classifier(image_features) # probs = probs / probs.norm(dim=-1, keepdim=True) # probs = probs @ text_features.to(dtype=torch.float32).T probs = probs.softmax(dim=-1) # (1, classes) labels = data['label'].cuda() probs_list.append(probs) label_list.append(labels) # probs_list = torch.cat(probs_list, 0) labels = torch.cat(label_list, 0) acc = calc_accuracy(probs_list, labels) print('accuracy: ', acc) f1_class, f1_average = calc_f1(probs_list, labels) print('f1 average: ', f1_average) print('f1 classes: ', f1_class) total_acc.append(acc) total_f1_phase.append(f1_average) print('f1 phase video-wise average ', np.mean(np.asarray(total_f1_phase))) print('Acc video-wise average ', np.mean(np.asarray(total_acc))) def linear_evaluation( train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, num_classes: int ) -> torch.nn.Module: # Freeze the pre-trained model's parameters for param in model.parameters(): param.requires_grad = False class_prompt=args.class_prompt with open(class_prompt) as f: lines = f.readlines() f.close() class_texts = [i.replace('\n', '') for i in lines] class_texts = process_text(class_texts) text_features = model(None, class_texts, mode='text')['text_emb'].cuda() text_features /= text_features.norm(dim=-1, keepdim=True).to(dtype=torch.float32) # Create a linear classifier classifier = nn.Linear(2048, num_classes).cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=0.0005) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40) # Training loop model.eval() # Ensure the model is in evaluation mode for epoch in range(25): for batch in train_loader: inputs = batch['video'].cuda() labels = batch['label'].cuda() # Forward pass through the pre-trained model to get features with torch.no_grad(): features = model(inputs, None, mode='video')['img_emb'] # (B*M*T, D) features = features.to(dtype=torch.float32) # Forward pass through the classifier outputs = classifier(features) # outputs_feat = outputs_feat / outputs_feat.norm(dim=-1, keepdim=True) # outputs = outputs_feat @ text_features.T loss = criterion(outputs, labels) print(loss) # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() # scheduler.step() # Validation can be added here if needed # classifier = classifier.eval() # test(classifier, test_loaders, model, args) # classifier = classifier.train() return classifier # Return the trained classifier def get_args(description='CLIP'): parser = argparse.ArgumentParser(description=description) parser.add_argument('--class_prompt', default='../class_prompt.txt', type=str, help='prompt for categories') parser.add_argument('--dataset_config', default='./config.py', type=str, help='dataset config') parser.add_argument('--batch_size', default=1, type=int, help='batch for testing') parser.add_argument('--num_class', default=12, type=int, help='class for classification') parser.add_argument('--checkpoint', default='', type=str, help='Checkpoint to load') args = parser.parse_args() return args, parser import torch.distributed as dist if __name__ == "__main__": args, _ = get_args() device = "cuda" if torch.cuda.is_available() else "cpu" configs = Config.fromfile(args.dataset_config)['config'] model = build_algorithm(configs.model_config).cuda() ###### load weights # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/epoch0917.pth.tar')['state_dict'] # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0089.pth.tar')['state_dict'] # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best/epoch0200_archive.pth.tar')['state_dict'] # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0111.pth.tar')['state_dict'] # Action+Phase # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4/epoch0170.pth.tar')['state_dict'] # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_test_4/epoch0500.pth.tar')['state_dict'] # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite/epoch0250.pth.tar')['state_dict'] ### HecVL # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite_spell_1/epoch0120.pth.tar')['state_dict'] ### NIPS state_dict = torch.load(args.checkpoint)['state_dict'] new_dict = {} for k, v in state_dict.items(): if 'module.' in k: new_dict[k[7:].replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder')] = v # .replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder') # for old version of model, convert keys a, b = model.load_state_dict(new_dict, strict=True) # print(1, a) # print(2, b) model.eval() train_datasets = [build_dataset(c) for c in configs.train_config] train_dataset = ConcatDataset(train_datasets) val_datasets = [build_dataset(c) for c in configs.val_config] val_dataset = ConcatDataset(val_datasets) test_datasets = [build_dataset(c) for c in configs.test_config] # 40 videos --> 40 datasets train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, num_workers=4 ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=4 ) test_loaders = [torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=0 ) for test_dataset in test_datasets] # 40 dataloaders print(args) classifier = linear_evaluation(train_loader, val_loader, model, args.num_class) test(classifier, test_loaders, model, args)