""" Retry utilities for SPARKNET Provides robust retry mechanisms for LLM calls and external services Following FAANG best practices for fault tolerance """ import asyncio from functools import wraps from typing import Callable, Type, Tuple, Optional, Any from loguru import logger from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, before_sleep_log, RetryError, ) # Common transient exceptions that should trigger retries TRANSIENT_EXCEPTIONS: Tuple[Type[Exception], ...] = ( ConnectionError, TimeoutError, asyncio.TimeoutError, ) def with_retry( max_attempts: int = 3, min_wait: float = 1.0, max_wait: float = 60.0, exceptions: Optional[Tuple[Type[Exception], ...]] = None, ): """ Decorator for adding retry logic to functions. Uses exponential backoff with jitter for optimal retry behavior. Args: max_attempts: Maximum number of retry attempts min_wait: Minimum wait time between retries (seconds) max_wait: Maximum wait time between retries (seconds) exceptions: Tuple of exception types to retry on Returns: Decorated function with retry logic Example: @with_retry(max_attempts=3) async def call_llm(prompt: str) -> str: ... """ retry_exceptions = exceptions or TRANSIENT_EXCEPTIONS def decorator(func: Callable) -> Callable: @wraps(func) @retry( stop=stop_after_attempt(max_attempts), wait=wait_exponential(multiplier=1, min=min_wait, max=max_wait), retry=retry_if_exception_type(retry_exceptions), before_sleep=before_sleep_log(logger, log_level="WARNING"), reraise=True, ) async def async_wrapper(*args, **kwargs): return await func(*args, **kwargs) @wraps(func) @retry( stop=stop_after_attempt(max_attempts), wait=wait_exponential(multiplier=1, min=min_wait, max=max_wait), retry=retry_if_exception_type(retry_exceptions), before_sleep=before_sleep_log(logger, log_level="WARNING"), reraise=True, ) def sync_wrapper(*args, **kwargs): return func(*args, **kwargs) if asyncio.iscoroutinefunction(func): return async_wrapper return sync_wrapper return decorator def with_fallback( fallback_value: Any, exceptions: Optional[Tuple[Type[Exception], ...]] = None, log_level: str = "WARNING", ): """ Decorator that returns a fallback value on exception. Args: fallback_value: Value to return if function raises exceptions: Exception types to catch (default: all) log_level: Log level for exception logging Returns: Decorated function with fallback behavior Example: @with_fallback(fallback_value=[], exceptions=(ValueError,)) def get_results() -> List[str]: ... """ catch_exceptions = exceptions or (Exception,) def decorator(func: Callable) -> Callable: @wraps(func) async def async_wrapper(*args, **kwargs): try: return await func(*args, **kwargs) except catch_exceptions as e: getattr(logger, log_level.lower())( f"{func.__name__} failed with {type(e).__name__}: {e}, returning fallback" ) return fallback_value @wraps(func) def sync_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except catch_exceptions as e: getattr(logger, log_level.lower())( f"{func.__name__} failed with {type(e).__name__}: {e}, returning fallback" ) return fallback_value if asyncio.iscoroutinefunction(func): return async_wrapper return sync_wrapper return decorator class CircuitBreaker: """ Circuit breaker pattern implementation for protecting external services. States: - CLOSED: Normal operation, requests pass through - OPEN: Failures exceeded threshold, requests fail immediately - HALF_OPEN: Testing if service recovered Example: breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=30) async with breaker: result = await call_external_service() """ CLOSED = "CLOSED" OPEN = "OPEN" HALF_OPEN = "HALF_OPEN" def __init__( self, failure_threshold: int = 5, recovery_timeout: float = 30.0, name: str = "default", ): self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.name = name self._state = self.CLOSED self._failure_count = 0 self._last_failure_time: Optional[float] = None self._lock = asyncio.Lock() @property def state(self) -> str: return self._state async def __aenter__(self): async with self._lock: if self._state == self.OPEN: # Check if recovery timeout has passed if self._last_failure_time: elapsed = asyncio.get_event_loop().time() - self._last_failure_time if elapsed >= self.recovery_timeout: self._state = self.HALF_OPEN logger.info(f"Circuit breaker '{self.name}' entering HALF_OPEN state") else: raise CircuitBreakerError( f"Circuit breaker '{self.name}' is OPEN, retry in {self.recovery_timeout - elapsed:.1f}s" ) return self async def __aexit__(self, exc_type, exc_val, exc_tb): async with self._lock: if exc_type is not None: self._failure_count += 1 self._last_failure_time = asyncio.get_event_loop().time() if self._failure_count >= self.failure_threshold: self._state = self.OPEN logger.warning( f"Circuit breaker '{self.name}' OPENED after {self._failure_count} failures" ) else: # Success - reset on successful call if self._state == self.HALF_OPEN: self._state = self.CLOSED self._failure_count = 0 logger.info(f"Circuit breaker '{self.name}' recovered to CLOSED state") elif self._state == self.CLOSED: self._failure_count = 0 return False # Don't suppress exceptions class CircuitBreakerError(Exception): """Raised when circuit breaker is open.""" pass