Anirudh Balaraman
add ci
caf6ee7
import argparse
import logging
import os
import nrrd
import numpy as np
import torch
from monai.bundle import ConfigParser
from monai.data import MetaTensor
from monai.transforms import (
Compose,
EnsureChannelFirstd,
EnsureTyped,
LoadImaged,
NormalizeIntensityd,
Orientationd,
ScaleIntensityd,
Spacingd,
)
from monai.utils import set_determinism
from tqdm import tqdm
set_determinism(43)
def get_segmask(args: argparse.Namespace) -> argparse.Namespace:
"""
Generate prostate segmentation masks using a pre-trained deep learning model.
This function performs inference on T2-weighted MRI images to segment the prostate gland.
It applies preprocessing transformations, runs the segmentation model, and saves the
predicted masks. Post-processing is applied to retain only the top 10 slices with
the highest non-zero voxel counts.
Args:
args: An arguments object containing:
- output_dir (str): Base output directory where segmentation masks will be saved
- project_dir (str): Root project directory containing model config and checkpoint
- t2_dir (str): Directory containing input T2-weighted MRI images in NRRD format
Returns:
args: The updated arguments object with seg_dir added, pointing to the
prostate_mask subdirectory within output_dir
Raises:
FileNotFoundError: If the model checkpoint or config file is not found
RuntimeError: If CUDA operations fail on GPU
Notes:
- Automatically selects GPU (CUDA) if available, otherwise uses CPU
- Applies MONAI transformations: loading, orientation (RAS), spacing (0.5mm isotropic),
intensity scaling and normalization
- Post-processing filters predictions to top 10 slices by non-zero voxel density
- Output masks are saved in NRRD format preserving original image headers
"""
args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
os.makedirs(args.seg_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
model_config_file = os.path.join(args.project_dir, "config", "inference.json")
model_config = ConfigParser()
model_config.read_config(model_config_file)
model_config["output_dir"] = args.seg_dir
model_config["dataset_dir"] = args.t2_dir
files = os.listdir(args.t2_dir)
model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
checkpoint = os.path.join(
args.project_dir,
"models",
"prostate_segmentation_model.pt",
)
model = model_config.get_parsed_content("network_def").to(device)
inferer = model_config.get_parsed_content("inferer")
model.load_state_dict(torch.load(checkpoint, map_location=device))
model.eval()
keys = "image"
transform = Compose(
[
LoadImaged(keys=keys),
EnsureChannelFirstd(keys=keys),
Orientationd(keys=keys, axcodes="RAS"),
Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode="bilinear"),
ScaleIntensityd(keys=keys, minv=0, maxv=1),
NormalizeIntensityd(keys=keys),
EnsureTyped(keys=keys),
]
)
logging.info("Starting prostate segmentation")
for file in tqdm(files):
data = {"image": os.path.join(args.t2_dir, file)}
_, header_t2 = nrrd.read(data["image"])
transformed_data = transform(data)
a = transformed_data
with torch.no_grad():
images = a["image"].reshape(1, *(a["image"].shape)).to(device)
data["pred"] = inferer(images, network=model)
pred_img = data["pred"].argmax(1).cpu()
model_output = {}
model_output["image"] = MetaTensor(pred_img, meta=transformed_data["image"].meta)
transformed_data["image"].data = model_output["image"].data
temp = transform.inverse(transformed_data)
pred_temp = temp["image"][0].numpy()
pred_nrrd = np.round(pred_temp)
nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0, 1))
top_slices = np.argsort(nonzero_counts)[-10:]
output_ = np.zeros_like(pred_nrrd)
output_[:, :, top_slices] = pred_nrrd[:, :, top_slices]
nrrd.write(os.path.join(args.seg_dir, file), output_, header_t2)
return args