| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| from monai.utils import first |
| from monai.utils.type_conversion import convert_to_numpy |
|
|
|
|
| def compute_scale_factor(autoencoder, train_loader, device): |
| with torch.no_grad(): |
| check_data = first(train_loader) |
| z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device)) |
| scale_factor = 1 / torch.std(z) |
| return scale_factor.item() |
|
|
|
|
| def normalize_image_to_uint8(image): |
| """ |
| Normalize image to uint8 |
| Args: |
| image: numpy array |
| """ |
| draw_img = image |
| if np.amin(draw_img) < 0: |
| draw_img[draw_img < 0] = 0 |
| if np.amax(draw_img) > 0.1: |
| draw_img /= np.amax(draw_img) |
| draw_img = (255 * draw_img).astype(np.uint8) |
| return draw_img |
|
|
|
|
| def visualize_2d_image(image): |
| """ |
| Prepare a 2D image for visualization. |
| Args: |
| image: image numpy array, sized (H, W) |
| """ |
| image = convert_to_numpy(image) |
| |
| draw_img = normalize_image_to_uint8(image) |
| draw_img = np.stack([draw_img, draw_img, draw_img], axis=-1) |
| return draw_img |
|
|