| """
|
| Load Balancer for Mamba Swarm API
|
| Distributes requests across multiple API server instances
|
| """
|
|
|
| import asyncio
|
| import aiohttp
|
| import random
|
| import time
|
| import logging
|
| from typing import List, Dict, Any, Optional, Tuple
|
| from dataclasses import dataclass, field
|
| from enum import Enum
|
| from collections import defaultdict, deque
|
| import json
|
| import hashlib
|
|
|
| class LoadBalancingStrategy(Enum):
|
| ROUND_ROBIN = "round_robin"
|
| LEAST_CONNECTIONS = "least_connections"
|
| WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
|
| LEAST_RESPONSE_TIME = "least_response_time"
|
| HASH_BASED = "hash_based"
|
| RESOURCE_AWARE = "resource_aware"
|
|
|
| @dataclass
|
| class ServerInstance:
|
| host: str
|
| port: int
|
| weight: float = 1.0
|
| max_connections: int = 100
|
| timeout: float = 30.0
|
| current_connections: int = 0
|
| total_requests: int = 0
|
| failed_requests: int = 0
|
| response_times: deque = field(default_factory=lambda: deque(maxlen=100))
|
| last_health_check: float = 0.0
|
| is_healthy: bool = True
|
| health_check_failures: int = 0
|
|
|
| @property
|
| def url(self) -> str:
|
| return f"http://{self.host}:{self.port}"
|
|
|
| @property
|
| def avg_response_time(self) -> float:
|
| return sum(self.response_times) / len(self.response_times) if self.response_times else 0.0
|
|
|
| @property
|
| def success_rate(self) -> float:
|
| total = self.total_requests
|
| if total == 0:
|
| return 1.0
|
| return (total - self.failed_requests) / total
|
|
|
| @property
|
| def load_score(self) -> float:
|
| """Calculate load score for resource-aware balancing"""
|
| connection_load = self.current_connections / self.max_connections
|
| response_time_load = min(self.avg_response_time / 1000.0, 1.0)
|
| failure_rate = self.failed_requests / max(self.total_requests, 1)
|
|
|
| return (connection_load * 0.4 + response_time_load * 0.4 + failure_rate * 0.2)
|
|
|
| class LoadBalancer:
|
| """Advanced load balancer for Mamba Swarm API servers"""
|
|
|
| def __init__(self,
|
| servers: List[Tuple[str, int]],
|
| strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE,
|
| health_check_interval: float = 30.0,
|
| health_check_timeout: float = 5.0,
|
| max_retries: int = 3):
|
|
|
| self.logger = logging.getLogger(__name__)
|
| self.strategy = strategy
|
| self.health_check_interval = health_check_interval
|
| self.health_check_timeout = health_check_timeout
|
| self.max_retries = max_retries
|
|
|
|
|
| self.servers = [
|
| ServerInstance(host=host, port=port)
|
| for host, port in servers
|
| ]
|
|
|
|
|
| self.round_robin_index = 0
|
| self.request_counts = defaultdict(int)
|
|
|
|
|
| self.session: Optional[aiohttp.ClientSession] = None
|
|
|
|
|
| self.health_check_task: Optional[asyncio.Task] = None
|
|
|
|
|
| self.total_requests = 0
|
| self.failed_requests = 0
|
| self.start_time = time.time()
|
|
|
| async def __aenter__(self):
|
| """Async context manager entry"""
|
| await self.start()
|
| return self
|
|
|
| async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| """Async context manager exit"""
|
| await self.stop()
|
|
|
| async def start(self):
|
| """Start the load balancer"""
|
|
|
| timeout = aiohttp.ClientTimeout(total=30.0, connect=10.0)
|
| self.session = aiohttp.ClientSession(timeout=timeout)
|
|
|
|
|
| self.health_check_task = asyncio.create_task(self._health_check_loop())
|
|
|
|
|
| await self._check_all_servers_health()
|
|
|
| self.logger.info(f"Load balancer started with {len(self.servers)} servers using {self.strategy.value} strategy")
|
|
|
| async def stop(self):
|
| """Stop the load balancer"""
|
| if self.health_check_task:
|
| self.health_check_task.cancel()
|
| try:
|
| await self.health_check_task
|
| except asyncio.CancelledError:
|
| pass
|
|
|
| if self.session:
|
| await self.session.close()
|
|
|
| self.logger.info("Load balancer stopped")
|
|
|
| def get_healthy_servers(self) -> List[ServerInstance]:
|
| """Get list of healthy servers"""
|
| return [server for server in self.servers if server.is_healthy]
|
|
|
| def select_server(self, request_data: Optional[Dict[str, Any]] = None) -> Optional[ServerInstance]:
|
| """Select server based on configured strategy"""
|
| healthy_servers = self.get_healthy_servers()
|
|
|
| if not healthy_servers:
|
| self.logger.warning("No healthy servers available")
|
| return None
|
|
|
| if self.strategy == LoadBalancingStrategy.ROUND_ROBIN:
|
| return self._round_robin_selection(healthy_servers)
|
| elif self.strategy == LoadBalancingStrategy.LEAST_CONNECTIONS:
|
| return self._least_connections_selection(healthy_servers)
|
| elif self.strategy == LoadBalancingStrategy.WEIGHTED_ROUND_ROBIN:
|
| return self._weighted_round_robin_selection(healthy_servers)
|
| elif self.strategy == LoadBalancingStrategy.LEAST_RESPONSE_TIME:
|
| return self._least_response_time_selection(healthy_servers)
|
| elif self.strategy == LoadBalancingStrategy.HASH_BASED:
|
| return self._hash_based_selection(healthy_servers, request_data)
|
| elif self.strategy == LoadBalancingStrategy.RESOURCE_AWARE:
|
| return self._resource_aware_selection(healthy_servers)
|
| else:
|
| return random.choice(healthy_servers)
|
|
|
| def _round_robin_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
| """Round-robin server selection"""
|
| server = servers[self.round_robin_index % len(servers)]
|
| self.round_robin_index += 1
|
| return server
|
|
|
| def _least_connections_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
| """Select server with least connections"""
|
| return min(servers, key=lambda s: s.current_connections)
|
|
|
| def _weighted_round_robin_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
| """Weighted round-robin selection"""
|
| total_weight = sum(s.weight for s in servers)
|
| random_weight = random.uniform(0, total_weight)
|
|
|
| current_weight = 0
|
| for server in servers:
|
| current_weight += server.weight
|
| if random_weight <= current_weight:
|
| return server
|
|
|
| return servers[-1]
|
|
|
| def _least_response_time_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
| """Select server with least average response time"""
|
| return min(servers, key=lambda s: s.avg_response_time or float('inf'))
|
|
|
| def _hash_based_selection(self, servers: List[ServerInstance], request_data: Optional[Dict[str, Any]]) -> ServerInstance:
|
| """Hash-based selection for session affinity"""
|
| if not request_data or 'prompt' not in request_data:
|
| return random.choice(servers)
|
|
|
|
|
| prompt_hash = hashlib.md5(request_data['prompt'].encode()).hexdigest()
|
| server_index = int(prompt_hash, 16) % len(servers)
|
| return servers[server_index]
|
|
|
| def _resource_aware_selection(self, servers: List[ServerInstance]) -> ServerInstance:
|
| """Select server based on resource utilization"""
|
|
|
| sorted_servers = sorted(servers, key=lambda s: s.load_score)
|
|
|
|
|
| weights = [1.0 / (s.load_score + 0.1) for s in sorted_servers]
|
| total_weight = sum(weights)
|
|
|
| random_value = random.uniform(0, total_weight)
|
| current_weight = 0
|
|
|
| for server, weight in zip(sorted_servers, weights):
|
| current_weight += weight
|
| if random_value <= current_weight:
|
| return server
|
|
|
| return sorted_servers[0]
|
|
|
| async def forward_request(self,
|
| path: str,
|
| method: str = "POST",
|
| data: Optional[Dict[str, Any]] = None,
|
| headers: Optional[Dict[str, str]] = None,
|
| **kwargs) -> Tuple[int, Dict[str, Any]]:
|
| """Forward request to selected server with retry logic"""
|
| self.total_requests += 1
|
|
|
| for attempt in range(self.max_retries + 1):
|
| server = self.select_server(data)
|
| if not server:
|
| self.failed_requests += 1
|
| return 503, {"error": "No healthy servers available"}
|
|
|
| try:
|
| start_time = time.time()
|
| server.current_connections += 1
|
|
|
| url = f"{server.url}{path}"
|
| request_kwargs = {
|
| "timeout": aiohttp.ClientTimeout(total=server.timeout),
|
| **kwargs
|
| }
|
|
|
| if headers:
|
| request_kwargs["headers"] = headers
|
|
|
| if data:
|
| request_kwargs["json"] = data
|
|
|
| async with self.session.request(method, url, **request_kwargs) as response:
|
| response_time = time.time() - start_time
|
| response_data = await response.json()
|
|
|
|
|
| server.current_connections -= 1
|
| server.total_requests += 1
|
| server.response_times.append(response_time * 1000)
|
|
|
| if response.status >= 400:
|
| server.failed_requests += 1
|
|
|
| if attempt < self.max_retries:
|
| self.logger.warning(f"Request failed on {server.url} (attempt {attempt + 1}), retrying...")
|
| continue
|
|
|
| return response.status, response_data
|
|
|
| except Exception as e:
|
| server.current_connections = max(0, server.current_connections - 1)
|
| server.failed_requests += 1
|
|
|
| self.logger.error(f"Request failed on {server.url}: {e}")
|
|
|
| if attempt < self.max_retries:
|
| await asyncio.sleep(0.1 * (attempt + 1))
|
| continue
|
|
|
| self.failed_requests += 1
|
| return 502, {"error": "All servers failed after retries"}
|
|
|
| async def _check_server_health(self, server: ServerInstance) -> bool:
|
| """Check health of a single server"""
|
| try:
|
| url = f"{server.url}/health"
|
| timeout = aiohttp.ClientTimeout(total=self.health_check_timeout)
|
|
|
| async with self.session.get(url, timeout=timeout) as response:
|
| if response.status == 200:
|
| health_data = await response.json()
|
| server.last_health_check = time.time()
|
| server.health_check_failures = 0
|
|
|
|
|
| if 'system_info' in health_data:
|
|
|
| pass
|
|
|
| return True
|
| else:
|
| server.health_check_failures += 1
|
| return False
|
|
|
| except Exception as e:
|
| server.health_check_failures += 1
|
| self.logger.debug(f"Health check failed for {server.url}: {e}")
|
| return False
|
|
|
| async def _check_all_servers_health(self):
|
| """Check health of all servers"""
|
| tasks = [self._check_server_health(server) for server in self.servers]
|
| results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
| for server, result in zip(self.servers, results):
|
| if isinstance(result, Exception):
|
| server.is_healthy = False
|
| server.health_check_failures += 1
|
| else:
|
| was_healthy = server.is_healthy
|
| server.is_healthy = result and server.health_check_failures < 3
|
|
|
| if not was_healthy and server.is_healthy:
|
| self.logger.info(f"Server {server.url} is back online")
|
| elif was_healthy and not server.is_healthy:
|
| self.logger.warning(f"Server {server.url} is unhealthy")
|
|
|
| async def _health_check_loop(self):
|
| """Periodic health check loop"""
|
| while True:
|
| try:
|
| await asyncio.sleep(self.health_check_interval)
|
| await self._check_all_servers_health()
|
| except asyncio.CancelledError:
|
| break
|
| except Exception as e:
|
| self.logger.error(f"Health check loop error: {e}")
|
|
|
| def add_server(self, host: str, port: int, weight: float = 1.0):
|
| """Add a new server to the pool"""
|
| server = ServerInstance(host=host, port=port, weight=weight)
|
| self.servers.append(server)
|
| self.logger.info(f"Added server {server.url}")
|
|
|
| def remove_server(self, host: str, port: int):
|
| """Remove a server from the pool"""
|
| self.servers = [s for s in self.servers if not (s.host == host and s.port == port)]
|
| self.logger.info(f"Removed server http://{host}:{port}")
|
|
|
| def get_stats(self) -> Dict[str, Any]:
|
| """Get load balancer statistics"""
|
| uptime = time.time() - self.start_time
|
|
|
| server_stats = []
|
| for server in self.servers:
|
| server_stats.append({
|
| "url": server.url,
|
| "is_healthy": server.is_healthy,
|
| "current_connections": server.current_connections,
|
| "total_requests": server.total_requests,
|
| "failed_requests": server.failed_requests,
|
| "success_rate": server.success_rate,
|
| "avg_response_time_ms": server.avg_response_time,
|
| "load_score": server.load_score,
|
| "weight": server.weight
|
| })
|
|
|
| return {
|
| "strategy": self.strategy.value,
|
| "uptime_seconds": uptime,
|
| "total_requests": self.total_requests,
|
| "failed_requests": self.failed_requests,
|
| "success_rate": (self.total_requests - self.failed_requests) / max(self.total_requests, 1),
|
| "healthy_servers": len(self.get_healthy_servers()),
|
| "total_servers": len(self.servers),
|
| "servers": server_stats
|
| }
|
|
|
|
|
| from fastapi import FastAPI, Request, HTTPException
|
| from fastapi.responses import JSONResponse
|
| import uvicorn
|
|
|
| def create_load_balancer_app(servers: List[Tuple[str, int]],
|
| strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE) -> FastAPI:
|
| """Create FastAPI app with load balancer"""
|
|
|
| app = FastAPI(title="Mamba Swarm Load Balancer", version="1.0.0")
|
| load_balancer = LoadBalancer(servers, strategy)
|
|
|
| @app.on_event("startup")
|
| async def startup():
|
| await load_balancer.start()
|
|
|
| @app.on_event("shutdown")
|
| async def shutdown():
|
| await load_balancer.stop()
|
|
|
| @app.get("/lb/health")
|
| async def lb_health():
|
| """Load balancer health endpoint"""
|
| return {"status": "healthy", "stats": load_balancer.get_stats()}
|
|
|
| @app.get("/lb/stats")
|
| async def lb_stats():
|
| """Get load balancer statistics"""
|
| return load_balancer.get_stats()
|
|
|
| @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
| async def proxy_request(request: Request, path: str):
|
| """Proxy all requests to backend servers"""
|
| try:
|
|
|
| body = await request.body()
|
| headers = dict(request.headers)
|
|
|
|
|
| headers.pop("host", None)
|
| headers.pop("connection", None)
|
|
|
|
|
| data = None
|
| if body:
|
| try:
|
| import json
|
| data = json.loads(body.decode())
|
| except:
|
| pass
|
|
|
|
|
| status, response_data = await load_balancer.forward_request(
|
| f"/{path}",
|
| request.method,
|
| data=data,
|
| headers=headers,
|
| params=dict(request.query_params)
|
| )
|
|
|
| return JSONResponse(content=response_data, status_code=status)
|
|
|
| except Exception as e:
|
| return JSONResponse(
|
| content={"error": f"Load balancer error: {str(e)}"},
|
| status_code=500
|
| )
|
|
|
| return app
|
|
|
| def run_load_balancer(servers: List[Tuple[str, int]],
|
| host: str = "0.0.0.0",
|
| port: int = 8080,
|
| strategy: LoadBalancingStrategy = LoadBalancingStrategy.RESOURCE_AWARE):
|
| """Run the load balancer"""
|
| app = create_load_balancer_app(servers, strategy)
|
|
|
| config = uvicorn.Config(
|
| app=app,
|
| host=host,
|
| port=port,
|
| log_level="info"
|
| )
|
|
|
| server = uvicorn.Server(config)
|
| server.run()
|
|
|
| if __name__ == "__main__":
|
|
|
| servers = [
|
| ("localhost", 8000),
|
| ("localhost", 8001),
|
| ("localhost", 8002),
|
| ]
|
|
|
| run_load_balancer(servers, strategy=LoadBalancingStrategy.RESOURCE_AWARE) |