| from urllib.request import urlopen |
| import torch |
| from torch import nn |
| import numpy as np |
| from skimage.morphology import label |
| import os |
| from HD_BET.paths import folder_with_parameter_files |
|
|
|
|
| def get_params_fname(fold): |
| return os.path.join(folder_with_parameter_files, "%d.model" % fold) |
|
|
|
|
| def maybe_download_parameters(fold=0, force_overwrite=False): |
| """ |
| Downloads the parameters for some fold if it is not present yet. |
| :param fold: |
| :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download |
| :return: |
| """ |
|
|
| assert 0 <= fold <= 4, "fold must be between 0 and 4" |
|
|
| if not os.path.isdir(folder_with_parameter_files): |
| maybe_mkdir_p(folder_with_parameter_files) |
|
|
| out_filename = get_params_fname(fold) |
|
|
| if force_overwrite and os.path.isfile(out_filename): |
| os.remove(out_filename) |
|
|
| if not os.path.isfile(out_filename): |
| url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold |
| print("Downloading", url, "...") |
| data = urlopen(url).read() |
| |
| with open(out_filename, 'wb') as f: |
| f.write(data) |
|
|
|
|
| def init_weights(module): |
| if isinstance(module, nn.Conv3d): |
| module.weight = nn.init.kaiming_normal(module.weight, a=1e-2) |
| if module.bias is not None: |
| module.bias = nn.init.constant(module.bias, 0) |
|
|
|
|
| def softmax_helper(x): |
| rpt = [1 for _ in range(len(x.size()))] |
| rpt[1] = x.size(1) |
| x_max = x.max(1, keepdim=True)[0].repeat(*rpt) |
| e_x = torch.exp(x - x_max) |
| return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) |
|
|
|
|
| class SetNetworkToVal(object): |
| def __init__(self, use_dropout_sampling=False, norm_use_average=True): |
| self.norm_use_average = norm_use_average |
| self.use_dropout_sampling = use_dropout_sampling |
|
|
| def __call__(self, module): |
| if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): |
| module.train(self.use_dropout_sampling) |
| elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ |
| isinstance(module, nn.InstanceNorm1d) \ |
| or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ |
| isinstance(module, nn.BatchNorm1d): |
| module.train(not self.norm_use_average) |
|
|
|
|
| def postprocess_prediction(seg): |
| |
| print("running postprocessing... ") |
| mask = seg != 0 |
| lbls = label(mask, connectivity=mask.ndim) |
| lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)] |
| largest_region = np.argmax(lbls_sizes[1:]) + 1 |
| seg[lbls != largest_region] = 0 |
| return seg |
|
|
|
|
| def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): |
| if join: |
| l = os.path.join |
| else: |
| l = lambda x, y: y |
| res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) |
| and (prefix is None or i.startswith(prefix)) |
| and (suffix is None or i.endswith(suffix))] |
| if sort: |
| res.sort() |
| return res |
|
|
|
|
| def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): |
| if join: |
| l = os.path.join |
| else: |
| l = lambda x, y: y |
| res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) |
| and (prefix is None or i.startswith(prefix)) |
| and (suffix is None or i.endswith(suffix))] |
| if sort: |
| res.sort() |
| return res |
|
|
|
|
| subfolders = subdirs |
|
|
|
|
| def maybe_mkdir_p(directory): |
| splits = directory.split("/")[1:] |
| for i in range(0, len(splits)): |
| if not os.path.isdir(os.path.join("", *splits[:i+1])): |
| os.mkdir(os.path.join("", *splits[:i+1])) |
|
|