| | import torch |
| | import numpy as np |
| | import opensr_model |
| | from typing import Union |
| |
|
| | def create_opensr_model( |
| | device: Union[str, torch.device] = "cpu" |
| | ) -> opensr_model: |
| | """ Create the super image model |
| | Returns: |
| | HanModel: The super image model |
| | """ |
| | model = opensr_model.SRLatentDiffusion(device=device) |
| | model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt") |
| | model.eval() |
| | return model |
| |
|
| |
|
| | def run_opensr_model( |
| | model: opensr_model, |
| | lr: np.ndarray, |
| | hr: np.ndarray, |
| | device: Union[str, torch.device] = "cpu" |
| | ) -> dict: |
| | |
| | lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float() |
| | hr_img = hr[0:3] |
| |
|
| | if lr_img.shape[1] == 121: |
| | |
| | lr_img = torch.nn.functional.pad( |
| | lr_img[None], |
| | pad=(3, 4, 3, 4), |
| | mode='reflect' |
| | ).squeeze() |
| | |
| | |
| | with torch.no_grad(): |
| | sr_img = model(lr_img[None]).squeeze() |
| |
|
| | |
| | lr_img = lr_img[:, 3:-4, 3:-4] |
| | sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4] |
| | else: |
| | |
| | with torch.no_grad(): |
| | sr_img = model(lr_img[None]).squeeze() |
| |
|
| | |
| | lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) |
| | sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) |
| | hr_img = hr_img |
| | |
| | |
| | return { |
| | "lr": lr_img, |
| | "sr": sr_img, |
| | "hr": hr_img |
| | } |