| from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred |
| from torch.utils.data import DataLoader |
|
|
| data_dict = { |
| 'ETTh1': Dataset_ETT_hour, |
| 'ETTh2': Dataset_ETT_hour, |
| 'ETTm1': Dataset_ETT_minute, |
| 'ETTm2': Dataset_ETT_minute, |
| 'custom': Dataset_Custom, |
| } |
|
|
|
|
| def data_provider(args, flag): |
| Data = data_dict[args.data] |
| timeenc = 0 if args.embed != 'timeF' else 1 |
| train_only = args.train_only |
|
|
| if flag == 'test': |
| shuffle_flag = False |
| drop_last = False |
| batch_size = args.batch_size |
| freq = args.freq |
| elif flag == 'pred': |
| shuffle_flag = False |
| drop_last = False |
| batch_size = 1 |
| freq = args.freq |
| Data = Dataset_Pred |
| else: |
| shuffle_flag = True |
| drop_last = True |
| batch_size = args.batch_size |
| freq = args.freq |
|
|
| data_set = Data( |
| root_path=args.root_path, |
| data_path=args.data_path, |
| flag=flag, |
| size=[args.seq_len, args.label_len, args.pred_len], |
| features=args.features, |
| target=args.target, |
| timeenc=timeenc, |
| freq=freq, |
| train_only=train_only |
| ) |
| print(flag, len(data_set)) |
| data_loader = DataLoader( |
| data_set, |
| batch_size=batch_size, |
| shuffle=shuffle_flag, |
| num_workers=args.num_workers, |
| drop_last=drop_last) |
| return data_set, data_loader |
|
|