| | """A simple progress bar for the console.""" |
| |
|
| | import threading |
| | from typing import Any, Dict, Optional, Sequence |
| | from uuid import UUID |
| |
|
| | from langchain_core.callbacks import base as base_callbacks |
| | from langchain_core.documents import Document |
| | from langchain_core.outputs import LLMResult |
| |
|
| |
|
| | class ProgressBarCallback(base_callbacks.BaseCallbackHandler): |
| | """A simple progress bar for the console.""" |
| |
|
| | def __init__(self, total: int, ncols: int = 50, **kwargs: Any): |
| | """Initialize the progress bar. |
| | |
| | Args: |
| | total: int, the total number of items to be processed. |
| | ncols: int, the character width of the progress bar. |
| | """ |
| | self.total = total |
| | self.ncols = ncols |
| | self.counter = 0 |
| | self.lock = threading.Lock() |
| | self._print_bar() |
| |
|
| | def increment(self) -> None: |
| | """Increment the counter and update the progress bar.""" |
| | with self.lock: |
| | self.counter += 1 |
| | self._print_bar() |
| |
|
| | def _print_bar(self) -> None: |
| | """Print the progress bar to the console.""" |
| | progress = self.counter / self.total |
| | arrow = "-" * int(round(progress * self.ncols) - 1) + ">" |
| | spaces = " " * (self.ncols - len(arrow)) |
| | print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="") |
| |
|
| | def on_chain_error( |
| | self, |
| | error: BaseException, |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|
| | def on_chain_end( |
| | self, |
| | outputs: Dict[str, Any], |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|
| | def on_retriever_error( |
| | self, |
| | error: BaseException, |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|
| | def on_retriever_end( |
| | self, |
| | documents: Sequence[Document], |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|
| | def on_llm_error( |
| | self, |
| | error: BaseException, |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|
| | def on_llm_end( |
| | self, |
| | response: LLMResult, |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|
| | def on_tool_error( |
| | self, |
| | error: BaseException, |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|
| | def on_tool_end( |
| | self, |
| | output: str, |
| | *, |
| | run_id: UUID, |
| | parent_run_id: Optional[UUID] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | if parent_run_id is None: |
| | self.increment() |
| |
|