File size: 4,445 Bytes
caf6ee7
 
906fcb9
caf6ee7
906fcb9
caf6ee7
906fcb9
caf6ee7
 
906fcb9
 
caf6ee7
 
906fcb9
 
 
caf6ee7
906fcb9
 
caf6ee7
 
1baebae
 
 
 
caf6ee7
1baebae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906fcb9
 
 
 
 
1baebae
 
 
906fcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1baebae
906fcb9
 
 
 
 
 
 
 
 
 
 
 
 
1baebae
 
906fcb9
 
1baebae
906fcb9
1baebae
906fcb9
1baebae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import logging
import os

import nrrd
import numpy as np
import torch
from monai.bundle import ConfigParser
from monai.data import MetaTensor
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    ScaleIntensityd,
    Spacingd,
)
from monai.utils import set_determinism
from tqdm import tqdm

set_determinism(43)


def get_segmask(args: argparse.Namespace) -> argparse.Namespace:
    """
    Generate prostate segmentation masks using a pre-trained deep learning model.
    This function performs inference on T2-weighted MRI images to segment the prostate gland.
    It applies preprocessing transformations, runs the segmentation model, and saves the
    predicted masks. Post-processing is applied to retain only the top 10 slices with
    the highest non-zero voxel counts.
    Args:
        args: An arguments object containing:
            - output_dir (str): Base output directory where segmentation masks will be saved
            - project_dir (str): Root project directory containing model config and checkpoint
            - t2_dir (str): Directory containing input T2-weighted MRI images in NRRD format
    Returns:
        args: The updated arguments object with seg_dir added, pointing to the
              prostate_mask subdirectory within output_dir
    Raises:
        FileNotFoundError: If the model checkpoint or config file is not found
        RuntimeError: If CUDA operations fail on GPU
    Notes:
        - Automatically selects GPU (CUDA) if available, otherwise uses CPU
        - Applies MONAI transformations: loading, orientation (RAS), spacing (0.5mm isotropic),
          intensity scaling and normalization
        - Post-processing filters predictions to top 10 slices by non-zero voxel density
        - Output masks are saved in NRRD format preserving original image headers
    """

    args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
    os.makedirs(args.seg_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    model_config_file = os.path.join(args.project_dir, "config", "inference.json")
    model_config = ConfigParser()
    model_config.read_config(model_config_file)
    model_config["output_dir"] = args.seg_dir
    model_config["dataset_dir"] = args.t2_dir
    files = os.listdir(args.t2_dir)
    model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
    checkpoint = os.path.join(
        args.project_dir,
        "models",
        "prostate_segmentation_model.pt",
    )
    model = model_config.get_parsed_content("network_def").to(device)
    inferer = model_config.get_parsed_content("inferer")
    model.load_state_dict(torch.load(checkpoint, map_location=device))
    model.eval()

    keys = "image"
    transform = Compose(
        [
            LoadImaged(keys=keys),
            EnsureChannelFirstd(keys=keys),
            Orientationd(keys=keys, axcodes="RAS"),
            Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode="bilinear"),
            ScaleIntensityd(keys=keys, minv=0, maxv=1),
            NormalizeIntensityd(keys=keys),
            EnsureTyped(keys=keys),
        ]
    )
    logging.info("Starting prostate segmentation")
    for file in tqdm(files):
        data = {"image": os.path.join(args.t2_dir, file)}
        _, header_t2 = nrrd.read(data["image"])
        transformed_data = transform(data)
        a = transformed_data
        with torch.no_grad():
            images = a["image"].reshape(1, *(a["image"].shape)).to(device)
            data["pred"] = inferer(images, network=model)
        pred_img = data["pred"].argmax(1).cpu()

        model_output = {}
        model_output["image"] = MetaTensor(pred_img, meta=transformed_data["image"].meta)
        transformed_data["image"].data = model_output["image"].data
        temp = transform.inverse(transformed_data)
        pred_temp = temp["image"][0].numpy()
        pred_nrrd = np.round(pred_temp)

        nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0, 1))
        top_slices = np.argsort(nonzero_counts)[-10:]
        output_ = np.zeros_like(pred_nrrd)
        output_[:, :, top_slices] = pred_nrrd[:, :, top_slices]

        nrrd.write(os.path.join(args.seg_dir, file), output_, header_t2)

    return args