Upload surgvlp_linear_evaluation.py with huggingface_hub
Browse files- surgvlp_linear_evaluation.py +277 -0
surgvlp_linear_evaluation.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import clip
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.append('../../../')
|
| 8 |
+
from codes.datasets import build_dataset
|
| 9 |
+
from codes.models import build_algorithm
|
| 10 |
+
from mmengine.config import Config
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
from baselines.utils import calc_accuracy, calc_f1
|
| 13 |
+
import torchmetrics
|
| 14 |
+
import numpy as np
|
| 15 |
+
from torch.utils.data import ConcatDataset
|
| 16 |
+
import torch.optim as optim
|
| 17 |
+
|
| 18 |
+
def process_text(text):
|
| 19 |
+
tokenizer_clinical = AutoTokenizer.from_pretrained('/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000')
|
| 20 |
+
ixtoword = {v: k for k, v in tokenizer_clinical.get_vocab().items()}
|
| 21 |
+
if type(text) == str:
|
| 22 |
+
text = [text]
|
| 23 |
+
|
| 24 |
+
processed_text_tensors = []
|
| 25 |
+
for t in text:
|
| 26 |
+
|
| 27 |
+
text_tensors = tokenizer_clinical(
|
| 28 |
+
t,
|
| 29 |
+
return_tensors="pt",
|
| 30 |
+
truncation=True,
|
| 31 |
+
padding="max_length",
|
| 32 |
+
max_length=77,
|
| 33 |
+
)
|
| 34 |
+
text_tensors["sent"] = [
|
| 35 |
+
ixtoword[ix] for ix in text_tensors["input_ids"][0].tolist()
|
| 36 |
+
]
|
| 37 |
+
processed_text_tensors.append(text_tensors)
|
| 38 |
+
|
| 39 |
+
caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors])
|
| 40 |
+
attention_mask = torch.stack(
|
| 41 |
+
[x["attention_mask"] for x in processed_text_tensors]
|
| 42 |
+
)
|
| 43 |
+
token_type_ids = torch.stack(
|
| 44 |
+
[x["token_type_ids"] for x in processed_text_tensors]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if len(text) == 1:
|
| 48 |
+
caption_ids = caption_ids.squeeze(0).cuda()
|
| 49 |
+
attention_mask = attention_mask.squeeze(0).cuda()#.to(device)
|
| 50 |
+
token_type_ids = token_type_ids.squeeze(0).cuda()
|
| 51 |
+
else:
|
| 52 |
+
caption_ids = caption_ids.squeeze().cuda()
|
| 53 |
+
attention_mask = attention_mask.squeeze().cuda()
|
| 54 |
+
token_type_ids = token_type_ids.squeeze().cuda()
|
| 55 |
+
|
| 56 |
+
cap_lens = []
|
| 57 |
+
for txt in text:
|
| 58 |
+
cap_lens.append(len([w for w in txt if not w.startswith("[")]))
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
"input_ids": caption_ids,
|
| 62 |
+
"attention_mask": attention_mask,
|
| 63 |
+
"token_type_ids": token_type_ids,
|
| 64 |
+
"cap_lens": cap_lens,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def test(classifier, test_loader, model, args):
|
| 68 |
+
class_prompt=args.class_prompt
|
| 69 |
+
|
| 70 |
+
model.eval()
|
| 71 |
+
|
| 72 |
+
with open(class_prompt) as f:
|
| 73 |
+
lines = f.readlines()
|
| 74 |
+
f.close()
|
| 75 |
+
|
| 76 |
+
class_texts = [i.replace('\n', '') for i in lines]
|
| 77 |
+
class_texts = process_text(class_texts)
|
| 78 |
+
text_features = model(None, class_texts, mode='text')['text_emb'].cuda()
|
| 79 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
total_acc = []
|
| 83 |
+
total_f1_phase = []
|
| 84 |
+
total_f1_phase_class = []
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
for test_loader in test_loaders:
|
| 88 |
+
probs_list = []
|
| 89 |
+
label_list = []
|
| 90 |
+
|
| 91 |
+
for i, data in enumerate(test_loader):
|
| 92 |
+
frames = data['video'].cuda() # (1, M, T, C, H, W)
|
| 93 |
+
# B, M, T, C, H, W = frames.shape
|
| 94 |
+
B, C, H, W = frames.shape
|
| 95 |
+
|
| 96 |
+
frames = frames.view(-1, C, H, W)
|
| 97 |
+
image_features = model(frames, None, mode='video')['img_emb'] # (B*M*T, D)
|
| 98 |
+
|
| 99 |
+
probs = classifier(image_features)
|
| 100 |
+
|
| 101 |
+
# probs = probs / probs.norm(dim=-1, keepdim=True)
|
| 102 |
+
# probs = probs @ text_features.to(dtype=torch.float32).T
|
| 103 |
+
|
| 104 |
+
probs = probs.softmax(dim=-1) # (1, classes)
|
| 105 |
+
labels = data['label'].cuda()
|
| 106 |
+
|
| 107 |
+
probs_list.append(probs)
|
| 108 |
+
label_list.append(labels)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
#
|
| 112 |
+
probs_list = torch.cat(probs_list, 0)
|
| 113 |
+
labels = torch.cat(label_list, 0)
|
| 114 |
+
|
| 115 |
+
acc = calc_accuracy(probs_list, labels)
|
| 116 |
+
print('accuracy: ', acc)
|
| 117 |
+
f1_class, f1_average = calc_f1(probs_list, labels)
|
| 118 |
+
print('f1 average: ', f1_average)
|
| 119 |
+
print('f1 classes: ', f1_class)
|
| 120 |
+
|
| 121 |
+
total_acc.append(acc)
|
| 122 |
+
total_f1_phase.append(f1_average)
|
| 123 |
+
print('f1 phase video-wise average ', np.mean(np.asarray(total_f1_phase)))
|
| 124 |
+
print('Acc video-wise average ', np.mean(np.asarray(total_acc)))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def linear_evaluation(
|
| 129 |
+
train_loader: torch.utils.data.DataLoader,
|
| 130 |
+
val_loader: torch.utils.data.DataLoader,
|
| 131 |
+
model: torch.nn.Module,
|
| 132 |
+
num_classes: int
|
| 133 |
+
) -> torch.nn.Module:
|
| 134 |
+
# Freeze the pre-trained model's parameters
|
| 135 |
+
for param in model.parameters():
|
| 136 |
+
param.requires_grad = False
|
| 137 |
+
|
| 138 |
+
class_prompt=args.class_prompt
|
| 139 |
+
with open(class_prompt) as f:
|
| 140 |
+
lines = f.readlines()
|
| 141 |
+
f.close()
|
| 142 |
+
|
| 143 |
+
class_texts = [i.replace('\n', '') for i in lines]
|
| 144 |
+
class_texts = process_text(class_texts)
|
| 145 |
+
text_features = model(None, class_texts, mode='text')['text_emb'].cuda()
|
| 146 |
+
text_features /= text_features.norm(dim=-1, keepdim=True).to(dtype=torch.float32)
|
| 147 |
+
|
| 148 |
+
# Create a linear classifier
|
| 149 |
+
classifier = nn.Linear(2048, num_classes).cuda()
|
| 150 |
+
criterion = nn.CrossEntropyLoss().cuda()
|
| 151 |
+
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=0.0005)
|
| 152 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Training loop
|
| 156 |
+
model.eval() # Ensure the model is in evaluation mode
|
| 157 |
+
for epoch in range(25):
|
| 158 |
+
for batch in train_loader:
|
| 159 |
+
inputs = batch['video'].cuda()
|
| 160 |
+
labels = batch['label'].cuda()
|
| 161 |
+
|
| 162 |
+
# Forward pass through the pre-trained model to get features
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
features = model(inputs, None, mode='video')['img_emb'] # (B*M*T, D)
|
| 165 |
+
|
| 166 |
+
features = features.to(dtype=torch.float32)
|
| 167 |
+
# Forward pass through the classifier
|
| 168 |
+
outputs = classifier(features)
|
| 169 |
+
|
| 170 |
+
# outputs_feat = outputs_feat / outputs_feat.norm(dim=-1, keepdim=True)
|
| 171 |
+
# outputs = outputs_feat @ text_features.T
|
| 172 |
+
|
| 173 |
+
loss = criterion(outputs, labels)
|
| 174 |
+
print(loss)
|
| 175 |
+
|
| 176 |
+
# Backward and optimize
|
| 177 |
+
optimizer.zero_grad()
|
| 178 |
+
loss.backward()
|
| 179 |
+
optimizer.step()
|
| 180 |
+
|
| 181 |
+
# scheduler.step()
|
| 182 |
+
|
| 183 |
+
# Validation can be added here if needed
|
| 184 |
+
# classifier = classifier.eval()
|
| 185 |
+
# test(classifier, test_loaders, model, args)
|
| 186 |
+
# classifier = classifier.train()
|
| 187 |
+
|
| 188 |
+
return classifier # Return the trained classifier
|
| 189 |
+
|
| 190 |
+
def get_args(description='CLIP'):
|
| 191 |
+
parser = argparse.ArgumentParser(description=description)
|
| 192 |
+
parser.add_argument('--class_prompt', default='../class_prompt.txt', type=str, help='prompt for categories')
|
| 193 |
+
parser.add_argument('--dataset_config', default='./config.py', type=str, help='dataset config')
|
| 194 |
+
parser.add_argument('--batch_size', default=1, type=int, help='batch for testing')
|
| 195 |
+
parser.add_argument('--num_class', default=12, type=int, help='class for classification')
|
| 196 |
+
parser.add_argument('--checkpoint', default='', type=str, help='Checkpoint to load')
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
return args, parser
|
| 199 |
+
|
| 200 |
+
import torch.distributed as dist
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
|
| 203 |
+
args, _ = get_args()
|
| 204 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 205 |
+
configs = Config.fromfile(args.dataset_config)['config']
|
| 206 |
+
|
| 207 |
+
model = build_algorithm(configs.model_config).cuda()
|
| 208 |
+
|
| 209 |
+
###### load weights
|
| 210 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/epoch0917.pth.tar')['state_dict']
|
| 211 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0089.pth.tar')['state_dict']
|
| 212 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best/epoch0200_archive.pth.tar')['state_dict']
|
| 213 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0111.pth.tar')['state_dict'] # Action+Phase
|
| 214 |
+
|
| 215 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4/epoch0170.pth.tar')['state_dict']
|
| 216 |
+
|
| 217 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_test_4/epoch0500.pth.tar')['state_dict']
|
| 218 |
+
|
| 219 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite/epoch0250.pth.tar')['state_dict'] ### HecVL
|
| 220 |
+
|
| 221 |
+
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite_spell_1/epoch0120.pth.tar')['state_dict'] ### NIPS
|
| 222 |
+
|
| 223 |
+
state_dict = torch.load(args.checkpoint)['state_dict']
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
new_dict = {}
|
| 228 |
+
for k, v in state_dict.items():
|
| 229 |
+
if 'module.' in k:
|
| 230 |
+
new_dict[k[7:].replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder')] = v
|
| 231 |
+
# .replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder') # for old version of model, convert keys
|
| 232 |
+
a, b = model.load_state_dict(new_dict, strict=True)
|
| 233 |
+
|
| 234 |
+
# print(1, a)
|
| 235 |
+
# print(2, b)
|
| 236 |
+
|
| 237 |
+
model.eval()
|
| 238 |
+
|
| 239 |
+
train_datasets = [build_dataset(c) for c in configs.train_config]
|
| 240 |
+
train_dataset = ConcatDataset(train_datasets)
|
| 241 |
+
|
| 242 |
+
val_datasets = [build_dataset(c) for c in configs.val_config]
|
| 243 |
+
val_dataset = ConcatDataset(val_datasets)
|
| 244 |
+
|
| 245 |
+
test_datasets = [build_dataset(c) for c in configs.test_config]
|
| 246 |
+
# 40 videos --> 40 datasets
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
train_loader = torch.utils.data.DataLoader(
|
| 250 |
+
train_dataset,
|
| 251 |
+
batch_size=args.batch_size,
|
| 252 |
+
shuffle=True,
|
| 253 |
+
drop_last=False,
|
| 254 |
+
num_workers=4
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
val_loader = torch.utils.data.DataLoader(
|
| 258 |
+
val_dataset,
|
| 259 |
+
batch_size=args.batch_size,
|
| 260 |
+
shuffle=False,
|
| 261 |
+
drop_last=False,
|
| 262 |
+
num_workers=4
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
test_loaders = [torch.utils.data.DataLoader(
|
| 266 |
+
test_dataset,
|
| 267 |
+
batch_size=args.batch_size,
|
| 268 |
+
shuffle=False,
|
| 269 |
+
drop_last=False,
|
| 270 |
+
num_workers=0
|
| 271 |
+
) for test_dataset in test_datasets] # 40 dataloaders
|
| 272 |
+
print(args)
|
| 273 |
+
|
| 274 |
+
classifier = linear_evaluation(train_loader, val_loader, model, args.num_class)
|
| 275 |
+
|
| 276 |
+
test(classifier, test_loaders, model, args)
|
| 277 |
+
|