| | import math |
| | import numpy as np |
| | import torch |
| |
|
| | import logging |
| | import os |
| | import sys |
| | from colorama import Fore, Style, init |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| | init(autoreset=True) |
| |
|
| | def nearest_power_of_two(x: int, round_up: bool = False) -> int: |
| | return ( |
| | 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x)) |
| | ) |
| |
|
| | def get_hankel(seq_len: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float32) -> torch.Tensor: |
| | entries = torch.arange(1, seq_len + 1, dtype=dtype, device=device) |
| | i_plus_j = entries[:, None] + entries[None, :] |
| |
|
| | if use_hankel_L: |
| | sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 |
| | denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) |
| | Z = sgn * (8.0 / denom) |
| | elif not use_hankel_L: |
| | Z = 2.0 / (i_plus_j**3 - i_plus_j) |
| | else: |
| | raise ValueError("use_hankel_L must be a boolean") |
| |
|
| | return Z |
| |
|
| |
|
| | class ColorFormatter(logging.Formatter): |
| | """ |
| | A custom log formatter that applies color based on the log level using the Colorama library. |
| | |
| | Attributes: |
| | LOG_COLORS (dict): A dictionary mapping log levels to their corresponding color codes. |
| | """ |
| |
|
| | |
| | LOG_COLORS = { |
| | logging.DEBUG: Fore.LIGHTMAGENTA_EX + Style.BRIGHT, |
| | logging.INFO: Fore.CYAN, |
| | logging.WARNING: Fore.YELLOW + Style.BRIGHT, |
| | logging.ERROR: Fore.RED + Style.BRIGHT, |
| | logging.CRITICAL: Fore.RED + Style.BRIGHT + Style.NORMAL, |
| | } |
| |
|
| | |
| | TIME_COLOR = Fore.GREEN |
| | FILE_COLOR = Fore.BLUE |
| | LEVEL_COLOR = Style.BRIGHT |
| |
|
| | def __init__(self, fmt=None): |
| | super().__init__(fmt or "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%Y-%m-%d %H:%M:%S") |
| |
|
| | def format(self, record): |
| | """ |
| | Formats a log record with the appropriate color based on the log level. |
| | |
| | Args: |
| | record (logging.LogRecord): The log record to format. |
| | |
| | Returns: |
| | str: The formatted log message with colors applied. |
| | """ |
| | |
| | level_color = self.LOG_COLORS.get(record.levelno, Fore.WHITE) |
| | time_str = f"{self.TIME_COLOR}{self.formatTime(record)}{Style.RESET_ALL}" |
| | levelname_str = f"{level_color}{record.levelname}{Style.RESET_ALL}" |
| | file_info_str = f"{self.FILE_COLOR}{record.filename}:{record.lineno}{Style.RESET_ALL}" |
| |
|
| | |
| | log_msg = f"{time_str} - {levelname_str} - {file_info_str} - {record.msg}" |
| | return log_msg |
| |
|
| | def setup_logger(): |
| | """ |
| | Sets up a logger with a custom color formatter that logs to standard output (stdout). |
| | |
| | The logger is configured with the ColorFormatter to format log messages with color based on the log level. |
| | The log level is set to INFO by default, but this can be changed to show more or less detailed messages. |
| | |
| | Returns: |
| | logging.Logger: A logger instance that logs formatted messages to stdout. |
| | """ |
| | handler = logging.StreamHandler(sys.stdout) |
| |
|
| | |
| | formatter = ColorFormatter() |
| | handler.setFormatter(formatter) |
| | logger = logging.getLogger(__name__) |
| | |
| | |
| | DEBUG = os.environ.get("DEBUG", "False").lower() in ("true", "1", "t") |
| | logger.setLevel(logging.DEBUG) if DEBUG else logger.setLevel(logging.INFO) |
| | logger.addHandler(handler) |
| | logger.propagate = False |
| |
|
| | return logger |
| |
|
| | logger = setup_logger() |
| |
|