KKYYKK commited on
Commit
2af06ab
·
verified ·
1 Parent(s): d2e2487

Upload surgvlp_linear_evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. surgvlp_linear_evaluation.py +277 -0
surgvlp_linear_evaluation.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import argparse
3
+ import torch
4
+ import clip
5
+ from PIL import Image
6
+ import sys
7
+ sys.path.append('../../../')
8
+ from codes.datasets import build_dataset
9
+ from codes.models import build_algorithm
10
+ from mmengine.config import Config
11
+ from transformers import AutoTokenizer
12
+ from baselines.utils import calc_accuracy, calc_f1
13
+ import torchmetrics
14
+ import numpy as np
15
+ from torch.utils.data import ConcatDataset
16
+ import torch.optim as optim
17
+
18
+ def process_text(text):
19
+ tokenizer_clinical = AutoTokenizer.from_pretrained('/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000')
20
+ ixtoword = {v: k for k, v in tokenizer_clinical.get_vocab().items()}
21
+ if type(text) == str:
22
+ text = [text]
23
+
24
+ processed_text_tensors = []
25
+ for t in text:
26
+
27
+ text_tensors = tokenizer_clinical(
28
+ t,
29
+ return_tensors="pt",
30
+ truncation=True,
31
+ padding="max_length",
32
+ max_length=77,
33
+ )
34
+ text_tensors["sent"] = [
35
+ ixtoword[ix] for ix in text_tensors["input_ids"][0].tolist()
36
+ ]
37
+ processed_text_tensors.append(text_tensors)
38
+
39
+ caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors])
40
+ attention_mask = torch.stack(
41
+ [x["attention_mask"] for x in processed_text_tensors]
42
+ )
43
+ token_type_ids = torch.stack(
44
+ [x["token_type_ids"] for x in processed_text_tensors]
45
+ )
46
+
47
+ if len(text) == 1:
48
+ caption_ids = caption_ids.squeeze(0).cuda()
49
+ attention_mask = attention_mask.squeeze(0).cuda()#.to(device)
50
+ token_type_ids = token_type_ids.squeeze(0).cuda()
51
+ else:
52
+ caption_ids = caption_ids.squeeze().cuda()
53
+ attention_mask = attention_mask.squeeze().cuda()
54
+ token_type_ids = token_type_ids.squeeze().cuda()
55
+
56
+ cap_lens = []
57
+ for txt in text:
58
+ cap_lens.append(len([w for w in txt if not w.startswith("[")]))
59
+
60
+ return {
61
+ "input_ids": caption_ids,
62
+ "attention_mask": attention_mask,
63
+ "token_type_ids": token_type_ids,
64
+ "cap_lens": cap_lens,
65
+ }
66
+
67
+ def test(classifier, test_loader, model, args):
68
+ class_prompt=args.class_prompt
69
+
70
+ model.eval()
71
+
72
+ with open(class_prompt) as f:
73
+ lines = f.readlines()
74
+ f.close()
75
+
76
+ class_texts = [i.replace('\n', '') for i in lines]
77
+ class_texts = process_text(class_texts)
78
+ text_features = model(None, class_texts, mode='text')['text_emb'].cuda()
79
+ text_features /= text_features.norm(dim=-1, keepdim=True)
80
+
81
+
82
+ total_acc = []
83
+ total_f1_phase = []
84
+ total_f1_phase_class = []
85
+
86
+ with torch.no_grad():
87
+ for test_loader in test_loaders:
88
+ probs_list = []
89
+ label_list = []
90
+
91
+ for i, data in enumerate(test_loader):
92
+ frames = data['video'].cuda() # (1, M, T, C, H, W)
93
+ # B, M, T, C, H, W = frames.shape
94
+ B, C, H, W = frames.shape
95
+
96
+ frames = frames.view(-1, C, H, W)
97
+ image_features = model(frames, None, mode='video')['img_emb'] # (B*M*T, D)
98
+
99
+ probs = classifier(image_features)
100
+
101
+ # probs = probs / probs.norm(dim=-1, keepdim=True)
102
+ # probs = probs @ text_features.to(dtype=torch.float32).T
103
+
104
+ probs = probs.softmax(dim=-1) # (1, classes)
105
+ labels = data['label'].cuda()
106
+
107
+ probs_list.append(probs)
108
+ label_list.append(labels)
109
+
110
+
111
+ #
112
+ probs_list = torch.cat(probs_list, 0)
113
+ labels = torch.cat(label_list, 0)
114
+
115
+ acc = calc_accuracy(probs_list, labels)
116
+ print('accuracy: ', acc)
117
+ f1_class, f1_average = calc_f1(probs_list, labels)
118
+ print('f1 average: ', f1_average)
119
+ print('f1 classes: ', f1_class)
120
+
121
+ total_acc.append(acc)
122
+ total_f1_phase.append(f1_average)
123
+ print('f1 phase video-wise average ', np.mean(np.asarray(total_f1_phase)))
124
+ print('Acc video-wise average ', np.mean(np.asarray(total_acc)))
125
+
126
+
127
+
128
+ def linear_evaluation(
129
+ train_loader: torch.utils.data.DataLoader,
130
+ val_loader: torch.utils.data.DataLoader,
131
+ model: torch.nn.Module,
132
+ num_classes: int
133
+ ) -> torch.nn.Module:
134
+ # Freeze the pre-trained model's parameters
135
+ for param in model.parameters():
136
+ param.requires_grad = False
137
+
138
+ class_prompt=args.class_prompt
139
+ with open(class_prompt) as f:
140
+ lines = f.readlines()
141
+ f.close()
142
+
143
+ class_texts = [i.replace('\n', '') for i in lines]
144
+ class_texts = process_text(class_texts)
145
+ text_features = model(None, class_texts, mode='text')['text_emb'].cuda()
146
+ text_features /= text_features.norm(dim=-1, keepdim=True).to(dtype=torch.float32)
147
+
148
+ # Create a linear classifier
149
+ classifier = nn.Linear(2048, num_classes).cuda()
150
+ criterion = nn.CrossEntropyLoss().cuda()
151
+ optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=0.0005)
152
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40)
153
+
154
+
155
+ # Training loop
156
+ model.eval() # Ensure the model is in evaluation mode
157
+ for epoch in range(25):
158
+ for batch in train_loader:
159
+ inputs = batch['video'].cuda()
160
+ labels = batch['label'].cuda()
161
+
162
+ # Forward pass through the pre-trained model to get features
163
+ with torch.no_grad():
164
+ features = model(inputs, None, mode='video')['img_emb'] # (B*M*T, D)
165
+
166
+ features = features.to(dtype=torch.float32)
167
+ # Forward pass through the classifier
168
+ outputs = classifier(features)
169
+
170
+ # outputs_feat = outputs_feat / outputs_feat.norm(dim=-1, keepdim=True)
171
+ # outputs = outputs_feat @ text_features.T
172
+
173
+ loss = criterion(outputs, labels)
174
+ print(loss)
175
+
176
+ # Backward and optimize
177
+ optimizer.zero_grad()
178
+ loss.backward()
179
+ optimizer.step()
180
+
181
+ # scheduler.step()
182
+
183
+ # Validation can be added here if needed
184
+ # classifier = classifier.eval()
185
+ # test(classifier, test_loaders, model, args)
186
+ # classifier = classifier.train()
187
+
188
+ return classifier # Return the trained classifier
189
+
190
+ def get_args(description='CLIP'):
191
+ parser = argparse.ArgumentParser(description=description)
192
+ parser.add_argument('--class_prompt', default='../class_prompt.txt', type=str, help='prompt for categories')
193
+ parser.add_argument('--dataset_config', default='./config.py', type=str, help='dataset config')
194
+ parser.add_argument('--batch_size', default=1, type=int, help='batch for testing')
195
+ parser.add_argument('--num_class', default=12, type=int, help='class for classification')
196
+ parser.add_argument('--checkpoint', default='', type=str, help='Checkpoint to load')
197
+ args = parser.parse_args()
198
+ return args, parser
199
+
200
+ import torch.distributed as dist
201
+ if __name__ == "__main__":
202
+
203
+ args, _ = get_args()
204
+ device = "cuda" if torch.cuda.is_available() else "cpu"
205
+ configs = Config.fromfile(args.dataset_config)['config']
206
+
207
+ model = build_algorithm(configs.model_config).cuda()
208
+
209
+ ###### load weights
210
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/epoch0917.pth.tar')['state_dict']
211
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0089.pth.tar')['state_dict']
212
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best/epoch0200_archive.pth.tar')['state_dict']
213
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0111.pth.tar')['state_dict'] # Action+Phase
214
+
215
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4/epoch0170.pth.tar')['state_dict']
216
+
217
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_test_4/epoch0500.pth.tar')['state_dict']
218
+
219
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite/epoch0250.pth.tar')['state_dict'] ### HecVL
220
+
221
+ # state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite_spell_1/epoch0120.pth.tar')['state_dict'] ### NIPS
222
+
223
+ state_dict = torch.load(args.checkpoint)['state_dict']
224
+
225
+
226
+
227
+ new_dict = {}
228
+ for k, v in state_dict.items():
229
+ if 'module.' in k:
230
+ new_dict[k[7:].replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder')] = v
231
+ # .replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder') # for old version of model, convert keys
232
+ a, b = model.load_state_dict(new_dict, strict=True)
233
+
234
+ # print(1, a)
235
+ # print(2, b)
236
+
237
+ model.eval()
238
+
239
+ train_datasets = [build_dataset(c) for c in configs.train_config]
240
+ train_dataset = ConcatDataset(train_datasets)
241
+
242
+ val_datasets = [build_dataset(c) for c in configs.val_config]
243
+ val_dataset = ConcatDataset(val_datasets)
244
+
245
+ test_datasets = [build_dataset(c) for c in configs.test_config]
246
+ # 40 videos --> 40 datasets
247
+
248
+
249
+ train_loader = torch.utils.data.DataLoader(
250
+ train_dataset,
251
+ batch_size=args.batch_size,
252
+ shuffle=True,
253
+ drop_last=False,
254
+ num_workers=4
255
+ )
256
+
257
+ val_loader = torch.utils.data.DataLoader(
258
+ val_dataset,
259
+ batch_size=args.batch_size,
260
+ shuffle=False,
261
+ drop_last=False,
262
+ num_workers=4
263
+ )
264
+
265
+ test_loaders = [torch.utils.data.DataLoader(
266
+ test_dataset,
267
+ batch_size=args.batch_size,
268
+ shuffle=False,
269
+ drop_last=False,
270
+ num_workers=0
271
+ ) for test_dataset in test_datasets] # 40 dataloaders
272
+ print(args)
273
+
274
+ classifier = linear_evaluation(train_loader, val_loader, model, args.num_class)
275
+
276
+ test(classifier, test_loaders, model, args)
277
+