| import torch |
| import pickle |
| import numpy as np |
| import opensr_test |
| import onnxruntime as ort |
| from typing import List, Union |
|
|
| def load_evoland() -> np.ndarray: |
| |
| so = ort.SessionOptions() |
| so.intra_op_num_threads = 10 |
| so.inter_op_num_threads = 10 |
| so.use_deterministic_compute = True |
|
|
| |
| ep_list = ["CPUExecutionProvider"] |
| ep_list.insert(0, "CUDAExecutionProvider") |
|
|
| ort_session = ort.InferenceSession( |
| "evoland/weights/carn_3x3x64g4sw_bootstrap.onnx", |
| sess_options=so, |
| providers=ep_list |
| ) |
| ort_session.set_providers(["CPUExecutionProvider"]) |
| ro = ort.RunOptions() |
|
|
| return [ort_session, ro] |
|
|
|
|
| def run_evoland( |
| model: List, |
| lr: np.ndarray, |
| hr: np.ndarray |
| ) -> dict: |
| |
| ort_session, ro = model |
|
|
| |
| bands = [1, 2, 3, 7, 4, 5, 6, 8, 10, 11] |
| lr = lr[bands] |
|
|
| if lr.shape[1] == 121: |
| |
| lr = torch.nn.functional.pad( |
| torch.from_numpy(lr[None]).float(), |
| pad=(3, 4, 3, 4), |
| mode='reflect' |
| ).squeeze().cpu().numpy() |
|
|
| |
| sr = ort_session.run( |
| None, |
| {"input": lr[None]}, |
| run_options=ro |
| )[0].squeeze() |
|
|
| |
| sr = sr[:, 3*2:-4*2, 3*2:-4*2].astype(np.uint16) |
| lr = lr[:, 3:-4, 3:-4].astype(np.uint16) |
| else: |
| |
| sr = ort_session.run( |
| None, |
| {"input": lr[None].astype(np.float32)}, |
| run_options=ro |
| )[0].squeeze() |
| |
| |
| |
| if sr.shape[1] != hr.shape[1]: |
| sr = torch.nn.functional.interpolate( |
| torch.from_numpy(sr)[None].float(), |
| size=hr.shape[1:], |
| mode='nearest' |
| ).squeeze().numpy().astype('uint16') |
|
|
|
|
| |
| return { |
| "lr": lr[[2, 1, 0]], |
| "sr": sr[[2, 1, 0]], |
| "hr": hr[0:3] |
| } |