| import os |
| |
| |
| |
| os.environ["nnUNet_raw"] = "./nnunet_raw" |
| os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed" |
| os.environ["nnUNet_results"] = "./nnunet_results" |
| from typing import Dict |
| import tempfile |
| import subprocess |
| import SimpleITK as sitk |
| from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor |
| from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ |
| save_json |
| |
| import numpy as np |
|
|
| from base_algorithm import BaseSynthradAlgorithm |
| from revert_normalisation import get_ct_normalisation_values, revert_normalisation_single_modified |
|
|
| import torch |
| import shutil |
|
|
| import os |
|
|
| os.environ["OPENBLAS_NUM_THREADS"] = "1" |
|
|
|
|
|
|
|
|
| class SynthradAlgorithm(BaseSynthradAlgorithm): |
| """ |
| This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image. |
| |
| Author: Suraj Pai (b.pai@maastrichtuniversity.nl) |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def predict(self, input_dict: Dict[str, sitk.Image]) -> sitk.Image: |
| assert list(input_dict.keys()) == ["image", "mask", "region"] |
|
|
| region = input_dict["region"] |
| mr_sitk = input_dict["image"] |
| mask_sitk = input_dict["mask"] |
|
|
| mr_np = sitk.GetArrayFromImage(mr_sitk).astype("float32") |
| mask_np = sitk.GetArrayFromImage(mask_sitk).astype("float32") |
|
|
| mr_np[mask_np == 0] = 0 |
|
|
| preprocessed_mr_sitk = sitk.GetImageFromArray(mr_np) |
| preprocessed_mr_sitk.CopyInformation(mr_sitk) |
|
|
| if region == "Head and Neck": |
| dataset_name = "Dataset262" |
| result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
| plans_path = "./262_gt_nnUNetResEncUNetLPlans.json" |
| if region == "Abdomen": |
| dataset_name = "Dataset260" |
| result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
| plans_path = "./260_gt_nnUNetResEncUNetLPlans.json" |
| if region == "Thorax": |
| dataset_name = "Dataset264" |
| result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
| plans_path = "./264_gt_nnUNetResEncUNetLPlans.json" |
|
|
|
|
|
|
| predictor = nnUNetPredictor( |
| tile_step_size=0.5, |
| use_gaussian=True, |
| use_mirroring=True, |
| perform_everything_on_device=True, |
| device=torch.device('cuda', 0), |
| verbose=True, |
| verbose_preprocessing=True, |
| allow_tqdm=True |
| ) |
| predictor.initialize_from_trained_model_folder( |
| join(os.environ["nnUNet_results"], f'{dataset_name}/{result_folder}'), |
| use_folds=(0, 1, 2, 3, 4), |
| checkpoint_name='checkpoint_final.pth', |
| ) |
|
|
| sitk_spacing = mr_sitk.GetSpacing() |
| sitk_origin = mr_sitk.GetOrigin() |
| sitk_dir = mr_sitk.GetDirection() |
|
|
| props = { |
| 'sitk_stuff': { |
| 'spacing': tuple(sitk_spacing), |
| 'origin': tuple(sitk_origin), |
| 'direction': tuple(sitk_dir), |
| }, |
| 'spacing': [sitk_spacing[2], sitk_spacing[1], sitk_spacing[0]] |
| } |
|
|
| img = sitk.GetArrayFromImage(mr_sitk).astype(np.float32) |
| img = np.expand_dims(img, 0) |
| ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False) |
| |
| pred_path = "./TRUNCATED.nii.gz" |
| pred_sitk = sitk.ReadImage(pred_path) |
|
|
| ct_mean, ct_std = get_ct_normalisation_values(plans_path) |
| mask_sitk = sitk.Cast(mask_sitk, sitk.sitkUInt8) |
|
|
| pred_sitk = revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=mask_sitk) |
| os.remove(pred_path) |
| shutil.rmtree("./imagesTs", ignore_errors=True) |
| shutil.rmtree("./predictions", ignore_errors=True) |
| return pred_sitk |
|
|
| if __name__ == '__main__': |
| |
| |
| SynthradAlgorithm().process() |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |