Spaces:
Running
Running
| import argparse | |
| import logging | |
| import os | |
| import nrrd | |
| import numpy as np | |
| import torch | |
| from monai.bundle import ConfigParser | |
| from monai.data import MetaTensor | |
| from monai.transforms import ( | |
| Compose, | |
| EnsureChannelFirstd, | |
| EnsureTyped, | |
| LoadImaged, | |
| NormalizeIntensityd, | |
| Orientationd, | |
| ScaleIntensityd, | |
| Spacingd, | |
| ) | |
| from monai.utils import set_determinism | |
| from tqdm import tqdm | |
| set_determinism(43) | |
| def get_segmask(args: argparse.Namespace) -> argparse.Namespace: | |
| """ | |
| Generate prostate segmentation masks using a pre-trained deep learning model. | |
| This function performs inference on T2-weighted MRI images to segment the prostate gland. | |
| It applies preprocessing transformations, runs the segmentation model, and saves the | |
| predicted masks. Post-processing is applied to retain only the top 10 slices with | |
| the highest non-zero voxel counts. | |
| Args: | |
| args: An arguments object containing: | |
| - output_dir (str): Base output directory where segmentation masks will be saved | |
| - project_dir (str): Root project directory containing model config and checkpoint | |
| - t2_dir (str): Directory containing input T2-weighted MRI images in NRRD format | |
| Returns: | |
| args: The updated arguments object with seg_dir added, pointing to the | |
| prostate_mask subdirectory within output_dir | |
| Raises: | |
| FileNotFoundError: If the model checkpoint or config file is not found | |
| RuntimeError: If CUDA operations fail on GPU | |
| Notes: | |
| - Automatically selects GPU (CUDA) if available, otherwise uses CPU | |
| - Applies MONAI transformations: loading, orientation (RAS), spacing (0.5mm isotropic), | |
| intensity scaling and normalization | |
| - Post-processing filters predictions to top 10 slices by non-zero voxel density | |
| - Output masks are saved in NRRD format preserving original image headers | |
| """ | |
| args.seg_dir = os.path.join(args.output_dir, "prostate_mask") | |
| os.makedirs(args.seg_dir, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.empty_cache() | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| model_config_file = os.path.join(args.project_dir, "config", "inference.json") | |
| model_config = ConfigParser() | |
| model_config.read_config(model_config_file) | |
| model_config["output_dir"] = args.seg_dir | |
| model_config["dataset_dir"] = args.t2_dir | |
| files = os.listdir(args.t2_dir) | |
| model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files] | |
| checkpoint = os.path.join( | |
| args.project_dir, | |
| "models", | |
| "prostate_segmentation_model.pt", | |
| ) | |
| model = model_config.get_parsed_content("network_def").to(device) | |
| inferer = model_config.get_parsed_content("inferer") | |
| model.load_state_dict(torch.load(checkpoint, map_location=device)) | |
| model.eval() | |
| keys = "image" | |
| transform = Compose( | |
| [ | |
| LoadImaged(keys=keys), | |
| EnsureChannelFirstd(keys=keys), | |
| Orientationd(keys=keys, axcodes="RAS"), | |
| Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode="bilinear"), | |
| ScaleIntensityd(keys=keys, minv=0, maxv=1), | |
| NormalizeIntensityd(keys=keys), | |
| EnsureTyped(keys=keys), | |
| ] | |
| ) | |
| logging.info("Starting prostate segmentation") | |
| for file in tqdm(files): | |
| data = {"image": os.path.join(args.t2_dir, file)} | |
| _, header_t2 = nrrd.read(data["image"]) | |
| transformed_data = transform(data) | |
| a = transformed_data | |
| with torch.no_grad(): | |
| images = a["image"].reshape(1, *(a["image"].shape)).to(device) | |
| data["pred"] = inferer(images, network=model) | |
| pred_img = data["pred"].argmax(1).cpu() | |
| model_output = {} | |
| model_output["image"] = MetaTensor(pred_img, meta=transformed_data["image"].meta) | |
| transformed_data["image"].data = model_output["image"].data | |
| temp = transform.inverse(transformed_data) | |
| pred_temp = temp["image"][0].numpy() | |
| pred_nrrd = np.round(pred_temp) | |
| nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0, 1)) | |
| top_slices = np.argsort(nonzero_counts)[-10:] | |
| output_ = np.zeros_like(pred_nrrd) | |
| output_[:, :, top_slices] = pred_nrrd[:, :, top_slices] | |
| nrrd.write(os.path.join(args.seg_dir, file), output_, header_t2) | |
| return args | |