Controller / Inference_&_LLM /inference_server.py
Gen-HVAC's picture
Upload 4 files
0575976 verified
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import numpy as np
import torch
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
try:
from unihvac.policy import DecisionTransformerPolicy5Zone
except ImportError:
from policy import DecisionTransformerPolicy5Zone
# Import LLM Sensor
try:
from LLM_part.digital_human_manager import DigitalHumanSensor
except ImportError:
print(" LLM features disabled.")
DigitalHumanSensor = None
app = FastAPI()
# --- 2. CONFIGURATION ---
BASE_PATH = "gen_hvac"
CKPT_PATH = os.path.join(BASE_PATH, "training-runs/run_001/last.pt")
MODEL_CONFIG = os.path.join(BASE_PATH, "training-runs/run_001/model_config.json")
NORM_STATS = "TrajectoryData_from_docker/norm_stats_v4_topk.npz"
FIXED_ENERGY_TARGET = -40000.0
COMFORT_RELAXED = -1000.0
COMFORT_STRICT = -1000.0
class SafetyCheck:
def __init__(self):
self.current_comfort_target = COMFORT_RELAXED
self.ema_alpha = 0.3
self.power_limit = 12000.0
def update(self, llm_votes: Dict[str, float], current_power_watts: float):
votes = list(llm_votes.values())
max_discomfort = max([abs(v) for v in votes]) if votes else 0.0
if max_discomfort >= 1.5:
goal_target = COMFORT_STRICT
status = "CRITICAL COMPLAINT"
elif max_discomfort >= 0.5:
goal_target = (COMFORT_RELAXED + COMFORT_STRICT) / 2
status = "MILD DISCOMFORT"
else:
goal_target = COMFORT_RELAXED
status = "SATISFIED"
if current_power_watts > self.power_limit:
goal_target = min(goal_target, -25000.0)
status += " [ENERGY LIMIT EXCEEDED]"
# D. Prevent Hallucination Spikes
self.current_comfort_target = (1 - self.ema_alpha) * self.current_comfort_target + \
(self.ema_alpha * goal_target)
return self.current_comfort_target, status
dt_policy = None
llm_sensor = None
governor = SafetyCheck()
# Keys Mapping
ENV_KEYS = [
'month', 'day_of_month', 'hour',
'outdoor_temp', 'core_temp', 'perim1_temp', 'perim2_temp', 'perim3_temp', 'perim4_temp',
'elec_power',
'core_occ_count', 'perim1_occ_count', 'perim2_occ_count', 'perim3_occ_count', 'perim4_occ_count',
'outdoor_dewpoint', 'outdoor_wetbulb',
'core_rh', 'perim1_rh', 'perim2_rh', 'perim3_rh', 'perim4_rh',
'core_ash55_notcomfortable_summer', 'core_ash55_notcomfortable_winter', 'core_ash55_notcomfortable_any',
'p1_ash55_notcomfortable_any', 'p2_ash55_notcomfortable_any', 'p3_ash55_notcomfortable_any', 'p4_ash55_notcomfortable_any',
'total_electricity_HVAC'
]
@app.on_event("startup")
def load_model():
global dt_policy, llm_sensor
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load DT Policy
try:
dt_policy = DecisionTransformerPolicy5Zone(
ckpt_path=CKPT_PATH,
model_config_path=MODEL_CONFIG,
norm_stats_path=NORM_STATS,
context_len=48,
max_tokens_per_step=64,
device=device,
temperature=0.5,
target_energy=FIXED_ENERGY_TARGET,
target_comfort=COMFORT_RELAXED
)
print("DT Policy Loaded.")
except Exception as e:
print(f"DT Load Error: {e}")
# 2. Load LLM
if DigitalHumanSensor:
try:
llm_sensor = DigitalHumanSensor(model_name="deepseek-v2")
print("LLM Sensor Loaded.")
except Exception as e:
print(f"LLM Error: {e}")
class ObsPayload(BaseModel):
step: int
obs: List[float]
info: Dict[str, Any] = {}
class ResetPayload(BaseModel):
message: str = "reset"
@app.post("/reset")
def reset_policy(payload: ResetPayload):
if dt_policy:
dt_policy.reset()
global governor
governor = SafetyCheck()
dt_policy.target_energy = FIXED_ENERGY_TARGET
dt_policy.target_comfort = COMFORT_RELAXED
return {"status": "success"}
return {"status": "error"}
@app.post("/predict")
def get_action(payload: ObsPayload):
global dt_policy, llm_sensor, governor
if dt_policy is None:
raise HTTPException(status_code=503, detail="Model not loaded")
obs_arr = np.array(payload.obs, dtype=np.float32)
# 1. LLM Loop (Keep existing)
if llm_sensor and (payload.step % 4 == 0):
try:
env_map = dict(zip(ENV_KEYS, obs_arr))
votes = llm_sensor.get_comfort_votes(env_map)
new_target, status = governor.update(votes, obs_arr[9])
dt_policy.target_comfort = new_target
print(f"[Step {payload.step}] LLM: {votes} | Status: {status} | Target: {new_target:.0f}")
except Exception:
pass
action, _, _ = dt_policy.act(obs_arr, payload.info, payload.step)
return {"action": action.tolist()}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)