| |
| |
| |
| |
| |
|
|
| |
|
|
| import logging |
| import os |
| import time |
| from typing import Any, List, Optional |
|
|
| import torch |
| from accelerate import Accelerator |
| from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase |
| from pytorch3d.implicitron.models.base_model import ImplicitronModelBase |
| from pytorch3d.implicitron.models.generic_model import EvaluationMode |
| from pytorch3d.implicitron.tools import model_io, vis_utils |
| from pytorch3d.implicitron.tools.config import ( |
| registry, |
| ReplaceableBase, |
| run_auto_creation, |
| ) |
| from pytorch3d.implicitron.tools.stats import Stats |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from .utils import seed_all_random_engines |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| class TrainingLoopBase(ReplaceableBase): |
| """ |
| Members: |
| evaluator: An EvaluatorBase instance, used to evaluate training results. |
| """ |
|
|
| evaluator: Optional[EvaluatorBase] |
| evaluator_class_type: Optional[str] = "ImplicitronEvaluator" |
|
|
| def run( |
| self, |
| train_loader: DataLoader, |
| val_loader: Optional[DataLoader], |
| test_loader: Optional[DataLoader], |
| train_dataset: Dataset, |
| model: ImplicitronModelBase, |
| optimizer: torch.optim.Optimizer, |
| scheduler: Any, |
| **kwargs, |
| ) -> None: |
| raise NotImplementedError() |
|
|
| def load_stats( |
| self, |
| log_vars: List[str], |
| exp_dir: str, |
| resume: bool = True, |
| resume_epoch: int = -1, |
| **kwargs, |
| ) -> Stats: |
| raise NotImplementedError() |
|
|
|
|
| @registry.register |
| class ImplicitronTrainingLoop(TrainingLoopBase): |
| """ |
| Members: |
| eval_only: If True, only run evaluation using the test dataloader. |
| max_epochs: Train for this many epochs. Note that if the model was |
| loaded from a checkpoint, we will restart training at the appropriate |
| epoch and run for (max_epochs - checkpoint_epoch) epochs. |
| store_checkpoints: If True, store model and optimizer state checkpoints. |
| store_checkpoints_purge: If >= 0, remove any checkpoints older or equal |
| to this many epochs. |
| test_interval: Evaluate on a test dataloader each `test_interval` epochs. |
| test_when_finished: If True, evaluate on a test dataloader when training |
| completes. |
| validation_interval: Validate each `validation_interval` epochs. |
| clip_grad: Optionally clip the gradient norms. |
| If set to a value <=0.0, no clipping |
| metric_print_interval: The batch interval at which the stats should be |
| logged. |
| visualize_interval: The batch interval at which the visualizations |
| should be plotted |
| visdom_env: The name of the Visdom environment to use for plotting. |
| visdom_port: The Visdom port. |
| visdom_server: Address of the Visdom server. |
| """ |
|
|
| |
| eval_only: bool = False |
| max_epochs: int = 1000 |
| store_checkpoints: bool = True |
| store_checkpoints_purge: int = 1 |
| test_interval: int = -1 |
| test_when_finished: bool = False |
| validation_interval: int = 1 |
|
|
| |
| clip_grad: float = 0.0 |
|
|
| |
| metric_print_interval: int = 5 |
| visualize_interval: int = 1000 |
| visdom_env: str = "" |
| visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097)) |
| visdom_server: str = "http://127.0.0.1" |
|
|
| def __post_init__(self): |
| run_auto_creation(self) |
|
|
| |
| |
| def run( |
| self, |
| *, |
| train_loader: DataLoader, |
| val_loader: Optional[DataLoader], |
| test_loader: Optional[DataLoader], |
| train_dataset: Dataset, |
| model: ImplicitronModelBase, |
| optimizer: torch.optim.Optimizer, |
| scheduler: Any, |
| accelerator: Optional[Accelerator], |
| device: torch.device, |
| exp_dir: str, |
| stats: Stats, |
| seed: int, |
| **kwargs, |
| ): |
| """ |
| Entry point to run the training and validation loops |
| based on the specified config file. |
| """ |
| start_epoch = stats.epoch + 1 |
| assert scheduler.last_epoch == stats.epoch + 1 |
| assert scheduler.last_epoch == start_epoch |
|
|
| |
| if self.eval_only: |
| if test_loader is not None: |
| |
| self.evaluator.run( |
| dataloader=test_loader, |
| device=device, |
| dump_to_json=True, |
| epoch=stats.epoch, |
| exp_dir=exp_dir, |
| model=model, |
| ) |
| return |
| else: |
| raise ValueError( |
| "Cannot evaluate and dump results to json, no test data provided." |
| ) |
|
|
| |
| for epoch in range(start_epoch, self.max_epochs): |
| |
| with stats: |
|
|
| |
| |
| seed_all_random_engines(seed + epoch) |
|
|
| cur_lr = float(scheduler.get_last_lr()[-1]) |
| logger.debug(f"scheduler lr = {cur_lr:1.2e}") |
|
|
| |
| self._training_or_validation_epoch( |
| accelerator=accelerator, |
| device=device, |
| epoch=epoch, |
| loader=train_loader, |
| model=model, |
| optimizer=optimizer, |
| stats=stats, |
| validation=False, |
| ) |
|
|
| |
| if val_loader is not None and epoch % self.validation_interval == 0: |
| self._training_or_validation_epoch( |
| accelerator=accelerator, |
| device=device, |
| epoch=epoch, |
| loader=val_loader, |
| model=model, |
| optimizer=optimizer, |
| stats=stats, |
| validation=True, |
| ) |
|
|
| |
| if ( |
| test_loader is not None |
| and self.test_interval > 0 |
| and epoch % self.test_interval == 0 |
| ): |
| self.evaluator.run( |
| device=device, |
| dataloader=test_loader, |
| model=model, |
| ) |
|
|
| assert stats.epoch == epoch, "inconsistent stats!" |
| self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats) |
|
|
| scheduler.step() |
| new_lr = float(scheduler.get_last_lr()[-1]) |
| if new_lr != cur_lr: |
| logger.info(f"LR change! {cur_lr} -> {new_lr}") |
|
|
| if self.test_when_finished: |
| if test_loader is not None: |
| self.evaluator.run( |
| device=device, |
| dump_to_json=True, |
| epoch=stats.epoch, |
| exp_dir=exp_dir, |
| dataloader=test_loader, |
| model=model, |
| ) |
| else: |
| raise ValueError( |
| "Cannot evaluate and dump results to json, no test data provided." |
| ) |
|
|
| def load_stats( |
| self, |
| log_vars: List[str], |
| exp_dir: str, |
| resume: bool = True, |
| resume_epoch: int = -1, |
| **kwargs, |
| ) -> Stats: |
| """ |
| Load Stats that correspond to the model's log_vars and resume_epoch. |
| |
| Args: |
| log_vars: A list of variable names to log. Should be a subset of the |
| `preds` returned by the forward function of the corresponding |
| ImplicitronModelBase instance. |
| exp_dir: Root experiment directory. |
| resume: If False, do not load stats from the checkpoint speci- |
| fied by resume and resume_epoch; instead, create a fresh stats object. |
| |
| stats: The stats structure (optionally loaded from checkpoint) |
| """ |
| |
| visdom_env_charts = ( |
| vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts" |
| ) |
| stats = Stats( |
| |
| list(log_vars), |
| plot_file=os.path.join(exp_dir, "train_stats.pdf"), |
| visdom_env=visdom_env_charts, |
| visdom_server=self.visdom_server, |
| visdom_port=self.visdom_port, |
| ) |
|
|
| model_path = None |
| if resume: |
| if resume_epoch > 0: |
| model_path = model_io.get_checkpoint(exp_dir, resume_epoch) |
| if not os.path.isfile(model_path): |
| raise FileNotFoundError( |
| f"Cannot find stats from epoch {resume_epoch}." |
| ) |
| else: |
| model_path = model_io.find_last_checkpoint(exp_dir) |
|
|
| if model_path is not None: |
| stats_path = model_io.get_stats_path(model_path) |
| stats_load = model_io.load_stats(stats_path) |
|
|
| |
| if resume: |
| if stats_load is None: |
| logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n") |
| last_epoch = model_io.parse_epoch_from_model_path(model_path) |
| logger.info(f"Estimated resume epoch = {last_epoch}") |
|
|
| |
| for _ in range(last_epoch + 1): |
| stats.new_epoch() |
| assert last_epoch == stats.epoch |
| else: |
| logger.info(f"Found previous stats in {stats_path} -> resuming.") |
| stats = stats_load |
|
|
| |
| stats.visdom_env = visdom_env_charts |
| stats.visdom_server = self.visdom_server |
| stats.visdom_port = self.visdom_port |
| stats.plot_file = os.path.join(exp_dir, "train_stats.pdf") |
| stats.synchronize_logged_vars(log_vars) |
| else: |
| logger.info("Clearing stats") |
|
|
| return stats |
|
|
| def _training_or_validation_epoch( |
| self, |
| epoch: int, |
| loader: DataLoader, |
| model: ImplicitronModelBase, |
| optimizer: torch.optim.Optimizer, |
| stats: Stats, |
| validation: bool, |
| *, |
| accelerator: Optional[Accelerator], |
| bp_var: str = "objective", |
| device: torch.device, |
| **kwargs, |
| ) -> None: |
| """ |
| This is the main loop for training and evaluation including: |
| model forward pass, loss computation, backward pass and visualization. |
| |
| Args: |
| epoch: The index of the current epoch |
| loader: The dataloader to use for the loop |
| model: The model module optionally loaded from checkpoint |
| optimizer: The optimizer module optionally loaded from checkpoint |
| stats: The stats struct, also optionally loaded from checkpoint |
| validation: If true, run the loop with the model in eval mode |
| and skip the backward pass |
| accelerator: An optional Accelerator instance. |
| bp_var: The name of the key in the model output `preds` dict which |
| should be used as the loss for the backward pass. |
| device: The device on which to run the model. |
| """ |
|
|
| if validation: |
| model.eval() |
| trainmode = "val" |
| else: |
| model.train() |
| trainmode = "train" |
|
|
| t_start = time.time() |
|
|
| |
| visdom_env_imgs = stats.visdom_env + "_images_" + trainmode |
| viz = vis_utils.get_visdom_connection( |
| server=stats.visdom_server, |
| port=stats.visdom_port, |
| ) |
|
|
| |
| n_batches = len(loader) |
| for it, net_input in enumerate(loader): |
| last_iter = it == n_batches - 1 |
|
|
| |
| net_input = net_input.to(device) |
|
|
| |
| if not validation: |
| optimizer.zero_grad() |
| preds = model( |
| **{**net_input, "evaluation_mode": EvaluationMode.TRAINING} |
| ) |
| else: |
| with torch.no_grad(): |
| preds = model( |
| **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION} |
| ) |
|
|
| |
| assert all(k not in preds for k in net_input.keys()) |
| |
| preds.update(net_input) |
|
|
| |
| stats.update(preds, time_start=t_start, stat_set=trainmode) |
| |
| assert stats.it[trainmode] == it, "inconsistent stat iteration number!" |
|
|
| |
| if it % self.metric_print_interval == 0 or last_iter: |
| std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches) |
| logger.info(std_out) |
|
|
| |
| if ( |
| (accelerator is None or accelerator.is_local_main_process) |
| and self.visualize_interval > 0 |
| and it % self.visualize_interval == 0 |
| ): |
| prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" |
| if hasattr(model, "visualize"): |
| model.visualize( |
| viz, |
| visdom_env_imgs, |
| preds, |
| prefix, |
| ) |
|
|
| |
| if not validation: |
| loss = preds[bp_var] |
| assert torch.isfinite(loss).all(), "Non-finite loss!" |
| |
| if accelerator is None: |
| loss.backward() |
| else: |
| accelerator.backward(loss) |
| if self.clip_grad > 0.0: |
| |
| total_norm = torch.nn.utils.clip_grad_norm( |
| model.parameters(), self.clip_grad |
| ) |
| if total_norm > self.clip_grad: |
| logger.debug( |
| f"Clipping gradient: {total_norm}" |
| + f" with coef {self.clip_grad / float(total_norm)}." |
| ) |
|
|
| optimizer.step() |
|
|
| def _checkpoint( |
| self, |
| accelerator: Optional[Accelerator], |
| epoch: int, |
| exp_dir: str, |
| model: ImplicitronModelBase, |
| optimizer: torch.optim.Optimizer, |
| stats: Stats, |
| ): |
| """ |
| Save a model and its corresponding Stats object to a file, if |
| `self.store_checkpoints` is True. In addition, if |
| `self.store_checkpoints_purge` is True, remove any checkpoints older |
| than `self.store_checkpoints_purge` epochs old. |
| """ |
| if self.store_checkpoints and ( |
| accelerator is None or accelerator.is_local_main_process |
| ): |
| if self.store_checkpoints_purge > 0: |
| for prev_epoch in range(epoch - self.store_checkpoints_purge): |
| model_io.purge_epoch(exp_dir, prev_epoch) |
| outfile = model_io.get_checkpoint(exp_dir, epoch) |
| unwrapped_model = ( |
| model if accelerator is None else accelerator.unwrap_model(model) |
| ) |
| model_io.safe_save_model( |
| unwrapped_model, stats, outfile, optimizer=optimizer |
| ) |
|
|