| import torch |
| import pytorch_lightning as pl |
| from pathlib import Path |
| from typing import Any |
| import torchvision |
| import wandb |
|
|
|
|
| class EvalSaveCallback(pl.Callback): |
|
|
| def __init__(self, save_dir: Path) -> None: |
| super().__init__() |
| self.save_dir = save_dir |
|
|
| def save(self, outputs, batch, batch_idx): |
| name = batch['name'] |
|
|
| filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt" |
| torch.save({ |
| "fpv": batch['image'], |
| "seg_masks": batch['seg_masks'], |
| 'name': name, |
| "output": outputs["output"], |
| "valid_bev": outputs["valid_bev"], |
| }, filename) |
|
|
| def on_test_batch_end(self, trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: torch.Tensor | Any | None, |
| batch: Any, |
| batch_idx: int, |
| dataloader_idx: int = 0) -> None: |
| if not outputs: |
| return |
|
|
| self.save(outputs, batch, batch_idx) |
|
|
| def on_validation_batch_end(self, trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: torch.Tensor | Any | None, |
| batch: Any, |
| batch_idx: int, |
| dataloader_idx: int = 0) -> None: |
| if not outputs: |
|
|
| return |
|
|
| self.save(outputs, batch, batch_idx) |
|
|
|
|
| class ImageLoggerCallback(pl.Callback): |
| def __init__(self, num_classes): |
| super().__init__() |
| self.num_classes = num_classes |
|
|
| def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"): |
| fpv_rgb = batch["image"] |
| fpv_grid = torchvision.utils.make_grid( |
| fpv_rgb, nrow=8, normalize=False) |
| images = [ |
| wandb.Image(fpv_grid, caption="fpv") |
| ] |
|
|
| pred = outputs['output'].permute(0, 2, 3, 1) |
| pred[outputs["valid_bev"][..., :-1] == 0] = 0 |
| pred = (pred > 0.5).float() |
| pred = pred.permute(0, 3, 1, 2) |
|
|
| for i in range(self.num_classes): |
| gt_class_i = batch['seg_masks'][..., i] |
| gt_class_i_grid = torchvision.utils.make_grid( |
| gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) |
| pred_class_i = pred[:, i] |
| pred_class_i_grid = torchvision.utils.make_grid( |
| pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) |
|
|
| images += [ |
| wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"), |
| wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}") |
| ] |
|
|
| trainer.logger.experiment.log( |
| { |
| "{}/images".format(mode): images |
| } |
| ) |
|
|
| def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): |
| if batch_idx == 0: |
| with torch.no_grad(): |
| outputs = pl_module(batch) |
| self.log_image(trainer, pl_module, outputs, |
| batch, batch_idx, mode="val") |
|
|
| def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): |
| if batch_idx == 0: |
| pl_module.eval() |
|
|
| with torch.no_grad(): |
| outputs = pl_module(batch) |
|
|
| self.log_image(trainer, pl_module, outputs, |
| batch, batch_idx, mode="train") |
|
|
| pl_module.train() |
|
|