File size: 3,115 Bytes
81e328b
 
 
 
 
 
 
 
20bc5e4
81e328b
 
542893e
 
 
 
81e328b
 
 
 
 
 
 
 
 
 
 
 
542893e
81e328b
 
 
 
542893e
81e328b
 
 
 
 
 
 
 
 
 
542893e
81e328b
542893e
 
 
 
 
 
 
 
 
 
 
81e328b
 
 
 
542893e
 
 
 
 
81e328b
 
 
 
542893e
 
81e328b
 
 
 
20bc5e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
FastAPI application for the PLL Cyberattack Detection OpenEnv.

Exposes HTTP endpoints for environment interaction:
  POST /reset   — Reset environment with task_id
  POST /step    — Submit an action and advance one step
  GET  /state   — Get current internal state
  GET  /health  — Health check (returns 200)
  GET  /tasks   — List available tasks
"""

import asyncio
from typing import Any, Dict, Optional

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel

from src.models import Observation, Action, Reward, State
from src.env import PLLAttackEnv

app = FastAPI(
    title="PLL Cyberattack Detection OpenEnv",
    description="OpenEnv for AI-driven cyberattack detection on SRF-PLLs",
    version="1.0.0",
)

env = PLLAttackEnv()
env_lock = asyncio.Lock()


class ResetRequest(BaseModel):
    task_id: int = 0
    seed: Optional[int] = None


class StepResponse(BaseModel):
    observation: Observation
    reward: Reward
    done: bool
    info: Dict[str, Any]


@app.post("/reset", response_model=Observation)
async def reset(req: Request):
    """Reset the environment and return initial observation."""
    async with env_lock:
        try:
            body = await req.body()
            if body:
                data = await req.json()
                request = ResetRequest(**data)
            else:
                request = ResetRequest()
        except Exception:
            request = ResetRequest()
        return env.reset(task_id=request.task_id, seed=request.seed)


@app.post("/step", response_model=StepResponse)
async def step(action: Action):
    async with env_lock:
        if env.attack_generator is None:
            raise HTTPException(status_code=400, detail="Call /reset before /step")
        obs, reward, done, info = env.step(action)
        return StepResponse(observation=obs, reward=reward, done=done, info=info)


@app.get("/state", response_model=State)
async def get_state():
    async with env_lock:
        return env.get_state()


@app.get("/health")
async def health():
    """Health check endpoint."""
    return {
        "status": "ok",
        "version": "1.0.0",
        "environment": "pll-cyberattack-detection",
        "tasks_available": 3,
        "episode_active": env.attack_generator is not None,
        "current_step": env.step_count,
    }


@app.get("/tasks")
async def list_tasks():
    """List all available tasks."""
    return {
        "tasks": [
            {
                "id": 0,
                "name": "sinusoidal_fdi",
                "difficulty": "easy",
                "description": "Detect sinusoidal FDI attack"
            },
            {
                "id": 1,
                "name": "multi_attack_classification",
                "difficulty": "medium",
                "description": "Classify attack type from observations"
            },
            {
                "id": 2,
                "name": "stealthy_attack_detection",
                "difficulty": "hard",
                "description": "Detect stealthy attack before PLL lock loss"
            },
        ]
    }