| | """ |
| | Comprehensive error handling and recovery utilities for BitTransformerLM. |
| | |
| | Provides robust error recovery mechanisms, graceful degradation, and detailed |
| | error logging for production deployments. |
| | """ |
| |
|
| | import logging |
| | import traceback |
| | import functools |
| | from typing import Dict, Any, Optional, Callable, Union, Type |
| | from contextlib import contextmanager |
| | import torch |
| | import numpy as np |
| |
|
| | from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike |
| |
|
| |
|
| | class BitTransformerError(Exception): |
| | """Base exception class for BitTransformerLM errors.""" |
| | |
| | def __init__(self, message: str, error_code: str = "BTLM_ERROR", |
| | context: Optional[Dict[str, Any]] = None): |
| | self.message = message |
| | self.error_code = error_code |
| | self.context = context or {} |
| | super().__init__(f"[{error_code}] {message}") |
| |
|
| |
|
| | class ModelError(BitTransformerError): |
| | """Errors related to model operations.""" |
| | pass |
| |
|
| |
|
| | class CompressionError(BitTransformerError): |
| | """Errors related to compression/decompression.""" |
| | pass |
| |
|
| |
|
| | class SafetyError(BitTransformerError): |
| | """Errors related to safety gates and telemetry.""" |
| | pass |
| |
|
| |
|
| | class DataError(BitTransformerError): |
| | """Errors related to data processing.""" |
| | pass |
| |
|
| |
|
| | class DistributedError(BitTransformerError): |
| | """Errors related to distributed training.""" |
| | pass |
| |
|
| |
|
| | class ErrorRecoveryManager: |
| | """Manages error recovery strategies and fallback mechanisms.""" |
| | |
| | def __init__(self, logger: Optional[logging.Logger] = None): |
| | self.logger = logger or logging.getLogger(__name__) |
| | self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {} |
| | self.error_counts: Dict[str, int] = {} |
| | self.max_retries = 3 |
| | |
| | def register_recovery_strategy(self, |
| | error_type: Type[Exception], |
| | strategy: RecoveryStrategy) -> None: |
| | """Register a recovery strategy for a specific error type.""" |
| | self.recovery_strategies[error_type] = strategy |
| | |
| | def handle_error(self, |
| | error: Exception, |
| | context: Optional[Dict[str, Any]] = None, |
| | allow_recovery: bool = True) -> Any: |
| | """Handle an error with potential recovery.""" |
| | error_key = f"{type(error).__name__}:{str(error)}" |
| | self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1 |
| | |
| | self.logger.error( |
| | f"Error occurred: {error}\n" |
| | f"Context: {context}\n" |
| | f"Traceback: {traceback.format_exc()}" |
| | ) |
| | |
| | if allow_recovery and self.error_counts[error_key] <= self.max_retries: |
| | |
| | for error_type, strategy in self.recovery_strategies.items(): |
| | if isinstance(error, error_type): |
| | try: |
| | self.logger.info(f"Attempting recovery for {type(error).__name__}") |
| | return strategy() |
| | except Exception as recovery_error: |
| | self.logger.error(f"Recovery failed: {recovery_error}") |
| | break |
| | |
| | |
| | raise error |
| |
|
| |
|
| | |
| | error_manager = ErrorRecoveryManager() |
| |
|
| |
|
| | def with_error_recovery(recovery_value: Any = None, |
| | max_retries: int = 3, |
| | error_types: Optional[tuple] = None): |
| | """Decorator for adding error recovery to functions.""" |
| | def decorator(func: Callable) -> Callable: |
| | @functools.wraps(func) |
| | def wrapper(*args, **kwargs): |
| | last_error = None |
| | |
| | for attempt in range(max_retries + 1): |
| | try: |
| | return func(*args, **kwargs) |
| | except Exception as e: |
| | last_error = e |
| | |
| | |
| | if error_types and not isinstance(e, error_types): |
| | raise |
| | |
| | if attempt < max_retries: |
| | error_manager.logger.warning( |
| | f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..." |
| | ) |
| | continue |
| | |
| | |
| | error_manager.logger.error( |
| | f"Function {func.__name__} failed after {max_retries + 1} attempts" |
| | ) |
| | break |
| | |
| | |
| | if recovery_value is not None: |
| | return recovery_value |
| | raise last_error |
| | |
| | return wrapper |
| | return decorator |
| |
|
| |
|
| | @contextmanager |
| | def safe_operation(operation_name: str, |
| | context: Optional[Dict[str, Any]] = None, |
| | recovery_value: Any = None): |
| | """Context manager for safe operations with error handling.""" |
| | try: |
| | error_manager.logger.debug(f"Starting operation: {operation_name}") |
| | yield |
| | error_manager.logger.debug(f"Completed operation: {operation_name}") |
| | except Exception as e: |
| | error_context = {"operation": operation_name} |
| | if context: |
| | error_context.update(context) |
| | |
| | try: |
| | return error_manager.handle_error(e, error_context) |
| | except: |
| | if recovery_value is not None: |
| | error_manager.logger.warning( |
| | f"Operation {operation_name} failed, using recovery value" |
| | ) |
| | return recovery_value |
| | raise |
| |
|
| |
|
| | def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor], |
| | fallback_value: Optional[torch.Tensor] = None) -> Callable: |
| | """Wrapper for tensor operations with safety checks.""" |
| | def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
| | |
| | if not isinstance(tensor, torch.Tensor): |
| | raise DataError("Input must be a torch.Tensor") |
| | |
| | if tensor.numel() == 0: |
| | if fallback_value is not None: |
| | return fallback_value |
| | raise DataError("Cannot operate on empty tensor") |
| | |
| | |
| | if torch.isnan(tensor).any(): |
| | error_manager.logger.warning("NaN values detected in tensor, attempting to clean") |
| | tensor = torch.nan_to_num(tensor, nan=0.0) |
| | |
| | if torch.isinf(tensor).any(): |
| | error_manager.logger.warning("Inf values detected in tensor, attempting to clean") |
| | tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6) |
| | |
| | try: |
| | return tensor_op(tensor, *args, **kwargs) |
| | except (RuntimeError, ValueError) as e: |
| | if "out of memory" in str(e).lower(): |
| | |
| | error_manager.logger.warning("OOM detected, attempting chunked operation") |
| | return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs) |
| | elif "device" in str(e).lower(): |
| | |
| | error_manager.logger.warning("Device mismatch, attempting CPU fallback") |
| | return tensor_op(tensor.cpu(), *args, **kwargs) |
| | else: |
| | raise |
| | |
| | return wrapper |
| |
|
| |
|
| | def _chunked_tensor_operation(tensor_op: Callable, |
| | tensor: torch.Tensor, |
| | chunk_size: int = 1024, |
| | *args, **kwargs) -> torch.Tensor: |
| | """Execute tensor operation in chunks to avoid OOM.""" |
| | if tensor.size(0) <= chunk_size: |
| | return tensor_op(tensor, *args, **kwargs) |
| | |
| | results = [] |
| | for i in range(0, tensor.size(0), chunk_size): |
| | chunk = tensor[i:i + chunk_size] |
| | chunk_result = tensor_op(chunk, *args, **kwargs) |
| | results.append(chunk_result) |
| | |
| | return torch.cat(results, dim=0) |
| |
|
| |
|
| | def validate_model_inputs(inputs: torch.Tensor, |
| | max_seq_len: int = 8192, |
| | expected_dtype: torch.dtype = torch.long) -> torch.Tensor: |
| | """Validate and sanitize model inputs.""" |
| | if not isinstance(inputs, torch.Tensor): |
| | raise DataError("Model inputs must be torch.Tensor") |
| | |
| | |
| | if inputs.dim() == 1: |
| | inputs = inputs.unsqueeze(0) |
| | elif inputs.dim() > 2: |
| | raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}") |
| | |
| | |
| | if inputs.size(-1) > max_seq_len: |
| | error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating") |
| | inputs = inputs[:, :max_seq_len] |
| | |
| | |
| | if inputs.dtype != expected_dtype: |
| | error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}") |
| | inputs = inputs.to(expected_dtype) |
| | |
| | |
| | if expected_dtype == torch.long: |
| | invalid_values = (inputs < 0) | (inputs > 1) |
| | if invalid_values.any(): |
| | error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]") |
| | inputs = torch.clamp(inputs, 0, 1) |
| | |
| | return inputs |
| |
|
| |
|
| | def safe_model_forward(model: torch.nn.Module, |
| | inputs: torch.Tensor, |
| | **kwargs) -> torch.Tensor: |
| | """Safely execute model forward pass with error recovery.""" |
| | inputs = validate_model_inputs(inputs) |
| | |
| | try: |
| | with safe_operation("model_forward"): |
| | return model(inputs, **kwargs) |
| | except RuntimeError as e: |
| | if "out of memory" in str(e).lower(): |
| | |
| | error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing") |
| | from torch.utils.checkpoint import checkpoint |
| | return checkpoint(model, inputs, **kwargs) |
| | elif "device" in str(e).lower(): |
| | |
| | device = next(model.parameters()).device |
| | inputs = inputs.to(device) |
| | return model(inputs, **kwargs) |
| | else: |
| | raise |
| |
|
| |
|
| | def recovery_checkpoint_save(model: torch.nn.Module, |
| | path: str, |
| | additional_data: Optional[Dict[str, Any]] = None) -> bool: |
| | """Save model checkpoint with error recovery.""" |
| | try: |
| | checkpoint_data = { |
| | 'model_state_dict': model.state_dict(), |
| | 'timestamp': torch.tensor(0), |
| | } |
| | if additional_data: |
| | checkpoint_data.update(additional_data) |
| | |
| | torch.save(checkpoint_data, path, _use_new_zipfile_serialization=True) |
| | error_manager.logger.info(f"Checkpoint saved successfully to {path}") |
| | return True |
| | |
| | except Exception as e: |
| | error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}") |
| | |
| | |
| | backup_path = path + ".backup" |
| | try: |
| | torch.save(checkpoint_data, backup_path) |
| | error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}") |
| | return True |
| | except Exception as backup_e: |
| | error_manager.logger.error(f"Backup save also failed: {backup_e}") |
| | return False |
| |
|
| |
|
| | def setup_error_logging(log_level: LogLevel = "INFO", |
| | log_file: Optional[str] = None) -> logging.Logger: |
| | """Set up comprehensive error logging.""" |
| | logger = logging.getLogger("BitTransformerLM") |
| | logger.setLevel(getattr(logging, log_level)) |
| | |
| | |
| | console_handler = logging.StreamHandler() |
| | console_formatter = logging.Formatter( |
| | '%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | console_handler.setFormatter(console_formatter) |
| | logger.addHandler(console_handler) |
| | |
| | |
| | if log_file: |
| | file_handler = logging.FileHandler(log_file) |
| | file_formatter = logging.Formatter( |
| | '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s' |
| | ) |
| | file_handler.setFormatter(file_formatter) |
| | logger.addHandler(file_handler) |
| | |
| | return logger |
| |
|
| |
|
| | |
| | def default_tensor_recovery() -> torch.Tensor: |
| | """Default recovery strategy for tensor operations.""" |
| | return torch.zeros(1, dtype=torch.long) |
| |
|
| |
|
| | def default_model_recovery() -> Dict[str, torch.Tensor]: |
| | """Default recovery strategy for model operations.""" |
| | return {"output": torch.zeros(1, dtype=torch.float32)} |
| |
|
| |
|
| | |
| | error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery) |
| | error_manager.register_recovery_strategy(ModelError, default_model_recovery) |