updated-policy / simulation /infrastructure.py
Yaswanth-Bolla's picture
Initial commit
1175c0b
"""
Infrastructure state machine.
Manages the full service topology, cascade propagation, and the
validate → mutate → tick execution ordering.
This is the central coordinator that scenarios inject faults into
and the environment executes actions against.
"""
from __future__ import annotations
import random
from typing import Any, Dict, List, Optional, Set, Tuple
from .service import Deploy, ServiceState
from .alerts import evaluate_alerts
from .logs import (
generate_noise_logs,
generate_red_herring_logs,
)
# ------------------------------------------------------------------
# Service topology — the dependency graph
# ------------------------------------------------------------------
SERVICE_NAMES = [
"api_gateway", "auth", "orders", "payment",
"cache", "database", "queue",
]
# depends_on: service → list of services it calls
DEPENDENCY_GRAPH: Dict[str, List[str]] = {
"api_gateway": ["auth", "orders", "cache"],
"auth": ["database"],
"orders": ["database", "payment", "auth"],
"payment": ["queue", "database"],
"cache": [],
"database": [],
"queue": [],
}
def _depended_by_graph() -> Dict[str, List[str]]:
"""Invert the dependency graph: service → who depends on it."""
inv: Dict[str, List[str]] = {name: [] for name in SERVICE_NAMES}
for service, deps in DEPENDENCY_GRAPH.items():
for dep in deps:
inv[dep].append(service)
return inv
DEPENDED_BY = _depended_by_graph()
class Infrastructure:
"""
Virtual infrastructure state machine.
Owns all services, handles cascade propagation, tracks simulation
time, and enforces action validation.
"""
def __init__(self) -> None:
self.services: Dict[str, ServiceState] = {}
self.current_minute: int = 0
self.time_budget_minutes: int = 30
self._actions_taken: List[Tuple[str, Optional[str]]] = [] # (action_type, target)
self._all_logs: List[Dict[str, Any]] = []
self._setup_services()
def _setup_services(self) -> None:
"""Create all seven services with their dependency graphs."""
for name in SERVICE_NAMES:
svc = ServiceState(
name=name,
dependencies=list(DEPENDENCY_GRAPH.get(name, [])),
)
# Give each service a "good" deploy history baseline
svc.deploy_history = [
Deploy(
version=f"v1.{random.randint(0, 9)}.{random.randint(0, 9)}",
timestamp_minutes=-120,
author=random.choice(["alice", "bob", "charlie", "deploy-bot"]),
commit_hash=f"{random.randint(0, 0xFFFFFF):06x}",
description="Routine release — bug fixes and performance improvements",
),
]
# Populate 30 minutes of healthy metric history
from .metrics import generate_healthy_history
svc.metric_history = generate_healthy_history(30, start_minute=0)
svc._reset_metrics_healthy()
self.services[name] = svc
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def get_service(self, name: str) -> Optional[ServiceState]:
return self.services.get(name)
def get_all_statuses(self) -> Dict[str, str]:
return {name: svc.status for name, svc in self.services.items()}
def get_alerts(self) -> List[Dict[str, Any]]:
return evaluate_alerts(self.services, self.current_minute)
def record_action(self, action_type: str, target: Optional[str]) -> None:
self._actions_taken.append((action_type, target))
def was_action_taken(self, action_type: str, target: Optional[str] = None) -> bool:
"""Check if this exact action was already taken (for repeat detection)."""
return (action_type, target) in self._actions_taken
def action_count(self) -> int:
return len(self._actions_taken)
# ------------------------------------------------------------------
# Tick — advances simulation by one minute (called after every step)
# ------------------------------------------------------------------
def tick(self) -> None:
"""
Advance the simulation by one minute.
Order: propagate cascades → tick all services → generate noise logs.
"""
self.current_minute += 1
self._propagate_cascades()
for name, svc in self.services.items():
new_logs = svc.tick(self.current_minute)
self._all_logs.extend(new_logs)
# Mix in noise and red herrings
if random.random() < 0.4:
noise = generate_noise_logs(name, self.current_minute, count=1)
svc.logs.extend(noise)
self._all_logs.extend(noise)
if random.random() < 0.15:
herrings = generate_red_herring_logs(name, self.current_minute, count=1)
svc.logs.extend(herrings)
self._all_logs.extend(herrings)
# ------------------------------------------------------------------
# Cascade propagation
# ------------------------------------------------------------------
def _propagate_cascades(self) -> None:
"""
If a service is DOWN or DEGRADED, its downstream dependents
should accumulate dependency_degraded faults.
If a service recovers, clear the cascaded faults.
"""
for name, svc in self.services.items():
dependents = DEPENDED_BY.get(name, [])
if svc.status in ("down", "degraded") and svc.active_faults:
# Cascade to dependents
for dep_name in dependents:
dep_svc = self.services[dep_name]
if not dep_svc.has_fault("dependency_degraded"):
dep_svc.inject_fault("dependency_degraded", upstream=name)
elif svc.status == "healthy" and not svc.active_faults:
# Service recovered — clear cascade on dependents
for dep_name in dependents:
dep_svc = self.services[dep_name]
params = dep_svc.fault_params.get("dependency_degraded", {})
if params.get("upstream") == name:
dep_svc.recover_from_dependency(self.current_minute)
# ------------------------------------------------------------------
# Validation — Layer 5: validate BEFORE mutating
# ------------------------------------------------------------------
def validate_action(
self,
action_type: str,
target_service: Optional[str],
) -> Tuple[bool, str]:
"""
Validate that an action is legal in the current state.
Returns (is_valid, error_message).
"""
from ..models import ActionType, TARGETED_ACTIONS
try:
at = ActionType(action_type)
except ValueError:
return False, f"Unknown action type: {action_type}"
if at in TARGETED_ACTIONS:
if not target_service:
return False, f"Action {action_type} requires a target_service"
if target_service not in self.services:
return False, f"Unknown service: {target_service}"
# Specific validations (action masking)
if at == ActionType.ROLLBACK_DEPLOY:
svc = self.services.get(target_service, None)
if svc and len(svc.deploy_history) < 2:
return False, f"No previous deploy to rollback to for {target_service}"
if at == ActionType.SCALE_SERVICE:
svc = self.services.get(target_service, None)
if svc and svc.status == "down":
return False, f"Cannot scale {target_service} — service is DOWN"
return True, ""
def get_valid_actions(self) -> List[str]:
"""
Return list of valid (action_type, target) descriptions.
Used to populate valid_actions[] in the observation.
"""
from ..models import ActionType, TARGETED_ACTIONS
valid = []
for at in ActionType:
if at in TARGETED_ACTIONS:
for svc_name in self.services:
is_valid, _ = self.validate_action(at.value, svc_name)
if is_valid:
valid.append(f"{at.value}:{svc_name}")
else:
valid.append(at.value)
return valid
# ------------------------------------------------------------------
# Service queries (observation builders)
# ------------------------------------------------------------------
def get_logs_for_service(
self,
service_name: str,
level_filter: Optional[str] = None,
keyword: Optional[str] = None,
limit: int = 20,
) -> List[Dict[str, Any]]:
"""Query logs for a service with optional filtering."""
svc = self.services.get(service_name)
if not svc:
return []
logs = list(svc.logs)
if level_filter:
logs = [l for l in logs if l.get("level", "").upper() == level_filter.upper()]
if keyword:
kw = keyword.lower()
logs = [l for l in logs if kw in l.get("message", "").lower()]
return logs[-limit:]
def get_metrics_for_service(self, service_name: str) -> List[Dict[str, float]]:
svc = self.services.get(service_name)
return list(svc.metric_history) if svc else []
def get_dependencies_for_service(self, service_name: str) -> Dict[str, List[str]]:
return {
"depends_on": list(DEPENDENCY_GRAPH.get(service_name, [])),
"depended_by": list(DEPENDED_BY.get(service_name, [])),
}
def get_deploy_history_for_service(self, service_name: str) -> List[Dict[str, Any]]:
svc = self.services.get(service_name)
if not svc:
return []
return [
{
"version": d.version,
"timestamp": f"2025-01-15T{14 + d.timestamp_minutes // 60:02d}:"
f"{d.timestamp_minutes % 60:02d}:00Z"
if d.timestamp_minutes >= 0
else f"2025-01-15T{12 + (d.timestamp_minutes + 120) // 60:02d}:"
f"{(d.timestamp_minutes + 120) % 60:02d}:00Z",
"author": d.author,
"commit_hash": d.commit_hash,
"description": d.description,
}
for d in svc.deploy_history
]
def run_health_check(self, service_name: str) -> Dict[str, Any]:
svc = self.services.get(service_name)
if not svc:
return {"status": "unknown", "response_time_ms": 0}
response_time = {
"healthy": random.randint(5, 50),
"degraded": random.randint(200, 2000),
"down": 30000, # timeout
}.get(svc.status, 0)
return {
"status": svc.status,
"response_time_ms": response_time,
"replicas": svc.replica_count,
"active_faults_count": len(svc.active_faults), # agent can see SOMETHING is wrong
}
def all_services_healthy(self) -> bool:
return all(svc.status == "healthy" for svc in self.services.values())