| import torch |
| import numpy as np |
| from super_image import HanModel |
| from typing import Union |
|
|
| def create_superimage_model( |
| device: Union[str, torch.device] = "cuda" |
| ) -> HanModel: |
| """ Create the super image model |
| |
| Returns: |
| HanModel: The super image model |
| """ |
| return HanModel.from_pretrained('eugenesiow/han', scale=4).to(device) |
|
|
|
|
| def run_superimage( |
| model: HanModel, |
| lr: np.ndarray, |
| hr: np.ndarray, |
| device: Union[str, torch.device] = "cuda" |
| ): |
| """ Run the super image model |
| |
| Args: |
| model (HanModel): The super image model |
| lr (np.ndarray): The low resolution image |
| hr (np.ndarray): The high resolution image |
| device (Union[str, torch.device], optional): The device to run the model on. Defaults to "cuda". |
| |
| Returns: |
| dict: The results |
| """ |
| |
| lr_tensor = (torch.from_numpy(lr[[3, 2, 1]]).to(device) / 2000).float() |
| |
| |
| with torch.no_grad(): |
| sr_tensor = model(lr_tensor[None]) |
|
|
| |
| lr = (lr_tensor.cpu().numpy() * 2000).astype(np.uint16) |
| sr = (sr_tensor.cpu().numpy() * 2000).astype(np.uint16) |
|
|
| |
| return { |
| "lr": lr.squeeze(), |
| "hr": hr[0:3].squeeze(), |
| "sr": sr.squeeze() |
| } |