import uvicorn from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Dict, Any, Optional import numpy as np import torch import os import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) try: from unihvac.policy import DecisionTransformerPolicy5Zone except ImportError: from policy import DecisionTransformerPolicy5Zone # Import LLM Sensor try: from LLM_part.digital_human_manager import DigitalHumanSensor except ImportError: print(" LLM features disabled.") DigitalHumanSensor = None app = FastAPI() # --- 2. CONFIGURATION --- BASE_PATH = "gen_hvac" CKPT_PATH = os.path.join(BASE_PATH, "training-runs/run_001/last.pt") MODEL_CONFIG = os.path.join(BASE_PATH, "training-runs/run_001/model_config.json") NORM_STATS = "TrajectoryData_from_docker/norm_stats_v4_topk.npz" FIXED_ENERGY_TARGET = -40000.0 COMFORT_RELAXED = -1000.0 COMFORT_STRICT = -1000.0 class SafetyCheck: def __init__(self): self.current_comfort_target = COMFORT_RELAXED self.ema_alpha = 0.3 self.power_limit = 12000.0 def update(self, llm_votes: Dict[str, float], current_power_watts: float): votes = list(llm_votes.values()) max_discomfort = max([abs(v) for v in votes]) if votes else 0.0 if max_discomfort >= 1.5: goal_target = COMFORT_STRICT status = "CRITICAL COMPLAINT" elif max_discomfort >= 0.5: goal_target = (COMFORT_RELAXED + COMFORT_STRICT) / 2 status = "MILD DISCOMFORT" else: goal_target = COMFORT_RELAXED status = "SATISFIED" if current_power_watts > self.power_limit: goal_target = min(goal_target, -25000.0) status += " [ENERGY LIMIT EXCEEDED]" # D. Prevent Hallucination Spikes self.current_comfort_target = (1 - self.ema_alpha) * self.current_comfort_target + \ (self.ema_alpha * goal_target) return self.current_comfort_target, status dt_policy = None llm_sensor = None governor = SafetyCheck() # Keys Mapping ENV_KEYS = [ 'month', 'day_of_month', 'hour', 'outdoor_temp', 'core_temp', 'perim1_temp', 'perim2_temp', 'perim3_temp', 'perim4_temp', 'elec_power', 'core_occ_count', 'perim1_occ_count', 'perim2_occ_count', 'perim3_occ_count', 'perim4_occ_count', 'outdoor_dewpoint', 'outdoor_wetbulb', 'core_rh', 'perim1_rh', 'perim2_rh', 'perim3_rh', 'perim4_rh', 'core_ash55_notcomfortable_summer', 'core_ash55_notcomfortable_winter', 'core_ash55_notcomfortable_any', 'p1_ash55_notcomfortable_any', 'p2_ash55_notcomfortable_any', 'p3_ash55_notcomfortable_any', 'p4_ash55_notcomfortable_any', 'total_electricity_HVAC' ] @app.on_event("startup") def load_model(): global dt_policy, llm_sensor device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Load DT Policy try: dt_policy = DecisionTransformerPolicy5Zone( ckpt_path=CKPT_PATH, model_config_path=MODEL_CONFIG, norm_stats_path=NORM_STATS, context_len=48, max_tokens_per_step=64, device=device, temperature=0.5, target_energy=FIXED_ENERGY_TARGET, target_comfort=COMFORT_RELAXED ) print("DT Policy Loaded.") except Exception as e: print(f"DT Load Error: {e}") # 2. Load LLM if DigitalHumanSensor: try: llm_sensor = DigitalHumanSensor(model_name="deepseek-v2") print("LLM Sensor Loaded.") except Exception as e: print(f"LLM Error: {e}") class ObsPayload(BaseModel): step: int obs: List[float] info: Dict[str, Any] = {} class ResetPayload(BaseModel): message: str = "reset" @app.post("/reset") def reset_policy(payload: ResetPayload): if dt_policy: dt_policy.reset() global governor governor = SafetyCheck() dt_policy.target_energy = FIXED_ENERGY_TARGET dt_policy.target_comfort = COMFORT_RELAXED return {"status": "success"} return {"status": "error"} @app.post("/predict") def get_action(payload: ObsPayload): global dt_policy, llm_sensor, governor if dt_policy is None: raise HTTPException(status_code=503, detail="Model not loaded") obs_arr = np.array(payload.obs, dtype=np.float32) # 1. LLM Loop (Keep existing) if llm_sensor and (payload.step % 4 == 0): try: env_map = dict(zip(ENV_KEYS, obs_arr)) votes = llm_sensor.get_comfort_votes(env_map) new_target, status = governor.update(votes, obs_arr[9]) dt_policy.target_comfort = new_target print(f"[Step {payload.step}] LLM: {votes} | Status: {status} | Target: {new_target:.0f}") except Exception: pass action, _, _ = dt_policy.act(obs_arr, payload.info, payload.step) return {"action": action.tolist()} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)