| | from urllib.request import urlopen |
| | import torch |
| | from torch import nn |
| | import numpy as np |
| | from skimage.morphology import label |
| | import os |
| | from HD_BET.paths import folder_with_parameter_files |
| |
|
| |
|
| | def get_params_fname(fold): |
| | return os.path.join(folder_with_parameter_files, "%d.model" % fold) |
| |
|
| |
|
| | def maybe_download_parameters(fold=0, force_overwrite=False): |
| | """ |
| | Downloads the parameters for some fold if it is not present yet. |
| | :param fold: |
| | :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download |
| | :return: |
| | """ |
| |
|
| | assert 0 <= fold <= 4, "fold must be between 0 and 4" |
| |
|
| | if not os.path.isdir(folder_with_parameter_files): |
| | maybe_mkdir_p(folder_with_parameter_files) |
| |
|
| | out_filename = get_params_fname(fold) |
| |
|
| | if force_overwrite and os.path.isfile(out_filename): |
| | os.remove(out_filename) |
| |
|
| | if not os.path.isfile(out_filename): |
| | url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold |
| | print("Downloading", url, "...") |
| | data = urlopen(url).read() |
| | |
| | with open(out_filename, 'wb') as f: |
| | f.write(data) |
| |
|
| |
|
| | def init_weights(module): |
| | if isinstance(module, nn.Conv3d): |
| | module.weight = nn.init.kaiming_normal(module.weight, a=1e-2) |
| | if module.bias is not None: |
| | module.bias = nn.init.constant(module.bias, 0) |
| |
|
| |
|
| | def softmax_helper(x): |
| | rpt = [1 for _ in range(len(x.size()))] |
| | rpt[1] = x.size(1) |
| | x_max = x.max(1, keepdim=True)[0].repeat(*rpt) |
| | e_x = torch.exp(x - x_max) |
| | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) |
| |
|
| |
|
| | class SetNetworkToVal(object): |
| | def __init__(self, use_dropout_sampling=False, norm_use_average=True): |
| | self.norm_use_average = norm_use_average |
| | self.use_dropout_sampling = use_dropout_sampling |
| |
|
| | def __call__(self, module): |
| | if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): |
| | module.train(self.use_dropout_sampling) |
| | elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ |
| | isinstance(module, nn.InstanceNorm1d) \ |
| | or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ |
| | isinstance(module, nn.BatchNorm1d): |
| | module.train(not self.norm_use_average) |
| |
|
| |
|
| | def postprocess_prediction(seg): |
| | |
| | print("running postprocessing... ") |
| | mask = seg != 0 |
| | lbls = label(mask, connectivity=mask.ndim) |
| | lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)] |
| | largest_region = np.argmax(lbls_sizes[1:]) + 1 |
| | seg[lbls != largest_region] = 0 |
| | return seg |
| |
|
| |
|
| | def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): |
| | if join: |
| | l = os.path.join |
| | else: |
| | l = lambda x, y: y |
| | res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) |
| | and (prefix is None or i.startswith(prefix)) |
| | and (suffix is None or i.endswith(suffix))] |
| | if sort: |
| | res.sort() |
| | return res |
| |
|
| |
|
| | def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): |
| | if join: |
| | l = os.path.join |
| | else: |
| | l = lambda x, y: y |
| | res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) |
| | and (prefix is None or i.startswith(prefix)) |
| | and (suffix is None or i.endswith(suffix))] |
| | if sort: |
| | res.sort() |
| | return res |
| |
|
| |
|
| | subfolders = subdirs |
| |
|
| |
|
| | def maybe_mkdir_p(directory): |
| | splits = directory.split("/")[1:] |
| | for i in range(0, len(splits)): |
| | if not os.path.isdir(os.path.join("", *splits[:i+1])): |
| | os.mkdir(os.path.join("", *splits[:i+1])) |
| |
|