| | import os |
| | from data import srdata |
| |
|
| | class DIV2K(srdata.SRData): |
| | def __init__(self, args, name='DIV2K', train=True, benchmark=False): |
| | data_range = [r.split('-') for r in args.data_range.split('/')] |
| | if train: |
| | data_range = data_range[0] |
| | else: |
| | if args.test_only and len(data_range) == 1: |
| | data_range = data_range[0] |
| | else: |
| | data_range = data_range[1] |
| |
|
| | self.begin, self.end = list(map(lambda x: int(x), data_range)) |
| | super(DIV2K, self).__init__( |
| | args, name=name, train=train, benchmark=benchmark |
| | ) |
| |
|
| | def _scan(self): |
| | names_hr, names_edge, names_lr = super(DIV2K, self)._scan() |
| | names_hr = names_hr[self.begin - 1:self.end] |
| | names_edge = names_edge[self.begin - 1:self.end] |
| | names_lr = [n[self.begin - 1:self.end] for n in names_lr] |
| |
|
| | return names_hr, names_edge, names_lr |
| |
|
| | def _set_filesystem(self, dir_data): |
| | super(DIV2K, self)._set_filesystem(dir_data) |
| | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') |
| | self.dir_edge = os.path.join(self.apath, 'DIV2K_train_EDGE_disturbed') |
| | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') |
| | if self.input_large: self.dir_lr += 'L' |
| |
|
| |
|