| import os |
| import torch |
| import torch.nn as nn |
| from torch.nn.utils import clip_grad_norm_ |
| from torch_geometric.nn import to_hetero |
| import torch.optim as optim |
| from torch.utils.data.distributed import DistributedSampler |
| from loguru import logger |
| import numpy as np |
| from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ExponentialLR, ReduceLROnPlateau |
| import torch.distributed as dist |
| from src.utils import seed_worker |
|
|
| from torch_geometric.loader import DataLoader |
| from torch.utils.data import ConcatDataset |
| from torch.cuda.amp import autocast |
| from .utils import count_parameters, AverageMeter, AVGMeter, Reporter, Timer |
|
|
| |
|
|
| class Oven(object): |
|
|
| def __init__(self, cfg): |
| self.cfg = cfg |
| self.ngpus = cfg.get('ngpus', 1) |
|
|
|
|
|
|
| def _init_training_wt_checkpoint(self, filepath_ckp): |
| if not os.path.exists(filepath_ckp): |
| return np.Infinity, -1, 0 |
| |
| checkpoint_resum = torch.load(filepath_ckp) |
| self.model.load_state_dict(checkpoint_resum['model_state']) |
| epoch = checkpoint_resum['epoch'] |
| previous_best = checkpoint_resum['best_performance'] |
| previous_best_epoch = checkpoint_resum["best_epoch"] |
| previous_best_metrics = checkpoint_resum["local_best_metrics"] |
| return previous_best, previous_best_epoch, epoch, previous_best_metrics |
|
|
| def _init_optim(self): |
| if self.cfg['train'].get("optimizer_type", "Adam").lower() in "adam": |
| optimizer = optim.Adam(self.model.parameters(), |
| lr=float(self.cfg['train']['learning_rate']), |
| weight_decay=self.cfg['train'].get("weight_decay", 1e-5) |
| ) |
| else: |
| optimizer = optim.SGD(self.model.parameters(), |
| lr=self.cfg['train']['learning_rate'], |
| momentum=self.cfg['train'].get("momentum", 0.9), |
| weight_decay=self.cfg['train'].get("weight_decay", 1e-5)) |
|
|
| |
| if self.cfg['scheduler']['type'] == 'Cosine': |
| scheduler = CosineAnnealingLR(optimizer, |
| T_max=self.cfg['train']['epochs'], |
| eta_min=float(self.cfg['scheduler']['eta_min'])) |
| elif self.cfg['scheduler']['type'] == 'Exponential': |
| scheduler = ExponentialLR(optimizer, gamma=self.cfg['scheduler']['gamma'], last_epoch=-1, verbose=False) |
| elif self.cfg['scheduler']['type'] == 'ReduceLROnPlateau': |
| scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=1e-5) |
| else: |
| scheduler = None |
| return optimizer, scheduler |
|
|
| def _init_data(self): |
| train_dataset = self.get_dataset(**self.cfg['data']['train']) |
| val_dataset = self.get_dataset(**self.cfg['data']['val']) |
|
|
| if not self.cfg['distributed']: |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=self.cfg['data']['batch_size'], |
| num_workers=self.cfg['data']['num_workers'], |
| shuffle=True, |
| worker_init_fn=seed_worker, |
| drop_last=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=self.cfg['data'].get("batch_size_test", self.cfg['data']['batch_size']), |
| num_workers=self.cfg['data']['num_workers'], |
| shuffle=False, |
| drop_last=True, |
| worker_init_fn=seed_worker |
| ) |
| else: |
| train_sampler = DistributedSampler(train_dataset, shuffle=True) |
| train_loader = DataLoader(train_dataset, |
| batch_size=self.cfg['data']['batch_size'], |
| num_workers=self.cfg['data']['num_workers'], |
| sampler=train_sampler, |
| drop_last=True, |
| worker_init_fn=seed_worker) |
| |
| valid_sampler = DistributedSampler(val_dataset, shuffle=False) |
| val_loader = DataLoader(val_dataset, |
| batch_size=self.cfg['data'].get("batch_size_test", self.cfg['data']['batch_size']), |
| num_workers=self.cfg['data']['num_workers'], |
| sampler=valid_sampler, |
| drop_last=True, |
| worker_init_fn=seed_worker) |
|
|
| return train_loader, val_loader |
|
|
| def get_dataset(self, dataset_type, **kwargs): |
| if dataset_type == 'PowerFlowDataset': |
| from src.dataset.powerflow_dataset import PowerFlowDataset |
| return PowerFlowDataset( |
| data_root=kwargs['data_root'], |
| split_txt=kwargs['split_txt'], |
| pq_len=kwargs['pq_len'], |
| pv_len=kwargs['pv_len'], |
| slack_len=kwargs['slack_len'], |
| mask_num=kwargs['mask_num'] |
| ) |
|
|
|
|
| def summary_epoch(self, |
| epoch, |
| train_loss, train_matrix, |
| valid_loss, val_matrix, |
| timer, local_best, |
| local_best_ep=-1, |
| local_best_metrics={}, |
| local_best_ema=100, |
| local_best_ep_ema=-1, |
| local_best_metrics_ema = {}, |
| valid_loss_ema=None, val_matrix_ema=None): |
|
|
| if self.cfg['distributed']: |
| if dist.get_rank() == 0: |
| cur_lr = self.optim.param_groups[0]["lr"] |
| |
| self.reporter.record({'loss/train_loss': train_loss}, epoch=epoch) |
| self.reporter.record({'loss/val_loss': valid_loss}, epoch=epoch) |
| self.reporter.record({'lr': cur_lr}, epoch=epoch) |
| self.reporter.record(train_matrix, epoch=epoch) |
| self.reporter.record(val_matrix, epoch=epoch) |
|
|
| |
| logger.info(f"Epoch {str(epoch+1).zfill(3)}/{self.cfg['train']['epochs']}," |
| + f" lr: {cur_lr: .8f}, eta: {timer.eta}h, " |
| + f"train_loss: {train_loss.agg(): .5f}, " |
| + f"valid_loss: {valid_loss.agg(): .5f}") |
| |
| train_matrix_info = "Train: " |
| for key in train_matrix.keys(): |
| tkey = str(key).split("/")[-1] |
| train_matrix_info += f"{tkey}:{train_matrix[key].agg(): .6f} " |
| logger.info(f"\t{train_matrix_info}") |
|
|
| val_matrix_info = "ZTest: " |
| performance_record = dict() |
| for key in val_matrix.keys(): |
| tkey = str(key).split("/")[-1] |
| val_matrix_info += f"{tkey}:{val_matrix[key].agg(): .6f} " |
| performance_record[key] = val_matrix[key].agg() |
| logger.info(f"\t{val_matrix_info}") |
|
|
| if val_matrix_ema is not None: |
| val_matrix_info_ema = "ZTest-ema: " |
| performance_record_ema = dict() |
| for key in val_matrix_ema.keys(): |
| tkey = str(key).split("/")[-1] |
| val_matrix_info_ema += f"{tkey}:{val_matrix_ema[key].agg(): .6f} " |
| performance_record_ema[key] = val_matrix_ema[key].agg() |
| logger.info(f"\t{val_matrix_info_ema}") |
|
|
| checked_performance_ema = {x:y for x,y in performance_record_ema.items() if "rmse" in x} |
| best_performance_ema = max(checked_performance_ema.values()) |
| if best_performance_ema < local_best_ema: |
| local_best_ema = best_performance_ema |
| local_best_ep_ema = epoch |
| local_best_metrics_ema = checked_performance_ema |
| logger.info(f"\t ValOfEMA:{best_performance_ema:.6f}/{local_best_ema:.6f}, Epoch:{epoch+1}/{local_best_ep_ema+1}") |
| |
| |
| checked_performance = {x:y for x,y in performance_record.items() if "rmse" in x} |
| best_performance = max(checked_performance.values()) |
| if best_performance < local_best: |
| local_best = best_performance |
| local_best_metrics = checked_performance |
| local_best_ep = epoch |
| |
| torch.save(self.model.module, os.path.join(self.cfg['log_path'], 'ckpt_best.pt')) |
| |
| state = { |
| "epoch": epoch + 1, |
| |
| "model_state": self.model.state_dict(), |
| "optimizer_state": self.optim.state_dict(), |
| "scheduler_state": self.scheduler.state_dict(), |
| "best_performance": local_best, |
| "best_epoch":local_best_ep, |
| "local_best_metrics": local_best_metrics, |
| } |
| torch.save(state, os.path.join(self.cfg['log_path'], 'ckpt_latest.pt')) |
| logger.info(f"\tTime(ep):{int(timer.elapsed_time)}s, Val(curr/best):{best_performance:.6f}/{local_best:.6f}, Epoch(curr/best):{epoch+1}/{local_best_ep+1}") |
| |
| |
| else: |
| cur_lr = self.optim.param_groups[0]["lr"] |
| self.reporter.record({'loss/train_loss': train_loss}, epoch=epoch) |
| self.reporter.record({'loss/val_loss': valid_loss}, epoch=epoch) |
| self.reporter.record({'lr': cur_lr}, epoch=epoch) |
| self.reporter.record(train_matrix, epoch=epoch) |
| self.reporter.record(val_matrix, epoch=epoch) |
|
|
| logger.info(f"Epoch {epoch}/{self.cfg['train']['epochs']}," |
| + f" lr: {cur_lr: .8f}, eta: {timer.eta}h, " |
| + f"train_loss: {train_loss.agg(): .5f}, " |
| + f"valid_loss: {valid_loss.agg(): .5f}") |
|
|
| train_matrix_info = "Train: " |
| for key in train_matrix.keys(): |
| tkey = str(key).split("/")[-1] |
| train_matrix_info += f"{tkey}:{train_matrix[key].agg(): .8f} " |
| logger.info(f"\t{train_matrix_info}") |
|
|
| val_matrix_info = "ZTest: " |
| performance_record = dict() |
| for key in val_matrix.keys(): |
| tkey = str(key).split("/")[-1] |
| val_matrix_info += f"{tkey}:{val_matrix[key].agg(): .8f} " |
| performance_record[key] = val_matrix[key].agg() |
| logger.info(f"\t{val_matrix_info}") |
|
|
| if val_matrix_ema is not None: |
| val_matrix_info_ema = "ZTest-ema: " |
| performance_record_ema = dict() |
| for key in val_matrix_ema.keys(): |
| tkey = str(key).split("/")[-1] |
| val_matrix_info_ema += f"{tkey}:{val_matrix_ema[key].agg(): .6f} " |
| performance_record_ema[key] = val_matrix_ema[key].agg() |
| logger.info(f"\t{val_matrix_info_ema}") |
| |
| checked_performance_ema = {x:y for x,y in performance_record_ema.items() if "rmse" in x} |
| best_performance_ema = max(checked_performance_ema.values()) |
| if best_performance_ema < local_best_ema: |
| local_best_ema = best_performance_ema |
| local_best_metrics_ema = checked_performance_ema |
| local_best_ep_ema = epoch |
| logger.info(f"\t ValOfEMA:{best_performance_ema:.6f}/{local_best_ema:.6f}, Epoch:{epoch+1}/{local_best_ep_ema+1}") |
|
|
| |
| checked_performance = {x:y for x,y in performance_record.items() if "rmse" in x} |
| best_performance = max(checked_performance.values()) |
| if best_performance < local_best: |
| local_best = best_performance |
| local_best_ep = epoch |
| local_best_metrics = checked_performance |
| |
| torch.save(self.model, os.path.join(self.cfg['log_path'], 'ckpt_best.pt')) |
| state = { |
| "epoch": epoch + 1, |
| "model_state": self.model.state_dict(), |
| "optimizer_state": self.optim.state_dict(), |
| "scheduler_state": self.scheduler.state_dict(), |
| "best_performance": local_best, |
| "best_epoch":local_best_ep, |
| "local_best_metrics": local_best_metrics, |
| } |
| torch.save(state, os.path.join(self.cfg['log_path'], 'ckpt_latest.pt')) |
| logger.info(f"\tTime(ep):{int(timer.elapsed_time)}s, Val(curr/best):{best_performance:.6f}/{local_best:.6f}, Epoch(curr/best):{epoch+1}/{local_best_ep+1}") |
| |
| if val_matrix_ema is not None: |
| return local_best, local_best_ep, local_best_ema, local_best_ep_ema, local_best_metrics_ema |
| else: |
| return local_best, local_best_ep, local_best_metrics |
|
|
|
|