| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| | from monai.utils import first |
| | from monai.utils.type_conversion import convert_to_numpy |
| |
|
| |
|
| | def compute_scale_factor(autoencoder, train_loader, device): |
| | with torch.no_grad(): |
| | check_data = first(train_loader) |
| | z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device)) |
| | scale_factor = 1 / torch.std(z) |
| | return scale_factor.item() |
| |
|
| |
|
| | def normalize_image_to_uint8(image): |
| | """ |
| | Normalize image to uint8 |
| | Args: |
| | image: numpy array |
| | """ |
| | draw_img = image |
| | if np.amin(draw_img) < 0: |
| | draw_img[draw_img < 0] = 0 |
| | if np.amax(draw_img) > 0.1: |
| | draw_img /= np.amax(draw_img) |
| | draw_img = (255 * draw_img).astype(np.uint8) |
| | return draw_img |
| |
|
| |
|
| | def visualize_2d_image(image): |
| | """ |
| | Prepare a 2D image for visualization. |
| | Args: |
| | image: image numpy array, sized (H, W) |
| | """ |
| | image = convert_to_numpy(image) |
| | |
| | draw_img = normalize_image_to_uint8(image) |
| | draw_img = np.stack([draw_img, draw_img, draw_img], axis=-1) |
| | return draw_img |
| |
|