| | import numpy as np |
| | import math |
| |
|
| |
|
| | def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)): |
| | """ |
| | Parameters: |
| | - session: an ONNX Runtime session object that contains the super-resolution model |
| | - lr: the low-resolution image |
| | - overlapping: the number of pixels to overlap between adjacent patches |
| | - patch_size: a tuple of (height, width) that specifies the size of each patch |
| | Returns: - a numpy array that represents the enhanced image |
| | """ |
| | _, h, w, _ = lr.shape |
| | sr = np.zeros((1, 2*h, 2*w, 3)) |
| | n_h = math.ceil(h / float(patch_size[0] - overlapping)) |
| | n_w = math.ceil(w / float(patch_size[1] - overlapping)) |
| | |
| | for ih in range(n_h): |
| | h_idx = ih * (patch_size[0] - overlapping) |
| | h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0] |
| | for iw in range(n_w): |
| | w_idx = iw * (patch_size[1] - overlapping) |
| | w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1] |
| |
|
| | tiling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1], :] |
| | |
| | sr_tiling = session.run(None, {session.get_inputs()[0].name: tiling_lr})[0] |
| |
|
| | left, right, top, bottom = 0, patch_size[1], 0, patch_size[0] |
| | left += overlapping//2 |
| | right -= overlapping//2 |
| | top += overlapping//2 |
| | bottom -= overlapping//2 |
| | |
| | if w_idx == 0: |
| | left -= overlapping//2 |
| | if h_idx == 0: |
| | top -= overlapping//2 |
| | if h_idx+patch_size[0]>=h: |
| | bottom += overlapping//2 |
| | if w_idx+patch_size[1]>=w: |
| | right += overlapping//2 |
| | |
| | |
| | sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right), :] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right, :] |
| | return sr |