| |
| |
| |
| |
| |
| |
| from rich import print |
| from dataclasses import dataclass |
| from pytorch_lightning.utilities import rank_zero_only |
| from typing import Union |
| from pytorch_lightning.callbacks.progress.rich_progress import * |
| from rich.console import Console, RenderableType |
| from rich.progress_bar import ProgressBar |
| from rich.style import Style |
| from rich.text import Text |
| from rich.progress import ( |
| BarColumn, |
| DownloadColumn, |
| Progress, |
| TaskID, |
| TextColumn, |
| TimeRemainingColumn, |
| TransferSpeedColumn, |
| ProgressColumn |
| ) |
| from rich import print, reconfigure |
|
|
| @rank_zero_only |
| def print_only(message: str): |
| print(message) |
|
|
| @dataclass |
| class RichProgressBarTheme: |
| """Styles to associate to different base components. |
| |
| Args: |
| description: Style for the progress bar description. For eg., Epoch x, Testing, etc. |
| progress_bar: Style for the bar in progress. |
| progress_bar_finished: Style for the finished progress bar. |
| progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed. |
| batch_progress: Style for the progress tracker (i.e 10/50 batches completed). |
| time: Style for the processed time and estimate time remaining. |
| processing_speed: Style for the speed of the batches being processed. |
| metrics: Style for the metrics |
| |
| https://rich.readthedocs.io/en/stable/style.html |
| """ |
|
|
| description: Union[str, Style] = "#FF4500" |
| progress_bar: Union[str, Style] = "#f92672" |
| progress_bar_finished: Union[str, Style] = "#b7cc8a" |
| progress_bar_pulse: Union[str, Style] = "#f92672" |
| batch_progress: Union[str, Style] = "#fc608a" |
| time: Union[str, Style] = "#45ada2" |
| processing_speed: Union[str, Style] = "#DC143C" |
| metrics: Union[str, Style] = "#228B22" |
|
|
| class BatchesProcessedColumn(ProgressColumn): |
| def __init__(self, style: Union[str, Style]): |
| self.style = style |
| super().__init__() |
|
|
| def render(self, task) -> RenderableType: |
| total = task.total if task.total != float("inf") else "--" |
| return Text(f"{int(task.completed)}/{int(total)}", style=self.style) |
|
|
| class MyMetricsTextColumn(ProgressColumn): |
| """A column containing text.""" |
|
|
| def __init__(self, style): |
| self._tasks = {} |
| self._current_task_id = 0 |
| self._metrics = {} |
| self._style = style |
| super().__init__() |
|
|
| def update(self, metrics): |
| |
| |
| |
| self._metrics = metrics |
|
|
| def render(self, task) -> Text: |
| text = "" |
| for k, v in self._metrics.items(): |
| text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " |
| return Text(text, justify="left", style=self._style) |
|
|
| class MyRichProgressBar(RichProgressBar): |
| """A progress bar prints metrics at the end of each epoch |
| """ |
|
|
| def _init_progress(self, trainer): |
| if self.is_enabled and (self.progress is None or self._progress_stopped): |
| self._reset_progress_bar_ids() |
| reconfigure(**self._console_kwargs) |
| |
| self._console: Console = Console(force_terminal=True) |
| self._console.clear_live() |
| self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) |
| self.progress = CustomProgress( |
| *self.configure_columns(trainer), |
| self._metric_component, |
| auto_refresh=False, |
| disable=self.is_disabled, |
| console=self._console, |
| ) |
| self.progress.start() |
| |
| self._progress_stopped = False |