KKYYKK commited on
Commit
d662031
·
verified ·
1 Parent(s): 2af06ab

Upload config_recog_autolapa_frame_linear.py with huggingface_hub

Browse files
config_recog_autolapa_frame_linear.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ _base_ = ['../base.py']
4
+
5
+ config = dict(
6
+ train_config=[
7
+ dict(
8
+ type='Recognition_frame',
9
+ csv_root='/gpfswork/rech/okw/ukw13bv/mmsl/csv/autolaparo/csvs',
10
+ vid='%02d.csv'%i,
11
+ video_root='/gpfsscratch/rech/okw/ukw13bv/autolaparo/frames_output',
12
+ transforms=transforms.Compose(
13
+ [
14
+ transforms.Resize((360, 640)),
15
+ transforms.CenterCrop(224),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
+ ]
19
+ ),
20
+ ) for i in range(1, 11)
21
+ ],
22
+ val_config=[
23
+ dict(
24
+ type='Recognition_frame',
25
+ csv_root='/gpfswork/rech/okw/ukw13bv/mmsl/csv/autolaparo/csvs',
26
+ vid='%02d.csv'%i,
27
+ video_root='/gpfsscratch/rech/okw/ukw13bv/autolaparo/frames_output',
28
+ transforms=transforms.Compose(
29
+ [
30
+ transforms.Resize((360, 640)),
31
+ transforms.CenterCrop(224),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
34
+ ]
35
+ ),
36
+ ) for i in range(11, 15)
37
+ ],
38
+ test_config=[
39
+ dict(
40
+ type='Recognition_frame',
41
+ csv_root='/gpfswork/rech/okw/ukw13bv/mmsl/csv/autolaparo/csvs',
42
+ vid='%02d.csv'%i,
43
+ video_root='/gpfsscratch/rech/okw/ukw13bv/autolaparo/frames_output',
44
+ transforms=transforms.Compose(
45
+ [
46
+ transforms.Resize((360, 640)),
47
+ transforms.CenterCrop(224),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+ ]
51
+ ),
52
+ ) for i in range(15, 22)
53
+ ],
54
+ model_config = dict(
55
+ type='MVNet_feature_extractor',
56
+ backbone_img = dict(
57
+ type='img_backbones/ImageEncoder_feature_extractor',
58
+ # type='img_backbones/ImageEncoder_CLIPVISUAL',
59
+ num_classes=768,
60
+ pretrained='imagenet', # imagenet/ssl/random
61
+ backbone_name='resnet_50',
62
+ # backbone_name='resnet_50_clip'
63
+ img_norm=False
64
+ ),
65
+ backbone_text= dict(
66
+ type='text_backbones/BertEncoder',
67
+ text_bert_type='/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000',
68
+ text_last_n_layers=4,
69
+ text_aggregate_method='sum',
70
+ text_norm=False,
71
+ text_embedding_dim=768,
72
+ text_freeze_bert=False,
73
+ text_agg_tokens=True
74
+ )
75
+ )
76
+ )
77
+