|
|
""" |
|
|
Retry utilities for SPARKNET |
|
|
Provides robust retry mechanisms for LLM calls and external services |
|
|
Following FAANG best practices for fault tolerance |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from functools import wraps |
|
|
from typing import Callable, Type, Tuple, Optional, Any |
|
|
from loguru import logger |
|
|
from tenacity import ( |
|
|
retry, |
|
|
stop_after_attempt, |
|
|
wait_exponential, |
|
|
retry_if_exception_type, |
|
|
before_sleep_log, |
|
|
RetryError, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
TRANSIENT_EXCEPTIONS: Tuple[Type[Exception], ...] = ( |
|
|
ConnectionError, |
|
|
TimeoutError, |
|
|
asyncio.TimeoutError, |
|
|
) |
|
|
|
|
|
|
|
|
def with_retry( |
|
|
max_attempts: int = 3, |
|
|
min_wait: float = 1.0, |
|
|
max_wait: float = 60.0, |
|
|
exceptions: Optional[Tuple[Type[Exception], ...]] = None, |
|
|
): |
|
|
""" |
|
|
Decorator for adding retry logic to functions. |
|
|
|
|
|
Uses exponential backoff with jitter for optimal retry behavior. |
|
|
|
|
|
Args: |
|
|
max_attempts: Maximum number of retry attempts |
|
|
min_wait: Minimum wait time between retries (seconds) |
|
|
max_wait: Maximum wait time between retries (seconds) |
|
|
exceptions: Tuple of exception types to retry on |
|
|
|
|
|
Returns: |
|
|
Decorated function with retry logic |
|
|
|
|
|
Example: |
|
|
@with_retry(max_attempts=3) |
|
|
async def call_llm(prompt: str) -> str: |
|
|
... |
|
|
""" |
|
|
retry_exceptions = exceptions or TRANSIENT_EXCEPTIONS |
|
|
|
|
|
def decorator(func: Callable) -> Callable: |
|
|
@wraps(func) |
|
|
@retry( |
|
|
stop=stop_after_attempt(max_attempts), |
|
|
wait=wait_exponential(multiplier=1, min=min_wait, max=max_wait), |
|
|
retry=retry_if_exception_type(retry_exceptions), |
|
|
before_sleep=before_sleep_log(logger, log_level="WARNING"), |
|
|
reraise=True, |
|
|
) |
|
|
async def async_wrapper(*args, **kwargs): |
|
|
return await func(*args, **kwargs) |
|
|
|
|
|
@wraps(func) |
|
|
@retry( |
|
|
stop=stop_after_attempt(max_attempts), |
|
|
wait=wait_exponential(multiplier=1, min=min_wait, max=max_wait), |
|
|
retry=retry_if_exception_type(retry_exceptions), |
|
|
before_sleep=before_sleep_log(logger, log_level="WARNING"), |
|
|
reraise=True, |
|
|
) |
|
|
def sync_wrapper(*args, **kwargs): |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
if asyncio.iscoroutinefunction(func): |
|
|
return async_wrapper |
|
|
return sync_wrapper |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
def with_fallback( |
|
|
fallback_value: Any, |
|
|
exceptions: Optional[Tuple[Type[Exception], ...]] = None, |
|
|
log_level: str = "WARNING", |
|
|
): |
|
|
""" |
|
|
Decorator that returns a fallback value on exception. |
|
|
|
|
|
Args: |
|
|
fallback_value: Value to return if function raises |
|
|
exceptions: Exception types to catch (default: all) |
|
|
log_level: Log level for exception logging |
|
|
|
|
|
Returns: |
|
|
Decorated function with fallback behavior |
|
|
|
|
|
Example: |
|
|
@with_fallback(fallback_value=[], exceptions=(ValueError,)) |
|
|
def get_results() -> List[str]: |
|
|
... |
|
|
""" |
|
|
catch_exceptions = exceptions or (Exception,) |
|
|
|
|
|
def decorator(func: Callable) -> Callable: |
|
|
@wraps(func) |
|
|
async def async_wrapper(*args, **kwargs): |
|
|
try: |
|
|
return await func(*args, **kwargs) |
|
|
except catch_exceptions as e: |
|
|
getattr(logger, log_level.lower())( |
|
|
f"{func.__name__} failed with {type(e).__name__}: {e}, returning fallback" |
|
|
) |
|
|
return fallback_value |
|
|
|
|
|
@wraps(func) |
|
|
def sync_wrapper(*args, **kwargs): |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
except catch_exceptions as e: |
|
|
getattr(logger, log_level.lower())( |
|
|
f"{func.__name__} failed with {type(e).__name__}: {e}, returning fallback" |
|
|
) |
|
|
return fallback_value |
|
|
|
|
|
if asyncio.iscoroutinefunction(func): |
|
|
return async_wrapper |
|
|
return sync_wrapper |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
class CircuitBreaker: |
|
|
""" |
|
|
Circuit breaker pattern implementation for protecting external services. |
|
|
|
|
|
States: |
|
|
- CLOSED: Normal operation, requests pass through |
|
|
- OPEN: Failures exceeded threshold, requests fail immediately |
|
|
- HALF_OPEN: Testing if service recovered |
|
|
|
|
|
Example: |
|
|
breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=30) |
|
|
|
|
|
async with breaker: |
|
|
result = await call_external_service() |
|
|
""" |
|
|
|
|
|
CLOSED = "CLOSED" |
|
|
OPEN = "OPEN" |
|
|
HALF_OPEN = "HALF_OPEN" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
failure_threshold: int = 5, |
|
|
recovery_timeout: float = 30.0, |
|
|
name: str = "default", |
|
|
): |
|
|
self.failure_threshold = failure_threshold |
|
|
self.recovery_timeout = recovery_timeout |
|
|
self.name = name |
|
|
|
|
|
self._state = self.CLOSED |
|
|
self._failure_count = 0 |
|
|
self._last_failure_time: Optional[float] = None |
|
|
self._lock = asyncio.Lock() |
|
|
|
|
|
@property |
|
|
def state(self) -> str: |
|
|
return self._state |
|
|
|
|
|
async def __aenter__(self): |
|
|
async with self._lock: |
|
|
if self._state == self.OPEN: |
|
|
|
|
|
if self._last_failure_time: |
|
|
elapsed = asyncio.get_event_loop().time() - self._last_failure_time |
|
|
if elapsed >= self.recovery_timeout: |
|
|
self._state = self.HALF_OPEN |
|
|
logger.info(f"Circuit breaker '{self.name}' entering HALF_OPEN state") |
|
|
else: |
|
|
raise CircuitBreakerError( |
|
|
f"Circuit breaker '{self.name}' is OPEN, retry in {self.recovery_timeout - elapsed:.1f}s" |
|
|
) |
|
|
return self |
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
|
async with self._lock: |
|
|
if exc_type is not None: |
|
|
self._failure_count += 1 |
|
|
self._last_failure_time = asyncio.get_event_loop().time() |
|
|
|
|
|
if self._failure_count >= self.failure_threshold: |
|
|
self._state = self.OPEN |
|
|
logger.warning( |
|
|
f"Circuit breaker '{self.name}' OPENED after {self._failure_count} failures" |
|
|
) |
|
|
else: |
|
|
|
|
|
if self._state == self.HALF_OPEN: |
|
|
self._state = self.CLOSED |
|
|
self._failure_count = 0 |
|
|
logger.info(f"Circuit breaker '{self.name}' recovered to CLOSED state") |
|
|
elif self._state == self.CLOSED: |
|
|
self._failure_count = 0 |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
class CircuitBreakerError(Exception): |
|
|
"""Raised when circuit breaker is open.""" |
|
|
pass |
|
|
|