Spaces:
Sleeping
Sleeping
| """ | |
| llm_model.py | |
| ------------ | |
| LLM-based intersection conflict resolver. | |
| Uses prompt engineering (zero-shot + few-shot) with an OpenAI-compatible API. | |
| """ | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import pandas as pd | |
| # βββ Prompt templates βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """You are an intelligent traffic intersection controller. | |
| Your task is to analyze vehicle scenarios at a 4-way 8-lane intersection, | |
| detect conflicts between vehicles, and output structured control decisions. | |
| Always respond in valid JSON with this exact schema: | |
| { | |
| "is_conflict": "yes" | "no", | |
| "number_of_conflicts": <int>, | |
| "conflict_vehicles": [{"vehicle1_id": "...", "vehicle2_id": "..."}], | |
| "decisions": ["..."], | |
| "priority_order": {"<vehicle_id>": <rank_int>}, | |
| "waiting_times": {"<vehicle_id>": <seconds_int>} | |
| } | |
| Intersection layout (8 lanes, 4 directions): | |
| - North: Lane 1 (right/straight β F,H), Lane 2 (left β E,D,C) | |
| - East: Lane 3 (right/straight β H,B), Lane 4 (left β G,E,F) | |
| - South: Lane 5 (right/straight β B,D), Lane 6 (left β A,G,H) | |
| - West: Lane 7 (right/straight β D,F), Lane 8 (left β B,C,A) | |
| Conflict detection rules: | |
| 1. Two vehicles CONFLICT if ALL of the following are true: | |
| a) Their paths physically cross (opposing or perpendicular directions with crossing trajectories) | |
| b) Both arrive within 5 seconds of each other (time = distance / speed_in_m_s) | |
| c) Speed in m/s = speed_km_h * 1000 / 3600 | |
| 2. Same direction vehicles NEVER conflict. | |
| 3. Right turns rarely conflict with other right turns. | |
| Priority rules (apply when conflict exists): | |
| 1. Straight-going vehicle has priority over turning vehicle. | |
| 2. Right-turning vehicle has priority over left-turning vehicle. | |
| 3. Right-hand rule: vehicle coming from the right has priority. | |
| 4. If arriving more than 1 second earlier: earlier-arriving vehicle has priority. | |
| 5. Priority rank 1 = highest priority (does not wait). Rank 2+ must yield. | |
| Decision format: "Potential conflict: Vehicle X must yield to Vehicle Y" | |
| """ | |
| FEW_SHOT_EXAMPLES = [ | |
| # ββ Example 1: Classic NβS conflict, close arrival ββββββββββββββββββββββββ | |
| { | |
| "role": "user", | |
| "content": json.dumps( | |
| { | |
| "vehicles": [ | |
| { | |
| "vehicle_id": "V1001", | |
| "lane": 1, | |
| "speed": 60, | |
| "distance_to_intersection": 50, | |
| "direction": "north", | |
| "destination": "F", | |
| }, | |
| { | |
| "vehicle_id": "V1002", | |
| "lane": 5, | |
| "speed": 45, | |
| "distance_to_intersection": 55, | |
| "direction": "south", | |
| "destination": "D", | |
| }, | |
| ] | |
| } | |
| ), | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": json.dumps( | |
| { | |
| "is_conflict": "yes", | |
| "number_of_conflicts": 1, | |
| "conflict_vehicles": [{"vehicle1_id": "V1001", "vehicle2_id": "V1002"}], | |
| "decisions": ["Potential conflict: Vehicle V1002 must yield to Vehicle V1001"], | |
| "priority_order": {"V1001": 1, "V1002": 2}, | |
| "waiting_times": {"V1001": 0, "V1002": 3}, | |
| } | |
| ), | |
| }, | |
| # ββ Example 2: No conflict β vehicles too far apart in arrival time ββββββββ | |
| { | |
| "role": "user", | |
| "content": json.dumps( | |
| { | |
| "vehicles": [ | |
| { | |
| "vehicle_id": "V2001", | |
| "lane": 1, | |
| "speed": 70, | |
| "distance_to_intersection": 40, | |
| "direction": "north", | |
| "destination": "H", | |
| }, | |
| { | |
| "vehicle_id": "V2002", | |
| "lane": 3, | |
| "speed": 25, | |
| "distance_to_intersection": 380, | |
| "direction": "east", | |
| "destination": "B", | |
| }, | |
| ] | |
| } | |
| ), | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": json.dumps( | |
| { | |
| "is_conflict": "no", | |
| "number_of_conflicts": 0, | |
| "conflict_vehicles": [], | |
| "decisions": [], | |
| "priority_order": {"V2001": 1, "V2002": 2}, | |
| "waiting_times": {"V2001": 0, "V2002": 0}, | |
| } | |
| ), | |
| }, | |
| # ββ Example 3: NβE conflict, straight vs right turn ββββββββββββββββββββββ | |
| { | |
| "role": "user", | |
| "content": json.dumps( | |
| { | |
| "vehicles": [ | |
| { | |
| "vehicle_id": "V3001", | |
| "lane": 1, | |
| "speed": 55, | |
| "distance_to_intersection": 70, | |
| "direction": "north", | |
| "destination": "H", | |
| }, | |
| { | |
| "vehicle_id": "V3002", | |
| "lane": 3, | |
| "speed": 50, | |
| "distance_to_intersection": 75, | |
| "direction": "east", | |
| "destination": "B", | |
| }, | |
| ] | |
| } | |
| ), | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": json.dumps( | |
| { | |
| "is_conflict": "yes", | |
| "number_of_conflicts": 1, | |
| "conflict_vehicles": [{"vehicle1_id": "V3001", "vehicle2_id": "V3002"}], | |
| "decisions": ["Potential conflict: Vehicle V3002 must yield to Vehicle V3001"], | |
| "priority_order": {"V3001": 1, "V3002": 2}, | |
| "waiting_times": {"V3001": 0, "V3002": 3}, | |
| } | |
| ), | |
| }, | |
| # ββ Example 4: 3 vehicles, 2 conflicts ββββββββββββββββββββββββββββββββββββ | |
| { | |
| "role": "user", | |
| "content": json.dumps( | |
| { | |
| "vehicles": [ | |
| { | |
| "vehicle_id": "V4001", | |
| "lane": 1, | |
| "speed": 60, | |
| "distance_to_intersection": 65, | |
| "direction": "north", | |
| "destination": "F", | |
| }, | |
| { | |
| "vehicle_id": "V4002", | |
| "lane": 5, | |
| "speed": 55, | |
| "distance_to_intersection": 70, | |
| "direction": "south", | |
| "destination": "D", | |
| }, | |
| { | |
| "vehicle_id": "V4003", | |
| "lane": 3, | |
| "speed": 50, | |
| "distance_to_intersection": 72, | |
| "direction": "east", | |
| "destination": "B", | |
| }, | |
| ] | |
| } | |
| ), | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": json.dumps( | |
| { | |
| "is_conflict": "yes", | |
| "number_of_conflicts": 2, | |
| "conflict_vehicles": [ | |
| {"vehicle1_id": "V4001", "vehicle2_id": "V4002"}, | |
| {"vehicle1_id": "V4001", "vehicle2_id": "V4003"}, | |
| ], | |
| "decisions": [ | |
| "Potential conflict: Vehicle V4002 must yield to Vehicle V4001", | |
| "Potential conflict: Vehicle V4003 must yield to Vehicle V4001", | |
| ], | |
| "priority_order": {"V4001": 1, "V4002": 2, "V4003": 2}, | |
| "waiting_times": {"V4001": 0, "V4002": 3, "V4003": 3}, | |
| } | |
| ), | |
| }, | |
| # ββ Example 5: No conflict β same direction ββββββββββββββββββββββββββββββββ | |
| { | |
| "role": "user", | |
| "content": json.dumps( | |
| { | |
| "vehicles": [ | |
| { | |
| "vehicle_id": "V5001", | |
| "lane": 1, | |
| "speed": 50, | |
| "distance_to_intersection": 100, | |
| "direction": "north", | |
| "destination": "F", | |
| }, | |
| { | |
| "vehicle_id": "V5002", | |
| "lane": 2, | |
| "speed": 45, | |
| "distance_to_intersection": 120, | |
| "direction": "north", | |
| "destination": "E", | |
| }, | |
| ] | |
| } | |
| ), | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": json.dumps( | |
| { | |
| "is_conflict": "no", | |
| "number_of_conflicts": 0, | |
| "conflict_vehicles": [], | |
| "decisions": [], | |
| "priority_order": {"V5001": 1, "V5002": 2}, | |
| "waiting_times": {"V5001": 0, "V5002": 0}, | |
| } | |
| ), | |
| }, | |
| ] | |
| # βββ LLM Client ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # βββ Masri et al. fine-tuning system prompt ββββββββββββββββββββββββββββββββββ | |
| # Used when calling a fine-tuned model β matches the training format exactly | |
| MASRI_SYSTEM_PROMPT = """You are an Urban Intersection Traffic Conflict Detector, responsible for monitoring a four-way intersection with traffic coming from the north, east, south, and west. Each direction has two lanes guiding vehicles to different destinations: | |
| - North: Lane 1 directs vehicles to F and H, Lane 2 directs vehicles to E, D, and C. | |
| - East: Lane 3 leads to H and B, Lane 4 leads to G, E, and F. | |
| - South: Lane 5 directs vehicles to B and D, Lane 6 directs vehicles to A, G, and H. | |
| - West: Lane 7 directs vehicles to D and F, Lane 8 directs vehicles to B, C, and A. | |
| Analyze the traffic data from all directions and lanes, and determine if there is a potential conflict between vehicles at the intersection. Respond only with yes or no.""" | |
| def _vehicles_to_text(vehicles: list) -> str: | |
| """Convert vehicles list to natural language β matches Masri et al. training format.""" | |
| parts = [] | |
| for v in vehicles: | |
| parts.append( | |
| f"Vehicle {v['vehicle_id']} is in lane {v['lane']}, moving {v['direction']} " | |
| f"at a speed of {float(v['speed']):.2f} km/h, and is " | |
| f"{float(v['distance_to_intersection']):.2f} meters away from the intersection, " | |
| f"heading towards {v['destination']}." | |
| ) | |
| return " ".join(parts) | |
| def _build_full_decision(vehicles: list, is_conflict: bool) -> dict: | |
| """ | |
| Build a full structured decision from a yes/no conflict flag. | |
| Uses the rule-based engine to compute priorities and waiting times. | |
| """ | |
| import sys as _sys | |
| from pathlib import Path as _Path | |
| # Try to use the original rule-based engine for detailed decisions | |
| poc_path = str(_Path(__file__).parent.parent / "poc") | |
| if poc_path not in _sys.path: | |
| _sys.path.insert(0, poc_path) | |
| try: | |
| from conflict_detection_orig import ( | |
| detect_conflicts, | |
| parse_intersection_layout, | |
| parse_vehicles, | |
| ) | |
| LAYOUT_DATA = { | |
| "intersection_layout": { | |
| "north": {"1": ["F", "H"], "2": ["E", "D", "C"]}, | |
| "east": {"3": ["H", "B"], "4": ["G", "E", "F"]}, | |
| "south": {"5": ["B", "D"], "6": ["A", "G", "H"]}, | |
| "west": {"7": ["D", "F"], "8": ["B", "C", "A"]}, | |
| } | |
| } | |
| layout = parse_intersection_layout(LAYOUT_DATA) | |
| import warnings as _warnings | |
| with _warnings.catch_warnings(record=True) as _caught: | |
| _warnings.simplefilter("always") | |
| vobjs = parse_vehicles({"vehicles_scenario": vehicles}, layout) | |
| conflicts = detect_conflicts(vobjs) | |
| # Check if rule-based engine had unknown lane issues | |
| unknown_lane_warnings = any( | |
| "unknown lane" in str(w.message) or "not accessible" in str(w.message) for w in _caught | |
| ) | |
| # If unknown lanes detected, trust the LLM yes/no answer directly | |
| if unknown_lane_warnings and is_conflict is not None: | |
| conflicts_for_structure = conflicts if conflicts else [] | |
| effective_conflict = is_conflict | |
| else: | |
| effective_conflict = bool(conflicts) | |
| conflicts_for_structure = conflicts | |
| # Aggregate across conflicts β take max waiting time per vehicle | |
| priority_order: dict = {} | |
| waiting_times: dict = {} | |
| for c in conflicts_for_structure: | |
| for vid, rank in c["priority_order"].items(): | |
| if vid not in priority_order: | |
| priority_order[vid] = rank | |
| for vid, wt in c["waiting_times"].items(): | |
| waiting_times[vid] = max(waiting_times.get(vid, 0), int(wt)) | |
| return { | |
| "is_conflict": "yes" if effective_conflict else "no", | |
| "number_of_conflicts": len(conflicts), | |
| "conflict_vehicles": [ | |
| {"vehicle1_id": c["vehicle1_id"], "vehicle2_id": c["vehicle2_id"]} | |
| for c in conflicts | |
| ], | |
| "decisions": [c["decision"] for c in conflicts], | |
| "priority_order": priority_order, | |
| "waiting_times": waiting_times, | |
| } | |
| except Exception: | |
| # Fallback: return minimal result based on LLM yes/no | |
| vids = [v["vehicle_id"] for v in vehicles] | |
| return { | |
| "is_conflict": "yes" if is_conflict else "no", | |
| "number_of_conflicts": 1 if is_conflict else 0, | |
| "conflict_vehicles": ( | |
| [{"vehicle1_id": vids[0], "vehicle2_id": vids[1]}] | |
| if is_conflict and len(vids) >= 2 | |
| else [] | |
| ), | |
| "decisions": ( | |
| [f"Potential conflict: Vehicle {vids[1]} must yield to Vehicle {vids[0]}"] | |
| if is_conflict and len(vids) >= 2 | |
| else [] | |
| ), | |
| "priority_order": {}, | |
| "waiting_times": {}, | |
| } | |
| class IntersectionLLM: | |
| """ | |
| Wrapper around an OpenAI-compatible LLM for intersection conflict resolution. | |
| Supports: | |
| - Zero-shot inference | |
| - Few-shot prompting | |
| - Fine-tuned model (pass fine_tuned_model_id) | |
| """ | |
| def __init__( | |
| self, | |
| model: str = "gpt-4o-mini", | |
| api_key: str | None = None, | |
| few_shot: bool = True, | |
| temperature: float = 0.0, | |
| fine_tuned_model_id: str | None = None, | |
| ): | |
| self.model = fine_tuned_model_id or model | |
| self.few_shot = few_shot | |
| self.temperature = temperature | |
| self._api_key = api_key or os.environ.get("OPENAI_API_KEY", "") | |
| # Fine-tuned models use Masri et al. format (natural language β yes/no) | |
| self._is_finetuned = bool(fine_tuned_model_id) | |
| try: | |
| from openai import OpenAI | |
| self._client = OpenAI(api_key=self._api_key) | |
| self._available = True | |
| except ImportError: | |
| self._client = None | |
| self._available = False | |
| def _build_messages(self, scenario: dict) -> list[dict]: | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| if self.few_shot: | |
| messages.extend(FEW_SHOT_EXAMPLES) | |
| messages.append({"role": "user", "content": json.dumps(scenario)}) | |
| return messages | |
| def predict(self, scenario: dict) -> dict: | |
| """ | |
| Run inference on a single scenario dict with a 'vehicles' list. | |
| Fine-tuned model: uses Masri et al. natural language format β yes/no, | |
| then enriches with rule-based engine for full structured output. | |
| Base model: uses JSON system prompt β full structured JSON output. | |
| """ | |
| if not self._available: | |
| raise RuntimeError("openai package not installed. Run: pip install openai") | |
| if not self._api_key: | |
| raise RuntimeError("OPENAI_API_KEY environment variable not set.") | |
| vehicles = scenario.get("vehicles", []) | |
| if self._is_finetuned: | |
| # ββ Fine-tuned model: Masri et al. format ββββββββββββββββββββββββ | |
| text = _vehicles_to_text(vehicles) | |
| messages = [ | |
| {"role": "system", "content": MASRI_SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": f"Conflict? (yes/no): {text}", | |
| }, | |
| ] | |
| response = self._client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=5, | |
| ) | |
| answer = response.choices[0].message.content.strip().lower() | |
| is_conflict = "yes" in answer | |
| # Enrich with rule-based engine for full structured output | |
| return _build_full_decision(vehicles, is_conflict) | |
| else: | |
| # ββ Base model: full JSON format ββββββββββββββββββββββββββββββββββ | |
| messages = self._build_messages(scenario) | |
| response = self._client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| response_format={"type": "json_object"}, | |
| ) | |
| content = response.choices[0].message.content | |
| return json.loads(content) | |
| def predict_batch(self, scenarios: list[dict]) -> list[dict]: | |
| """Run inference on a list of scenario dicts.""" | |
| return [self.predict(s) for s in scenarios] | |
| def predict_from_df(self, df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Given a DataFrame (one row per vehicle), group by scenario_id, | |
| build scenarios, run inference, and return a results DataFrame. | |
| """ | |
| results = [] | |
| for scenario_id, group in df.groupby("scenario_id"): | |
| vehicles = group[ | |
| [ | |
| "vehicle_id", | |
| "lane", | |
| "speed", | |
| "distance_to_intersection", | |
| "direction", | |
| "destination", | |
| ] | |
| ].to_dict(orient="records") | |
| scenario = {"vehicles": vehicles} | |
| try: | |
| pred = self.predict(scenario) | |
| pred["scenario_id"] = scenario_id | |
| results.append(pred) | |
| except Exception as exc: | |
| results.append( | |
| { | |
| "scenario_id": scenario_id, | |
| "error": str(exc), | |
| } | |
| ) | |
| return pd.DataFrame(results) | |
| # βββ Fine-tuning helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_finetune_example(row: pd.Series) -> dict: | |
| """ | |
| Convert a raw dataset row into an OpenAI fine-tuning JSONL record. | |
| """ | |
| import ast | |
| vehicles = [ | |
| { | |
| "vehicle_id": row["vehicle_id"], | |
| "lane": int(row["lane"]), | |
| "speed": float(row["speed"]), | |
| "distance_to_intersection": float(row["distance_to_intersection"]), | |
| "direction": row["direction"], | |
| "destination": row["destination"], | |
| } | |
| ] | |
| try: | |
| priority_order = ast.literal_eval(row["priority_order"]) | |
| waiting_times = ast.literal_eval(row["waiting_times"]) | |
| conflict_vehicles = ast.literal_eval(row["conflict_vehicles"]) | |
| decisions = ast.literal_eval(row["decisions"]) | |
| except Exception: | |
| priority_order = {} | |
| waiting_times = {} | |
| conflict_vehicles = [] | |
| decisions = [] | |
| assistant_response = { | |
| "is_conflict": row["is_conflict"], | |
| "number_of_conflicts": int(row["number_of_conflicts"]), | |
| "conflict_vehicles": conflict_vehicles, | |
| "decisions": decisions, | |
| "priority_order": priority_order, | |
| "waiting_times": waiting_times, | |
| } | |
| return { | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": json.dumps({"vehicles": vehicles})}, | |
| {"role": "assistant", "content": json.dumps(assistant_response)}, | |
| ] | |
| } | |
| def prepare_finetune_dataset( | |
| csv_path: str | Path, | |
| output_path: str | Path, | |
| max_examples: int = 500, | |
| ) -> Path: | |
| """ | |
| Convert the raw CSV dataset into an OpenAI fine-tuning JSONL file. | |
| Groups by scenario so each example covers one scenario. | |
| """ | |
| df = pd.read_csv(csv_path) | |
| out = Path(output_path) | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out, "w") as f: | |
| count = 0 | |
| for scenario_id, group in df.groupby("scenario_id"): | |
| if count >= max_examples: | |
| break | |
| # Use first row for labels (all rows in scenario share same labels) | |
| row = group.iloc[0] | |
| import ast | |
| try: | |
| conflict_vehicles = ast.literal_eval(row["conflict_vehicles"]) | |
| decisions = ast.literal_eval(row["decisions"]) | |
| priority_order = ast.literal_eval(row["priority_order"]) | |
| waiting_times = ast.literal_eval(row["waiting_times"]) | |
| except Exception: | |
| conflict_vehicles = [] | |
| decisions = [] | |
| priority_order = {} | |
| waiting_times = {} | |
| vehicles = group[ | |
| [ | |
| "vehicle_id", | |
| "lane", | |
| "speed", | |
| "distance_to_intersection", | |
| "direction", | |
| "destination", | |
| ] | |
| ].to_dict(orient="records") | |
| assistant_response = { | |
| "is_conflict": row["is_conflict"], | |
| "number_of_conflicts": int(row["number_of_conflicts"]), | |
| "conflict_vehicles": conflict_vehicles, | |
| "decisions": decisions, | |
| "priority_order": priority_order, | |
| "waiting_times": waiting_times, | |
| } | |
| example = { | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": json.dumps({"vehicles": vehicles})}, | |
| {"role": "assistant", "content": json.dumps(assistant_response)}, | |
| ] | |
| } | |
| f.write(json.dumps(example) + "\n") | |
| count += 1 | |
| print(f"β Fine-tuning dataset saved β {out} ({count} examples)") | |
| return out | |