Spaces:
Runtime error
Runtime error
Upload codes
Browse files- config.py +34 -0
- dataset.py +261 -0
- datasetbuilder.py +62 -0
- modelbuilder.py +127 -0
- test.py +164 -0
config.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import yaml
|
| 3 |
+
|
| 4 |
+
def get_config_universal(dataset_name):
|
| 5 |
+
with open('./configs/' + dataset_name + '_config.json') as f:
|
| 6 |
+
config = json.load(f)
|
| 7 |
+
return config
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_sweep_config_universal(dataset_name):
|
| 11 |
+
with open('./configs/sweep_' + dataset_name + '_config.yaml') as f:
|
| 12 |
+
config = yaml.load(f,Loader=yaml.FullLoader)
|
| 13 |
+
return config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_config():
|
| 17 |
+
with open('configs/camargo_config.json') as f:
|
| 18 |
+
config = json.load(f)
|
| 19 |
+
return config
|
| 20 |
+
|
| 21 |
+
def get_kiha_config():
|
| 22 |
+
with open('./configs/kiha_config.json') as f:
|
| 23 |
+
config = json.load(f)
|
| 24 |
+
return config
|
| 25 |
+
|
| 26 |
+
def get_model_config(model_config):
|
| 27 |
+
with open(f'./configs/{model_config}.json') as f:
|
| 28 |
+
config = json.load(f)
|
| 29 |
+
return config
|
| 30 |
+
|
| 31 |
+
def get_sweep_config():
|
| 32 |
+
with open('configs/sweep_camargo_config.yaml') as f:
|
| 33 |
+
config = yaml.load(f,Loader=yaml.FullLoader)
|
| 34 |
+
return config
|
dataset.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from loading.loadpickledataset import LoadPickleDataSet
|
| 4 |
+
from preprocessing.augmentation.gaussiannoise import GaussianNoise
|
| 5 |
+
from preprocessing.augmentation.imurotation import IMURotation
|
| 6 |
+
from preprocessing.filter_imu import FilterIMU
|
| 7 |
+
from preprocessing.filter_opensim import FilterOpenSim
|
| 8 |
+
from preprocessing.remove_outlier import remove_outlier
|
| 9 |
+
from preprocessing.resample import Resample
|
| 10 |
+
from preprocessing.segmentation.fixwindowsegmentation import FixWindowSegmentation
|
| 11 |
+
from preprocessing.segmentation.gaitcyclesegmentation import GaitCycleSegmentation
|
| 12 |
+
from preprocessing.segmentation.zeropaddingsegmentation import ZeroPaddingSegmentation
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DataSet:
|
| 16 |
+
def __init__(self, config, load_dataset=True):
|
| 17 |
+
self.config = config
|
| 18 |
+
self.x = []
|
| 19 |
+
self.y = []
|
| 20 |
+
self.labels = []
|
| 21 |
+
self.selected_trial_type = config['selected_trial_type']
|
| 22 |
+
self.selected_activity_label = config['selected_activity_label']
|
| 23 |
+
self.segmentation_method = config['segmentation_method']
|
| 24 |
+
if self.config['gc_dataset']:
|
| 25 |
+
self.segmentation_method = 'zeropadding'
|
| 26 |
+
self.resample = config['resample']
|
| 27 |
+
self.n_sample = len(self.y)
|
| 28 |
+
if load_dataset:
|
| 29 |
+
self.load_dataset()
|
| 30 |
+
self.train_subjects = config['train_subjects']
|
| 31 |
+
self.test_subjects = config['test_subjects']
|
| 32 |
+
self.train_activity = config['train_activity']
|
| 33 |
+
self.test_activity = config['test_activity']
|
| 34 |
+
# self.winsize = 128
|
| 35 |
+
self.train_dataset = {}
|
| 36 |
+
self.test_dataset = {}
|
| 37 |
+
|
| 38 |
+
def load_dataset(self):
|
| 39 |
+
getdata_handler = LoadPickleDataSet(self.config)
|
| 40 |
+
x, y, labels = getdata_handler.run_get_dataset()
|
| 41 |
+
self.x, self.y, self.labels = self.run_activity_based_filter(x, y, labels)
|
| 42 |
+
self._preprocess()
|
| 43 |
+
|
| 44 |
+
def _preprocess(self):
|
| 45 |
+
self.x, self.y, self.labels = remove_outlier(self.x, self.y, self.labels)
|
| 46 |
+
if self.resample:
|
| 47 |
+
self.x, self.y, self.labels = self.run_resample_signal(self.x, self.y, self.labels)
|
| 48 |
+
if self.config['opensim_filter']:
|
| 49 |
+
filteropensim_handler = FilterOpenSim(self.y, lowcut=6, fs=100, order=2)
|
| 50 |
+
self.y = filteropensim_handler.run_lowpass_filter()
|
| 51 |
+
if self.config['imu_filter']:
|
| 52 |
+
filterimu_handler = FilterIMU(self.x, lowcut=10, fs=100, order=2)
|
| 53 |
+
self.x = filterimu_handler.run_lowpass_filter()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def run_resample_signal(self, x, y, labels):
|
| 57 |
+
resample_handler = Resample(x, y, labels, 200, 100)
|
| 58 |
+
x, y, labels = resample_handler._run_resample()
|
| 59 |
+
return x, y, labels
|
| 60 |
+
|
| 61 |
+
def run_segmentation(self, x, y, labels):
|
| 62 |
+
if self.segmentation_method == 'fixedwindow':
|
| 63 |
+
segmentation_handler = FixWindowSegmentation(x, y, labels, winsize=self.config['target_padding_length'], overlap=0.5, start_over=True)
|
| 64 |
+
self.x, self.y, self.labels = segmentation_handler._run_segmentation()
|
| 65 |
+
|
| 66 |
+
elif self.segmentation_method == 'zeropadding':
|
| 67 |
+
segmentation_handler = ZeroPaddingSegmentation(x, y, labels, target_padding_length=self.config['target_padding_length'], start_over=True)
|
| 68 |
+
self.x, self.y, self.labels = segmentation_handler._run_segmentation()
|
| 69 |
+
|
| 70 |
+
elif self.segmentation_method == 'gaitcycle':
|
| 71 |
+
segmentation_handler = GaitCycleSegmentation(x, y, labels, winsize=128, overlap=0.5, start_over=True)
|
| 72 |
+
self.x, self.y, self.labels = segmentation_handler._run_segmentation()
|
| 73 |
+
|
| 74 |
+
if self.config['opensim_filter']:
|
| 75 |
+
filteropensim_handler = FilterOpenSim(self.y, lowcut=6, fs=100, order=2)
|
| 76 |
+
self.y = filteropensim_handler.run_lowpass_filter()
|
| 77 |
+
|
| 78 |
+
if self.config['rotation']:
|
| 79 |
+
imu_rotation_handler = IMURotation(knom=10)
|
| 80 |
+
self.x, self.y, self.labels = imu_rotation_handler.run_rotation(self.x.copy(), self.y.copy(), self.labels.copy())
|
| 81 |
+
|
| 82 |
+
if self.config['gaussian_noise']:
|
| 83 |
+
gaussian_noise_handler = GaussianNoise(0, .05)
|
| 84 |
+
self.x, self.y, self.labels = gaussian_noise_handler.run_add_noise(self.x, self.y, self.labels)
|
| 85 |
+
del x, y, labels
|
| 86 |
+
return self.x, self.y, self.labels
|
| 87 |
+
|
| 88 |
+
def run_activity_based_filter(self, x, y, label):
|
| 89 |
+
'''
|
| 90 |
+
:return: updated x, y, and labels which contains only the selected labels (activity section)
|
| 91 |
+
'''
|
| 92 |
+
updated_x = []
|
| 93 |
+
update_y = []
|
| 94 |
+
updated_label = []
|
| 95 |
+
s = 0
|
| 96 |
+
for ll, xx, yy, in zip(label, x, y):
|
| 97 |
+
# print(ll['subject'][0])
|
| 98 |
+
# print(ll['trialNum'][0])
|
| 99 |
+
if self.config['dataset_name']=='camargo' and ll['trialType'].isin(self.selected_trial_type).all() and self.selected_activity_label == ['all_idle']:
|
| 100 |
+
l_temp = ll[ll['trialType'].isin(self.selected_trial_type)]
|
| 101 |
+
l_temp_index = l_temp.index.values
|
| 102 |
+
xx_temp = xx[l_temp_index]
|
| 103 |
+
yy_temp = yy[l_temp_index]
|
| 104 |
+
|
| 105 |
+
updated_x.append(xx_temp)
|
| 106 |
+
update_y.append(yy_temp)
|
| 107 |
+
updated_label.append(l_temp)
|
| 108 |
+
elif self.config['dataset_name']=='camargo' and ll['trialType'].isin(self.selected_trial_type).all() and self.selected_activity_label == ['all']:
|
| 109 |
+
update_selected_activity_label = list(ll['Label'].unique())
|
| 110 |
+
update_selected_activity_label = [i for i in update_selected_activity_label if i not in ['idle', 'stand']]
|
| 111 |
+
l_temp = ll[(ll['trialType'].isin(self.selected_trial_type)) & (ll['Label'].isin(update_selected_activity_label))]
|
| 112 |
+
l_temp_index = l_temp.index.values
|
| 113 |
+
xx_temp = xx[l_temp_index]
|
| 114 |
+
yy_temp = yy[l_temp_index]
|
| 115 |
+
updated_x.append(xx_temp)
|
| 116 |
+
update_y.append(yy_temp)
|
| 117 |
+
updated_label.append(l_temp)
|
| 118 |
+
|
| 119 |
+
elif self.config['dataset_name'] == 'camargo' and ll['trialType'].isin(self.selected_trial_type).all() and self.selected_activity_label == ['all_split']:
|
| 120 |
+
ll_temp = ll.copy()
|
| 121 |
+
ll_temp['trialType2'] =ll_temp['Label']
|
| 122 |
+
if ll['trialType'][0] =='levelground':
|
| 123 |
+
# get the turn index if it's there
|
| 124 |
+
turn1_indx = ll_temp[ll_temp['Label'] == 'turn1'].index.values
|
| 125 |
+
turn2_indx = ll_temp[ll_temp['Label'] == 'turn2'].index.values
|
| 126 |
+
# check which turn is turn 1
|
| 127 |
+
if turn1_indx[0]<turn2_indx[0]:
|
| 128 |
+
pass
|
| 129 |
+
else:
|
| 130 |
+
turn2_indx_temp = turn1_indx
|
| 131 |
+
turn1_indx = turn2_indx
|
| 132 |
+
turn2_indx = turn2_indx_temp
|
| 133 |
+
# devide into two segments
|
| 134 |
+
seg1 = ll_temp.iloc[0:turn1_indx[-1]+1]
|
| 135 |
+
seg2 = ll_temp.iloc[turn2_indx[0]:]
|
| 136 |
+
seg1_trialType2 = seg1['trialType2'].replace({'idle': 'idle', 'stand': 'idle', 'turn1': 'idle', 'turn2': 'idle',
|
| 137 |
+
'stand-walk':'levelground1', 'walk':'levelground1',
|
| 138 |
+
'walk-stand': 'levelground1'})
|
| 139 |
+
seg2_trialType2 = seg2['trialType2'].replace({'idle': 'idle', 'stand': 'idle', 'turn1': 'idle','turn2': 'idle',
|
| 140 |
+
'stand-walk':'levelground2', 'walk':'levelground2',
|
| 141 |
+
'walk-stand': 'levelground2'})
|
| 142 |
+
ll_temp['trialType2'] = pd.concat([seg1_trialType2, seg2_trialType2])
|
| 143 |
+
ll = ll_temp
|
| 144 |
+
elif ll['trialType'][0] =='ramp':
|
| 145 |
+
ll_temp['trialType2'] = ll_temp['trialType2'].replace({'idle': 'idle',
|
| 146 |
+
'walk-rampascent': 'rampascent', 'rampascent':'rampascent','rampascent-walk': 'rampascent',
|
| 147 |
+
'walk-rampdescent': 'rampdescent', 'rampdescent':'rampdescent','rampdescent-walk': 'rampdescent'})
|
| 148 |
+
ll = ll_temp
|
| 149 |
+
elif ll['trialType'][0] == 'stair':
|
| 150 |
+
ll_temp['trialType2'] = ll_temp['trialType2'].replace({'idle': 'idle',
|
| 151 |
+
'walk-stairascent': 'stairascent', 'stairascent':'stairascent','stairascent-walk': 'stairascent',
|
| 152 |
+
'walk-stairdescent': 'stairdescent', 'stairdescent':'stairdescent','stairdescent-walk': 'stairdescent'})
|
| 153 |
+
ll = ll_temp
|
| 154 |
+
|
| 155 |
+
update_selected_activity_label = list(ll['trialType2'].unique())
|
| 156 |
+
# remove stand, idle, turn1, turn2 samples
|
| 157 |
+
update_selected_activity_label = [i for i in update_selected_activity_label if
|
| 158 |
+
i not in ['idle']]
|
| 159 |
+
for activity_label in update_selected_activity_label:
|
| 160 |
+
# if trial type == levelground ->save stand-walk and walk into one trial and walk-stand into another trial. all samples would be continues
|
| 161 |
+
# if ramp or stair--> save trial for ascent and descent individually
|
| 162 |
+
if isinstance(activity_label, str):
|
| 163 |
+
l_temp = ll[(ll['trialType'].isin(self.selected_trial_type)) & (ll['trialType2']==activity_label)]
|
| 164 |
+
l_temp_index = l_temp.index.values
|
| 165 |
+
xx_temp = xx[l_temp_index]
|
| 166 |
+
yy_temp = yy[l_temp_index]
|
| 167 |
+
updated_x.append(xx_temp)
|
| 168 |
+
update_y.append(yy_temp)
|
| 169 |
+
updated_label.append(l_temp)
|
| 170 |
+
if len(xx_temp)==0:
|
| 171 |
+
print(i)
|
| 172 |
+
elif self.config['dataset_name']=='camargo':
|
| 173 |
+
l_temp = ll[(ll['trialType'].isin(self.selected_trial_type)) & (ll['Label'].isin(self.selected_activity_label))]
|
| 174 |
+
l_temp_index = l_temp.index.values
|
| 175 |
+
xx_temp = xx[l_temp_index]
|
| 176 |
+
yy_temp = yy[l_temp_index]
|
| 177 |
+
|
| 178 |
+
updated_x.append(xx_temp)
|
| 179 |
+
update_y.append(yy_temp)
|
| 180 |
+
updated_label.append(l_temp)
|
| 181 |
+
elif self.config['dataset_name']=='kiha':
|
| 182 |
+
l_temp = ll[(ll['trialType'].isin(self.selected_trial_type))]
|
| 183 |
+
l_temp_index = l_temp.index.values
|
| 184 |
+
xx_temp = xx[l_temp_index]
|
| 185 |
+
yy_temp = yy[l_temp_index]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
updated_x.append(xx_temp)
|
| 189 |
+
update_y.append(yy_temp)
|
| 190 |
+
updated_label.append(l_temp)
|
| 191 |
+
# else:
|
| 192 |
+
# continue
|
| 193 |
+
return updated_x, update_y, updated_label
|
| 194 |
+
|
| 195 |
+
def concatenate_data(self):
|
| 196 |
+
self.labels = pd.concat(self.labels, axis=0, ignore_index = True)
|
| 197 |
+
self.x = np.concatenate(self.x, axis=0)
|
| 198 |
+
self.y = np.concatenate(self.y, axis=0)
|
| 199 |
+
|
| 200 |
+
def run_dataset_split_loop(self):
|
| 201 |
+
train_labels = []
|
| 202 |
+
test_labels = []
|
| 203 |
+
train_x = []
|
| 204 |
+
train_y = []
|
| 205 |
+
test_x = []
|
| 206 |
+
test_y = []
|
| 207 |
+
for t, trial in enumerate(self.labels):
|
| 208 |
+
if all(trial['subject'].isin(self.train_subjects)) and all(trial['trialType2'].isin(self.train_activity)):
|
| 209 |
+
train_labels.append(trial)
|
| 210 |
+
train_x.append(self.x[t])
|
| 211 |
+
train_y.append(self.y[t])
|
| 212 |
+
|
| 213 |
+
elif all(trial['subject'].isin(self.test_subjects)) and all(trial['trialType2'].isin(self.test_activity)):
|
| 214 |
+
test_labels.append(trial)
|
| 215 |
+
test_x.append(self.x[t])
|
| 216 |
+
test_y.append(self.y[t])
|
| 217 |
+
|
| 218 |
+
self.train_dataset['x'] = train_x
|
| 219 |
+
self.train_dataset['y'] = train_y
|
| 220 |
+
self.train_dataset['labels'] = train_labels
|
| 221 |
+
|
| 222 |
+
self.test_dataset['x'] = test_x
|
| 223 |
+
self.test_dataset['y'] = test_y
|
| 224 |
+
self.test_dataset['labels'] = test_labels
|
| 225 |
+
return self.train_dataset, self.test_dataset
|
| 226 |
+
|
| 227 |
+
def run_dataset_split(self):
|
| 228 |
+
if set(self.test_subjects).issubset(self.train_subjects):
|
| 229 |
+
train_labels = self.labels[~self.labels['subject'].isin(self.test_subjects)]
|
| 230 |
+
test_labels = self.labels[(self.labels['subjects'].isin(self.test_subjects))]
|
| 231 |
+
else:
|
| 232 |
+
train_labels = self.labels[self.labels['subject'].isin(self.train_subjects)]
|
| 233 |
+
test_labels = self.labels[(self.labels['subject'].isin(self.test_subjects))]
|
| 234 |
+
print(train_labels['subject'].unique())
|
| 235 |
+
print(test_labels['subject'].unique())
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
train_index = train_labels.index.values
|
| 239 |
+
test_index = test_labels.index.values
|
| 240 |
+
print('training length', len(train_index))
|
| 241 |
+
print('test length', len(test_index))
|
| 242 |
+
|
| 243 |
+
train_x = self.x[train_index]
|
| 244 |
+
train_y = self.y[train_index]
|
| 245 |
+
# self.train_dataset['x'] = train_x.reshape([int(train_x.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], train_x.shape[1]])
|
| 246 |
+
# self.train_dataset['y'] = train_y.reshape([int(train_y.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], train_y.shape[1]])
|
| 247 |
+
self.train_dataset['x'] = train_x
|
| 248 |
+
self.train_dataset['y'] = train_y
|
| 249 |
+
self.train_dataset['labels'] = train_labels.reset_index(drop=True)
|
| 250 |
+
|
| 251 |
+
test_x = self.x[test_index]
|
| 252 |
+
test_y = self.y[test_index]
|
| 253 |
+
# self.test_dataset['x'] = test_x.reshape([int(test_x.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], test_x.shape[1]])
|
| 254 |
+
# self.test_dataset['y'] = test_y.reshape([int(test_y.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], test_y.shape[1]])
|
| 255 |
+
self.test_dataset['x'] = test_x
|
| 256 |
+
self.test_dataset['y'] = test_y
|
| 257 |
+
self.test_dataset['labels'] = test_labels.reset_index(drop=True)
|
| 258 |
+
del train_labels, test_labels, train_x, train_y, test_x, test_y
|
| 259 |
+
return self.train_dataset, self.test_dataset
|
| 260 |
+
|
| 261 |
+
|
datasetbuilder.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
from preprocessing.augmentation.gaussiannoise import GaussianNoise
|
| 5 |
+
from preprocessing.transformation.transformation import Transformation
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from sklearn import preprocessing
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DataSetBuilder(Dataset):
|
| 12 |
+
def __init__(self, x, y, labels, transform_method=None, scaler=None, noise=None, classification=None):
|
| 13 |
+
self.x = x
|
| 14 |
+
self.y = y
|
| 15 |
+
self.labels = labels
|
| 16 |
+
self.y_label = []
|
| 17 |
+
|
| 18 |
+
self.transform_method = transform_method
|
| 19 |
+
self.scaler = scaler
|
| 20 |
+
self.noise = noise
|
| 21 |
+
self.classification = classification
|
| 22 |
+
self._preprocess()
|
| 23 |
+
if self.classification:
|
| 24 |
+
self._run_label_encoding()
|
| 25 |
+
self.n_sample = len(y)
|
| 26 |
+
|
| 27 |
+
# x = np.transpose(self.x, (0, 2, 1))
|
| 28 |
+
self.x = torch.from_numpy(x).double()
|
| 29 |
+
self.y = torch.from_numpy(self.y).double()
|
| 30 |
+
|
| 31 |
+
def _run_label_encoding(self):
|
| 32 |
+
le = preprocessing.LabelEncoder()
|
| 33 |
+
y_label = le.fit_transform(self.labels[:, 0, 3])
|
| 34 |
+
y_label = torch.as_tensor(y_label)
|
| 35 |
+
# self.y_label = F.one_hot(y_label.to(torch.int64))
|
| 36 |
+
self.y_label = y_label.to(torch.int64)
|
| 37 |
+
|
| 38 |
+
def _preprocess(self):
|
| 39 |
+
if self.transform_method['data_transformer_method'] is not None:
|
| 40 |
+
self._run_transform()
|
| 41 |
+
if self.noise is not None:
|
| 42 |
+
self._run_noise()
|
| 43 |
+
|
| 44 |
+
def _run_transform(self):
|
| 45 |
+
transform_handler = Transformation(method=self.transform_method['data_transformer_method'], by=self.transform_method['data_transformer_by'])
|
| 46 |
+
if self.scaler is None:
|
| 47 |
+
self.scaler, self.x = transform_handler.run_transform(train=self.x, scaler_fit=self.scaler)
|
| 48 |
+
else:
|
| 49 |
+
self.x = transform_handler.run_transform(val=self.x, scaler_fit=self.scaler)
|
| 50 |
+
|
| 51 |
+
def _run_noise(self, ):
|
| 52 |
+
gaussiannoise_handler = GaussianNoise(mean=0, std=1)
|
| 53 |
+
self.x, self.y, self.labels = gaussiannoise_handler.run_add_noise(self.x, self.y, self.labels)
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return self.n_sample
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, item):
|
| 59 |
+
if self.classification:
|
| 60 |
+
return self.x[item], self.y[item], self.y_label[item]
|
| 61 |
+
else:
|
| 62 |
+
return self.x[item], self.y[item]
|
modelbuilder.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
# from tsai.models.TST import TST
|
| 4 |
+
from sklearn.neighbors import KNeighborsRegressor
|
| 5 |
+
from config import get_model_config
|
| 6 |
+
from loss.weightedmseloss import WeightedMSELoss
|
| 7 |
+
from loss.weightedmultioutputsloss import WeightedMultiOutputLoss
|
| 8 |
+
from loss.weightedrmseloss import WeightedRMSELoss
|
| 9 |
+
from model.Hernandez2021cnnlstm import Hernandez2021CNNLSTM
|
| 10 |
+
from model.bilstmmodel import BiLSTMModel
|
| 11 |
+
from model.cnnlstm import CNNLSTM
|
| 12 |
+
from model.dorschky2020cnn import Dorschky2020CNN
|
| 13 |
+
from model.gholami2020cnn import Gholami2020CNN
|
| 14 |
+
from model.lstmlstm import Seq2Seq
|
| 15 |
+
from model.lstmlstmattention import Seq2SeqAtt
|
| 16 |
+
from model.lstmlstmrec import Seq2SeqRec
|
| 17 |
+
from model.lstmmodel import LSTMModel
|
| 18 |
+
from model.tcnmodel import TCNModel
|
| 19 |
+
from model.transformer import Transformer
|
| 20 |
+
from model.transformer_seq2seq import Seq2SeqTransformer
|
| 21 |
+
from model.transformer_tsai import TransformerTSAI
|
| 22 |
+
from model.zrenner2018cnn import Zrenner2018CNN
|
| 23 |
+
from utils.update_config import update_model_config
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ModelBuilder:
|
| 27 |
+
def __init__(self, config):
|
| 28 |
+
self.config = config
|
| 29 |
+
self.n_input_channel = len(self.config['selected_sensors'])*6
|
| 30 |
+
self.n_output = len(self.config['selected_opensim_labels'])
|
| 31 |
+
self.model_name = self.config['model_name']
|
| 32 |
+
self.model_config = get_model_config(f'config_{self.model_name}')
|
| 33 |
+
self.model_config = update_model_config(self.config, self.model_config)
|
| 34 |
+
self.optimizer_name = self.config['optimizer_name']
|
| 35 |
+
self.learning_rate = self.config['learning_rate']
|
| 36 |
+
self.l2_weight_decay_status = self.config['l2_weight_decay_status']
|
| 37 |
+
self.l2_weight_decay = self.config['l2_weight_decay']
|
| 38 |
+
self.loss = self.config['loss']
|
| 39 |
+
self.weight = self.config['loss_weight']
|
| 40 |
+
self.device = self.config['device']
|
| 41 |
+
# self.device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
|
| 42 |
+
if not self.n_output == len(self.weight):
|
| 43 |
+
self.weight = None
|
| 44 |
+
|
| 45 |
+
def run_model_builder(self):
|
| 46 |
+
model = self.get_model_architecture()
|
| 47 |
+
criterion = self.get_criterion(self.weight)
|
| 48 |
+
optimizer = self.get_optimizer()
|
| 49 |
+
return model, optimizer, criterion
|
| 50 |
+
|
| 51 |
+
def get_model_architecture(self):
|
| 52 |
+
if self.model_name == 'lstm': # done
|
| 53 |
+
self.model = LSTMModel(self.model_config)
|
| 54 |
+
elif self.model_name == 'bilstm': # done
|
| 55 |
+
self.model = BiLSTMModel(self.model_config)
|
| 56 |
+
elif self.model_name == 'cnnlstm': # done
|
| 57 |
+
self.model = CNNLSTM(self.model_config)
|
| 58 |
+
elif self.model_name == 'hernandez2021cnnlstm': # done
|
| 59 |
+
self.model = Hernandez2021CNNLSTM(self.model_config)
|
| 60 |
+
elif self.model_name == 'seq2seq': # done
|
| 61 |
+
self.model = Seq2Seq(self.config)
|
| 62 |
+
elif self.model_name == 'seq2seqrec':
|
| 63 |
+
self.model = Seq2SeqRec(self.n_input_channel, self.n_output)
|
| 64 |
+
elif self.model_name == 'seq2seqatt':# done
|
| 65 |
+
self.model = Seq2SeqAtt(self.model_config)
|
| 66 |
+
elif self.model_name == 'transformer': #done
|
| 67 |
+
self.model = Transformer(d_input=self.n_input_channel, d_model=12, d_output=self.n_output, d_len=self.config['target_padding_length'], h=8, N=1, attention_size=None,
|
| 68 |
+
dropout=0.5, chunk_mode=None, pe='original', multihead=True)
|
| 69 |
+
elif self.model_name == 'seq2seqtransformer':
|
| 70 |
+
self.model = Seq2SeqTransformer(d_input=self.n_input_channel, d_model=24, d_output=self.n_output, h=8, N=4, attention_size=None,
|
| 71 |
+
dropout=0.1, chunk_mode=None, pe='original')
|
| 72 |
+
elif self.model_name == 'transformertsai':
|
| 73 |
+
c_in = self.n_input_channel # aka channels, features, variables, dimensions
|
| 74 |
+
c_out = self.n_output
|
| 75 |
+
seq_len = self.config['target_padding_length']
|
| 76 |
+
y_range = self.config['target_padding_length']
|
| 77 |
+
max_seq_len = self.config['target_padding_length']
|
| 78 |
+
d_model = self.model_config['tsai_d_model']
|
| 79 |
+
n_heads = self.model_config['tsai_n_heads']
|
| 80 |
+
d_k = d_v = None # if None --> d_model // n_heads
|
| 81 |
+
d_ff = self.model_config['tsai_d_ff']
|
| 82 |
+
res_dropout = self.model_config['tsai_res_dropout_p']
|
| 83 |
+
activation = "gelu"
|
| 84 |
+
n_layers = self.model_config['tsai_n_layers']
|
| 85 |
+
fc_dropout = self.model_config['tsai_fc_dropout_p']
|
| 86 |
+
classification = self.model_config['classification']
|
| 87 |
+
kwargs = {}
|
| 88 |
+
self.model = TransformerTSAI(c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads,
|
| 89 |
+
d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, act=activation, n_layers=n_layers,
|
| 90 |
+
fc_dropout=fc_dropout, classification=classification, **kwargs)
|
| 91 |
+
elif self.model_name == 'Gholami2020CNN':
|
| 92 |
+
self.model = Gholami2020CNN(self.model_config)
|
| 93 |
+
elif self.model_name == 'Dorschky2020CNN':
|
| 94 |
+
self.model = Dorschky2020CNN(self.model_config)
|
| 95 |
+
elif self.model_name == 'Zrenner2018CNN':
|
| 96 |
+
self.model = Zrenner2018CNN(self.model_config)
|
| 97 |
+
elif self.model_name == 'tcn':
|
| 98 |
+
self.model = TCNModel(self.model_config)
|
| 99 |
+
elif self.model_name == 'knn':
|
| 100 |
+
self.model = KNeighborsRegressor()
|
| 101 |
+
return self.model
|
| 102 |
+
|
| 103 |
+
def get_optimizer(self):
|
| 104 |
+
if self.optimizer_name == 'Adam':
|
| 105 |
+
if self.l2_weight_decay_status:
|
| 106 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.l2_weight_decay)
|
| 107 |
+
else:
|
| 108 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
| 109 |
+
return self.optimizer
|
| 110 |
+
|
| 111 |
+
def get_criterion(self, weight=None):
|
| 112 |
+
if self.loss == 'RMSE' and weight is not None:
|
| 113 |
+
weight = torch.tensor(weight).to(self.device)
|
| 114 |
+
self.criterion = WeightedRMSELoss(weight)
|
| 115 |
+
elif self.loss == 'RMSE' and weight is None:
|
| 116 |
+
self.criterion = torch.sqrt(nn.MSELoss())
|
| 117 |
+
elif self.loss == 'MSE' and weight is not None:
|
| 118 |
+
weight = torch.tensor(weight).to(self.device)
|
| 119 |
+
self.criterion = WeightedMSELoss(weight)
|
| 120 |
+
elif self.loss == 'MSE-CE' and weight is not None:
|
| 121 |
+
weight = torch.tensor(weight).to(self.device)
|
| 122 |
+
self.criterion = WeightedMultiOutputLoss(weight)
|
| 123 |
+
else:
|
| 124 |
+
self.criterion = nn.MSELoss()
|
| 125 |
+
return self.criterion
|
| 126 |
+
|
| 127 |
+
|
test.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import wandb
|
| 3 |
+
|
| 4 |
+
from model.lstmlstm import Seq2SeqTest
|
| 5 |
+
from model.lstmlstmattention import Seq2SeqAttTest
|
| 6 |
+
from model.transformer_seq2seq import Seq2SeqTransformerTest
|
| 7 |
+
from modelbuilder import ModelBuilder
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Test:
|
| 11 |
+
def run_testing(self, config, model, test_dataloader):
|
| 12 |
+
self.config = config
|
| 13 |
+
self.device = config['device']
|
| 14 |
+
self.loss = self.config['loss']
|
| 15 |
+
self.weight = self.config['loss_weight']
|
| 16 |
+
self.model_name = self.config['model_name']
|
| 17 |
+
self.classification = config['classification']
|
| 18 |
+
self.n_output = len(self.config['selected_opensim_labels'])
|
| 19 |
+
if not self.n_output == len(self.weight):
|
| 20 |
+
self.weight = None
|
| 21 |
+
modelbuilder_handler = ModelBuilder(self.config)
|
| 22 |
+
criterion = modelbuilder_handler.get_criterion(self.weight)
|
| 23 |
+
self.tester = self.setup_tester()
|
| 24 |
+
y_pred, y_true, loss = self.tester(model, test_dataloader, criterion, self.device)
|
| 25 |
+
return y_pred, y_true, loss
|
| 26 |
+
|
| 27 |
+
def setup_tester(self):
|
| 28 |
+
if self.model_name == 'seq2seqatt':
|
| 29 |
+
tester = self.testing_seq2seqatt
|
| 30 |
+
elif self.model_name == 'seq2seqtransformer':
|
| 31 |
+
tester = self.testing_transformer_seq2seq
|
| 32 |
+
elif (self.model_name == 'transformer' and not self.classification) or (self.model_name == 'transformertsai' and not self.classification):
|
| 33 |
+
tester = self.testing_transformer
|
| 34 |
+
elif self.classification:
|
| 35 |
+
tester = self.testing_w_classification
|
| 36 |
+
else:
|
| 37 |
+
tester = self.testing
|
| 38 |
+
return tester
|
| 39 |
+
|
| 40 |
+
def testing(self, model, test_dataloader, criterion, device):
|
| 41 |
+
model.eval()
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
test_loss = []
|
| 44 |
+
test_preds = []
|
| 45 |
+
test_trues = []
|
| 46 |
+
for x, y in test_dataloader:
|
| 47 |
+
x = x.to(device)
|
| 48 |
+
y = y.to(device)
|
| 49 |
+
y_pred = model(x.float())
|
| 50 |
+
loss = criterion(y, y_pred)
|
| 51 |
+
test_loss.append(loss.item())
|
| 52 |
+
test_preds.append(y_pred)
|
| 53 |
+
test_trues.append(y)
|
| 54 |
+
test_loss = torch.mean(torch.tensor(test_loss))
|
| 55 |
+
print('Test Accuracy of the model: {}'.format(test_loss))
|
| 56 |
+
# wandb.log({"Test Loss": test_loss})
|
| 57 |
+
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
|
| 58 |
+
|
| 59 |
+
def testing_w_classification(self, model, test_dataloader, criterion, device):
|
| 60 |
+
model.eval()
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
test_loss = []
|
| 63 |
+
test_preds = []
|
| 64 |
+
test_trues = []
|
| 65 |
+
for x, y, y_label in test_dataloader:
|
| 66 |
+
x = x.to(device).float()
|
| 67 |
+
y_label = y_label.type(torch.LongTensor).to(device) # The targets passed to nn.CrossEntropyLoss() should be in torch.long format
|
| 68 |
+
y = y.to(device)
|
| 69 |
+
y_pred = model(x)
|
| 70 |
+
y_pred[0] = y_pred[0].double()
|
| 71 |
+
y_pred[1] = y_pred[1].double()
|
| 72 |
+
y_true = [y, y_label]
|
| 73 |
+
loss = criterion(y_pred, y_true)
|
| 74 |
+
test_loss.append(loss.item())
|
| 75 |
+
test_preds.append(y_pred)
|
| 76 |
+
test_trues.append(y_true)
|
| 77 |
+
test_loss = torch.mean(torch.tensor(test_loss))
|
| 78 |
+
print('Test Accuracy of the model: {}'.format(test_loss))
|
| 79 |
+
wandb.log({"Test Loss": test_loss})
|
| 80 |
+
test_preds_reg = []
|
| 81 |
+
test_trues_reg = []
|
| 82 |
+
for pred, true in zip(test_preds, test_trues):
|
| 83 |
+
test_preds_reg.append(pred[0])
|
| 84 |
+
test_trues_reg.append(true[0])
|
| 85 |
+
return torch.cat(test_preds_reg, 0), torch.cat(test_trues_reg, 0), test_loss
|
| 86 |
+
|
| 87 |
+
def testing_seq2seq(self, model, test_dataloader, criterion, device):
|
| 88 |
+
model.eval()
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
test_loss = []
|
| 91 |
+
test_preds = []
|
| 92 |
+
test_trues = []
|
| 93 |
+
for x, y in test_dataloader:
|
| 94 |
+
x = x.to(device)
|
| 95 |
+
y = y.to(device)
|
| 96 |
+
# y_pred = model(x.float(), y.float()) # just for seq 2 seq
|
| 97 |
+
y_pred = Seq2SeqTest(model, x.float())
|
| 98 |
+
loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :])
|
| 99 |
+
test_loss.append(loss.item())
|
| 100 |
+
test_preds.append(y_pred)
|
| 101 |
+
test_trues.append(y)
|
| 102 |
+
test_loss = torch.mean(torch.tensor(test_loss))
|
| 103 |
+
print('Test Accuracy of the model: {}'.format(test_loss))
|
| 104 |
+
wandb.log({"Test Loss": test_loss})
|
| 105 |
+
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
|
| 106 |
+
|
| 107 |
+
def testing_seq2seqatt(self, model, test_dataloader, criterion, device):
|
| 108 |
+
model.eval()
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
test_loss = []
|
| 111 |
+
test_preds = []
|
| 112 |
+
test_trues = []
|
| 113 |
+
for x, y in test_dataloader:
|
| 114 |
+
x = x.to(device)
|
| 115 |
+
y = y.to(device)
|
| 116 |
+
# y_pred = model(x.float(), y.float()) # just for seq 2 seq
|
| 117 |
+
y_pred = Seq2SeqAttTest(model, x.float())
|
| 118 |
+
loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :])
|
| 119 |
+
test_loss.append(loss.item())
|
| 120 |
+
test_preds.append(y_pred)
|
| 121 |
+
test_trues.append(y)
|
| 122 |
+
test_loss = torch.mean(torch.tensor(test_loss))
|
| 123 |
+
print('Test Accuracy of the model: {}'.format(test_loss))
|
| 124 |
+
wandb.log({"Test Loss": test_loss})
|
| 125 |
+
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
|
| 126 |
+
|
| 127 |
+
def testing_transformer(self, model, test_dataloader, criterion, device):
|
| 128 |
+
model.eval()
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
test_loss = []
|
| 131 |
+
test_preds = []
|
| 132 |
+
test_trues = []
|
| 133 |
+
for x, y in test_dataloader:
|
| 134 |
+
x = x.to(device)
|
| 135 |
+
y = y.to(device)
|
| 136 |
+
y_pred = model(x.float()) # just for transformer
|
| 137 |
+
loss = criterion(y, y_pred.to(device))
|
| 138 |
+
test_loss.append(loss.item())
|
| 139 |
+
test_preds.append(y_pred)
|
| 140 |
+
test_trues.append(y)
|
| 141 |
+
test_loss = torch.mean(torch.tensor(test_loss))
|
| 142 |
+
print('Test Accuracy of the model: {}'.format(test_loss))
|
| 143 |
+
wandb.log({"Test Loss": test_loss})
|
| 144 |
+
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
|
| 145 |
+
|
| 146 |
+
def testing_transformer_seq2seq(self, model, test_dataloader, criterion, device):
|
| 147 |
+
model.eval()
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
test_loss = []
|
| 150 |
+
test_preds = []
|
| 151 |
+
test_trues = []
|
| 152 |
+
for x, y in test_dataloader:
|
| 153 |
+
x = x.to(device)
|
| 154 |
+
y = y.to(device)
|
| 155 |
+
y_pred = Seq2SeqTransformerTest(model, x.float())
|
| 156 |
+
# y_pred = model(x.float(), y.float()[:, :-1, :]) # just for seq 2 seq transformer
|
| 157 |
+
loss = criterion(y_pred, y.to(device))
|
| 158 |
+
test_loss.append(loss.item())
|
| 159 |
+
test_preds.append(y_pred)
|
| 160 |
+
test_trues.append(y[:, 1:, :])
|
| 161 |
+
test_loss = torch.mean(torch.tensor(test_loss))
|
| 162 |
+
print('Test Accuracy of the model: {}'.format(test_loss))
|
| 163 |
+
# wandb.log({"Test Loss": test_loss})
|
| 164 |
+
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
|