| |
| |
|
|
| imports: |
| - $import os |
| - $import json |
| - $import datetime |
| - $import torch |
| - $import glob |
|
|
| |
| image: $monai.utils.CommonKeys.IMAGE |
| label: $monai.utils.CommonKeys.LABEL |
| pred: $monai.utils.CommonKeys.PRED |
|
|
| |
| rank: 0 |
| is_not_rank0: '$@rank > 0' |
|
|
| |
| val_interval: 1 |
| ckpt_interval: 1 |
| rand_prob: 0.5 |
| batch_size: 5 |
| num_epochs: 10 |
| num_substeps: 1 |
| num_workers: 4 |
| learning_rate: 0.001 |
| num_classes: 4 |
| device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| bundle_root: . |
| ckpt_path: $@bundle_root + '/models/model.pt' |
| dataset_dir: $@bundle_root + '/data/train_data' |
| results_dir: $@bundle_root + '/results' |
| |
| output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')' |
|
|
| |
| network_def: |
| _target_: DenseNet121 |
| spatial_dims: 2 |
| in_channels: 1 |
| out_channels: '@num_classes' |
| network: $@network_def.to(@device) |
|
|
| |
| data_json: $@bundle_root + '/data/train_samples.json' |
| data_fp: "$open(@data_json,'r', encoding='utf8')" |
| data_dict: "$json.load(@data_fp)" |
| partitions: '$monai.data.partition_dataset(@data_dict, (4, 1), shuffle=True, seed=0)' |
| train_sub: '$@partitions[0]' |
| val_sub: '$@partitions[1]' |
|
|
| |
| base_transforms: |
| - _target_: LoadImaged |
| keys: '@image' |
| - _target_: EnsureChannelFirstd |
| keys: '@image' |
|
|
| |
| train_transforms: |
| - _target_: RandAxisFlipd |
| keys: '@image' |
| prob: '@rand_prob' |
| - _target_: RandRotate90d |
| keys: '@image' |
| prob: '@rand_prob' |
| - _target_: RandGaussianNoised |
| keys: '@image' |
| prob: '@rand_prob' |
| std: 0.05 |
| - _target_: ScaleIntensityd |
| keys: '@image' |
|
|
| |
| val_transforms: |
| - _target_: ScaleIntensityd |
| keys: '@image' |
|
|
| |
| preprocessing: |
| _target_: Compose |
| transforms: $@base_transforms + @train_transforms |
|
|
| val_preprocessing: |
| _target_: Compose |
| transforms: $@base_transforms + @val_transforms |
|
|
| |
| train_dataset: |
| _target_: Dataset |
| data: '@train_sub' |
| transform: '@preprocessing' |
|
|
| val_dataset: |
| _target_: Dataset |
| data: '@val_sub' |
| transform: '@val_preprocessing' |
|
|
| |
| train_dataloader: |
| _target_: ThreadDataLoader |
| dataset: '@train_dataset' |
| batch_size: '@batch_size' |
| repeats: '@num_substeps' |
| num_workers: '@num_workers' |
|
|
| val_dataloader: |
| _target_: DataLoader |
| dataset: '@val_dataset' |
| batch_size: '@batch_size' |
| num_workers: '@num_workers' |
|
|
| |
| lossfn: |
| _target_: torch.nn.CrossEntropyLoss |
| reduction: sum |
|
|
| |
| optimizer: |
| _target_: torch.optim.Adam |
| params: $@network.parameters() |
| lr: '@learning_rate' |
|
|
| |
| inferer: |
| _target_: SimpleInferer |
|
|
| |
| postprocessing: |
| _target_: Compose |
| transforms: |
| - _target_: Activationsd |
| keys: '@pred' |
| softmax: true |
| - _target_: AsDiscreted |
| keys: ['@pred', '@label'] |
| argmax: [true, false] |
| to_onehot: '@num_classes' |
| - _target_: ToTensord |
| keys: ['@pred', '@label'] |
| device: '@device' |
|
|
| |
| val_handlers: |
| - _target_: StatsHandler |
| name: null |
| output_transform: '$lambda x: None' |
| - _target_: LogfileHandler |
| output_dir: '@output_dir' |
| - _target_: CheckpointSaver |
| _disabled_: '@is_not_rank0' |
| save_dir: '@output_dir' |
| save_dict: |
| model: '@network' |
| save_interval: 0 |
| save_final: false |
| epoch_level: false |
| save_key_metric: true |
| key_metric_name: val_accuracy |
|
|
| |
| evaluator: |
| _target_: SupervisedEvaluator |
| device: '@device' |
| val_data_loader: '@val_dataloader' |
| network: '@network' |
| postprocessing: '@postprocessing' |
| key_val_metric: |
| val_accuracy: |
| _target_: ignite.metrics.Accuracy |
| output_transform: $monai.handlers.from_engine([@pred, @label]) |
| additional_metrics: |
| val_f1: |
| _target_: ConfusionMatrix |
| metric_name: 'f1 score' |
| output_transform: $monai.handlers.from_engine([@pred, @label]) |
| val_handlers: '@val_handlers' |
|
|
| |
| metriclogger: |
| _target_: MetricLogger |
| evaluator: '@evaluator' |
|
|
| handlers: |
| - '@metriclogger' |
| - _target_: CheckpointLoader |
| _disabled_: $not os.path.exists(@ckpt_path) |
| load_path: '@ckpt_path' |
| load_dict: |
| model: '@network' |
| - _target_: ValidationHandler |
| validator: '@evaluator' |
| epoch_level: true |
| interval: '@val_interval' |
| - _target_: CheckpointSaver |
| _disabled_: '@is_not_rank0' |
| save_dir: '@output_dir' |
| save_dict: |
| model: '@network' |
| logger: '@metriclogger' |
| save_interval: '@ckpt_interval' |
| save_final: true |
| epoch_level: true |
| - _target_: StatsHandler |
| name: null |
| tag_name: train_loss |
| output_transform: $monai.handlers.from_engine(['loss'], first=True) |
| - _target_: LogfileHandler |
| output_dir: '@output_dir' |
|
|
| |
| trainer: |
| _target_: SupervisedTrainer |
| max_epochs: '@num_epochs' |
| device: '@device' |
| train_data_loader: '@train_dataloader' |
| network: '@network' |
| inferer: '@inferer' |
| loss_function: '@lossfn' |
| optimizer: '@optimizer' |
| |
| key_train_metric: null |
| train_handlers: '@handlers' |
|
|
| initialize: |
| - "$monai.utils.set_determinism(seed=123)" |
| - "$setattr(torch.backends.cudnn, 'benchmark', True)" |
| run: |
| - "$@trainer.run()" |
|
|