| import math |
| import numpy as np |
| import torch |
|
|
| import logging |
| import os |
| import sys |
| from colorama import Fore, Style, init |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| init(autoreset=True) |
|
|
| def nearest_power_of_two(x: int, round_up: bool = False) -> int: |
| return ( |
| 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x)) |
| ) |
|
|
| def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor: |
| entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device) |
| i_plus_j = entries[:, None] + entries[None, :] |
|
|
| if use_hankel_L: |
| sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 |
| denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) |
| Z = sgn * (8.0 / denom) |
| elif not use_hankel_L: |
| Z = 2.0 / (i_plus_j**3 - i_plus_j) |
| else: |
| raise ValueError("use_hankel_L must be a boolean") |
|
|
| return Z |
|
|
|
|
| class ColorFormatter(logging.Formatter): |
| """ |
| A custom log formatter that applies color based on the log level using the Colorama library. |
| |
| Attributes: |
| LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes. |
| """ |
|
|
| |
| LOG_COLORS = { |
| logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT, |
| logging.INFO: Fore.CYAN, |
| logging.WARNING: Fore.YELLOW + Style.BRIGHT, |
| logging.ERROR: Fore.RED + Style.BRIGHT, |
| logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL, |
| } |
|
|
| |
| TIME_COLOR = Fore.GREEN |
| FILE_COLOR = Fore.BLUE |
| LEVEL_COLOR = Style.BRIGHT |
|
|
| def __init__(self, fmt=None): |
| super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S") |
|
|
| def format(self, record): |
| """ |
| Formats a log record with the appropriate color based on the log level. |
| |
| Args: |
| record (logging.LogRecord): The log record to format. |
| |
| Returns: |
| str: The formatted log message with colors applied. |
| """ |
| |
| level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE) |
| time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}" |
| levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}" |
| file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}" |
|
|
| |
| log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}" |
| return log_msg |
|
|
| def setup_logger(): |
| """ |
| Sets up a logger with a custom color formatter that logs to standard output (stdout). |
| |
| The logger is configured with the ColorFormatter to format log messages with color based on the log level. |
| The log level is set to INFO by default, but this can be changed to show more or less detailed messages. |
| |
| Returns: |
| logging.Logger: A logger instance that logs formatted messages to stdout. |
| """ |
| handler = logging.StreamHandler(sys.stdout) |
|
|
| |
| formatter = ColorFormatter() |
| handler.setFormatter(formatter) |
| logger = logging.getLogger(__name__) |
| |
| |
| DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t") |
| logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO) |
| logger.addHandler(handler) |
| logger.propagate = False |
|
|
| return logger |
|
|
| logger = setup_logger() |
|
|