| |
| import sys |
| import os |
| import torch |
|
|
|
|
| |
| root_path = os.path.abspath('.') |
| sys.path.append(root_path) |
| from opt import opt |
| from architecture.grl import GRL |
| from train_code.train_master import train_master |
|
|
|
|
|
|
| |
| scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
| class train_grl(train_master): |
| def __init__(self, options, args) -> None: |
| super().__init__(options, args, "grl") |
|
|
|
|
| def loss_init(self): |
| |
| self.pixel_loss_load() |
| |
|
|
| def call_model(self): |
| patch_size = 144 |
| window_size = 8 |
|
|
| if opt['model_size'] == "small": |
| |
| self.generator = GRL( |
| upscale = opt['scale'], |
| img_size = patch_size, |
| window_size = 8, |
| depths = [4, 4, 4, 4], |
| embed_dim = 128, |
| num_heads_window = [2, 2, 2, 2], |
| num_heads_stripe = [2, 2, 2, 2], |
| mlp_ratio = 2, |
| qkv_proj_type = "linear", |
| anchor_proj_type = "avgpool", |
| anchor_window_down_factor = 2, |
| out_proj_type = "linear", |
| conv_type = "1conv", |
| upsampler = "pixelshuffle", |
| ).cuda() |
|
|
| elif opt['model_size'] == "tiny": |
| |
| self.generator = GRL( |
| upscale = opt['scale'], |
| img_size = 64, |
| window_size = 8, |
| depths = [4, 4, 4, 4], |
| embed_dim = 64, |
| num_heads_window = [2, 2, 2, 2], |
| num_heads_stripe = [2, 2, 2, 2], |
| mlp_ratio = 2, |
| qkv_proj_type = "linear", |
| anchor_proj_type = "avgpool", |
| anchor_window_down_factor = 2, |
| out_proj_type = "linear", |
| conv_type = "1conv", |
| upsampler = "pixelshuffledirect", |
| ).cuda() |
|
|
|
|
| elif opt['model_size'] == "tiny2": |
| |
| self.generator = GRL( |
| upscale = opt['scale'], |
| img_size = 64, |
| window_size = 8, |
| depths = [4, 4, 4, 4], |
| embed_dim = 64, |
| num_heads_window = [2, 2, 2, 2], |
| num_heads_stripe = [2, 2, 2, 2], |
| mlp_ratio = 2, |
| qkv_proj_type = "linear", |
| anchor_proj_type = "avgpool", |
| anchor_window_down_factor = 2, |
| out_proj_type = "linear", |
| conv_type = "1conv", |
| upsampler = "nearest+conv", |
| ).cuda() |
|
|
| else: |
| raise NotImplementedError("We don't support such model size in GRL model") |
| |
| |
| self.generator.train() |
|
|
| |
| def run(self): |
| self.master_run() |
| |
|
|
| |
| def calculate_loss(self, gen_hr, imgs_hr): |
| |
|
|
| |
| l_g_pix = self.cri_pix(gen_hr, imgs_hr, self.batch_idx) |
| self.weight_store["pixel_loss"] = l_g_pix |
| self.generator_loss += l_g_pix |
|
|
|
|
| def tensorboard_report(self, iteration): |
| |
| self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) |
|
|