kushalExplores commited on
Commit
8efdae2
·
verified ·
1 Parent(s): 20c29f9

Upload train_flatmate_rl_trl.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_flatmate_rl_trl.ipynb +777 -0
train_flatmate_rl_trl.ipynb ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "3905a08b",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Train a Flatmate RL Action Policy with TRL\n",
9
+ "\n",
10
+ "This notebook connects to the Hugging Face Space endpoint, collects rollout examples over OpenEnv websocket sessions, and fine-tunes a small causal language model to emit Flatmate RL JSON actions. The training path uses TRL `SFTTrainer`, which is the most stable starting point for this mixed natural-language plus structured-tool action space.\n",
11
+ "\n",
12
+ "Endpoint used here: `https://huggingface.co/spaces/kushalExplores/flatmate_rl`."
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "54f0ddc0",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# Install notebook dependencies. Restart the kernel after this cell if Colab/Jupyter asks you to.\n",
23
+ "%pip install -q \"trl>=0.23.0\" \"transformers>=4.46.0\" accelerate datasets peft websockets huggingface_hub matplotlib pandas"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "a6a37c34",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from __future__ import annotations\n",
34
+ "\n",
35
+ "import asyncio\n",
36
+ "import json\n",
37
+ "import random\n",
38
+ "from dataclasses import dataclass\n",
39
+ "from pathlib import Path\n",
40
+ "from typing import Any\n",
41
+ "from urllib.parse import urlparse\n",
42
+ "\n",
43
+ "import websockets\n",
44
+ "from datasets import Dataset\n",
45
+ "\n",
46
+ "SPACE_HTTP_URL = \"https://kushalexplores-flatmate-rl.hf.space\"\n",
47
+ "SCENARIOS = [\n",
48
+ " \"task_visit_single\",\n",
49
+ " \"task_visit_single_hidden_flex\",\n",
50
+ " \"task_visit_multi\",\n",
51
+ " \"task_visit_single_seller_followup\",\n",
52
+ "]\n",
53
+ "\n",
54
+ "def ws_url_from_http(base_url: str) -> str:\n",
55
+ " parsed = urlparse(base_url.rstrip(\"/\"))\n",
56
+ " scheme = \"wss\" if parsed.scheme == \"https\" else \"ws\"\n",
57
+ " return f\"{scheme}://{parsed.netloc}/ws\"\n",
58
+ "\n",
59
+ "SPACE_WS_URL = ws_url_from_http(SPACE_HTTP_URL)\n",
60
+ "SPACE_WS_URL"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "markdown",
65
+ "id": "3e10f23e",
66
+ "metadata": {},
67
+ "source": [
68
+ "## Endpoint Client\n",
69
+ "\n",
70
+ "OpenEnv's plain HTTP `/reset` and `/step` endpoints are stateless. Use `/ws` for multi-step episodes because the websocket session keeps one environment instance alive across reset and step calls."
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "f958cca7",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "class FlatmateEndpoint:\n",
81
+ " def __init__(self, ws_url: str = SPACE_WS_URL, timeout_s: float = 120.0):\n",
82
+ " self.ws_url = ws_url\n",
83
+ " self.timeout_s = timeout_s\n",
84
+ "\n",
85
+ " async def __aenter__(self):\n",
86
+ " self.ws = await websockets.connect(self.ws_url, open_timeout=self.timeout_s, ping_timeout=self.timeout_s)\n",
87
+ " return self\n",
88
+ "\n",
89
+ " async def __aexit__(self, exc_type, exc, tb):\n",
90
+ " try:\n",
91
+ " await self.ws.send(json.dumps({\"type\": \"close\"}))\n",
92
+ " finally:\n",
93
+ " await self.ws.close()\n",
94
+ "\n",
95
+ " async def _send(self, payload: dict[str, Any]) -> dict[str, Any]:\n",
96
+ " await self.ws.send(json.dumps(payload))\n",
97
+ " raw = await asyncio.wait_for(self.ws.recv(), timeout=self.timeout_s)\n",
98
+ " message = json.loads(raw)\n",
99
+ " if message.get(\"type\") == \"error\":\n",
100
+ " raise RuntimeError(message.get(\"data\", message))\n",
101
+ " data = message[\"data\"]\n",
102
+ " obs = data.get(\"observation\", {})\n",
103
+ " obs[\"reward\"] = data.get(\"reward\")\n",
104
+ " obs[\"done\"] = data.get(\"done\", False)\n",
105
+ " return obs\n",
106
+ "\n",
107
+ " async def reset(self, scenario_id: str, seed: int | None = None) -> dict[str, Any]:\n",
108
+ " data: dict[str, Any] = {\"scenario_id\": scenario_id}\n",
109
+ " if seed is not None:\n",
110
+ " data[\"seed\"] = seed\n",
111
+ " return await self._send({\"type\": \"reset\", \"data\": data})\n",
112
+ "\n",
113
+ " async def step(self, action: dict[str, Any]) -> dict[str, Any]:\n",
114
+ " return await self._send({\"type\": \"step\", \"data\": action})\n",
115
+ "\n",
116
+ "async def smoke_test_endpoint():\n",
117
+ " async with FlatmateEndpoint() as env:\n",
118
+ " obs = await env.reset(\"task_visit_single\", seed=1)\n",
119
+ " print(obs[\"scenario_id\"], obs[\"status\"])\n",
120
+ " print(obs.get(\"last_user_message\") or obs.get(\"current_user_request\"))\n",
121
+ "\n",
122
+ "await smoke_test_endpoint()"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "fe2ad079",
128
+ "metadata": {},
129
+ "source": [
130
+ "## Rollout Policy for Data Collection\n",
131
+ "\n",
132
+ "This heuristic is intentionally simple. It produces valid-looking action examples from endpoint observations; after SFT, replace it with model generation and keep the same evaluator."
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "id": "611b1ac4",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "def tool_names(obs: dict[str, Any]) -> list[str]:\n",
143
+ " return [str(t.get(\"tool\", t.get(\"tool_name\", \"\"))) for t in obs.get(\"tool_trace\", [])]\n",
144
+ "\n",
145
+ "def action_policy(obs: dict[str, Any]) -> dict[str, Any] | None:\n",
146
+ " tools = tool_names(obs)\n",
147
+ " phase = obs.get(\"phase\", \"buyer\")\n",
148
+ " remaining = set(obs.get(\"remaining_required_fields\", []))\n",
149
+ " scenario_id = obs.get(\"scenario_id\", \"task_visit_single\")\n",
150
+ "\n",
151
+ " if phase == \"seller\" and not obs.get(\"seller_profile_stored\"):\n",
152
+ " if remaining:\n",
153
+ " return {\"action_type\": \"assistant_message\", \"assistant_message\": \"Please share the household dietary setup, who the flat is for, and available visit time slots.\"}\n",
154
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"store_seller_details\", \"tool_arguments\": {}}\n",
155
+ "\n",
156
+ " if not obs.get(\"buyer_profile_stored\"):\n",
157
+ " if \"diet\" in remaining and \"visit_availability\" in remaining:\n",
158
+ " return {\"action_type\": \"assistant_message\", \"assistant_message\": \"Please share your dietary preference and visit availability.\"}\n",
159
+ " if \"diet\" in remaining:\n",
160
+ " return {\"action_type\": \"assistant_message\", \"assistant_message\": \"Please share your dietary preference.\"}\n",
161
+ " if \"visit_availability\" in remaining:\n",
162
+ " return {\"action_type\": \"assistant_message\", \"assistant_message\": \"Please share your visit availability.\"}\n",
163
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"store_user_details\", \"tool_arguments\": {}}\n",
164
+ "\n",
165
+ " if \"search_posts\" not in tools:\n",
166
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"search_posts\", \"tool_arguments\": {}}\n",
167
+ "\n",
168
+ " post_ids = [\"post_031\", \"post_052\"] if scenario_id == \"task_visit_multi\" else [\"post_031\"]\n",
169
+ " if \"match_location_preference\" not in tools:\n",
170
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"match_location_preference\", \"tool_arguments\": {\"post_ids\": post_ids}}\n",
171
+ " if \"get_commute_time\" not in tools:\n",
172
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"get_commute_time\", \"tool_arguments\": {\"post_ids\": post_ids}}\n",
173
+ " if \"check_calendar_slots\" not in tools:\n",
174
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"check_calendar_slots\", \"tool_arguments\": {\"post_ids\": post_ids}}\n",
175
+ " if \"shortlist\" not in tools:\n",
176
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"shortlist\", \"tool_arguments\": {\"post_ids\": post_ids}}\n",
177
+ " if \"contact_poster\" not in tools:\n",
178
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"contact_poster\", \"tool_arguments\": {\"post_id\": post_ids[0], \"time_text\": \"tomorrow 7pm\"}}\n",
179
+ " if \"book_viewing\" not in tools:\n",
180
+ " return {\"action_type\": \"tool_call\", \"tool_name\": \"book_viewing\", \"tool_arguments\": {\"post_id\": post_ids[0], \"time_text\": \"tomorrow 7pm\"}}\n",
181
+ "\n",
182
+ " return None\n",
183
+ "\n",
184
+ "def flatten_observation(obs: dict[str, Any]) -> str:\n",
185
+ " visible = {\n",
186
+ " \"scenario_id\": obs.get(\"scenario_id\"),\n",
187
+ " \"phase\": obs.get(\"phase\"),\n",
188
+ " \"status\": obs.get(\"status\"),\n",
189
+ " \"last_user_message\": obs.get(\"last_user_message\"),\n",
190
+ " \"current_user_request\": obs.get(\"current_user_request\"),\n",
191
+ " \"available_tools\": obs.get(\"available_tools\", []),\n",
192
+ " \"remaining_required_fields\": obs.get(\"remaining_required_fields\", []),\n",
193
+ " \"prerequisites_satisfied\": obs.get(\"prerequisites_satisfied\", {}),\n",
194
+ " \"recent_tool_calls\": obs.get(\"recent_tool_calls\", []),\n",
195
+ " \"last_tool_result\": obs.get(\"last_tool_result\", {}),\n",
196
+ " \"violations\": obs.get(\"violations\", []),\n",
197
+ " \"booked_visits\": obs.get(\"booked_visits\", []),\n",
198
+ " \"feedback_summary\": obs.get(\"feedback_summary\", \"\"),\n",
199
+ " }\n",
200
+ " return json.dumps(visible, ensure_ascii=False, sort_keys=True)\n",
201
+ "\n",
202
+ "def make_training_text(obs: dict[str, Any], action: dict[str, Any]) -> str:\n",
203
+ " return (\n",
204
+ " \"You are a broker policy for the Flatmate RL environment. \"\n",
205
+ " \"Given an observation, return exactly one JSON action.\\n\\n\"\n",
206
+ " f\"Observation:\\n{flatten_observation(obs)}\\n\\n\"\n",
207
+ " f\"Action:\\n{json.dumps(action, ensure_ascii=False, sort_keys=True)}\"\n",
208
+ " )"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "id": "7b22fa13",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "@dataclass\n",
219
+ "class RolloutConfig:\n",
220
+ " train_episodes_per_task: int = 4\n",
221
+ " test_episodes_per_task: int = 2\n",
222
+ " max_steps: int = 20\n",
223
+ " seed: int = 7\n",
224
+ "\n",
225
+ "async def collect_one_episode(\n",
226
+ " scenario_id: str,\n",
227
+ " episode_id: str,\n",
228
+ " episode_idx: int,\n",
229
+ " split: str,\n",
230
+ " seed: int,\n",
231
+ " max_steps: int,\n",
232
+ ") -> list[dict[str, Any]]:\n",
233
+ " rows: list[dict[str, Any]] = []\n",
234
+ " async with FlatmateEndpoint() as env:\n",
235
+ " obs = await env.reset(scenario_id, seed=seed)\n",
236
+ " total_reward = 0.0\n",
237
+ " for step_idx in range(max_steps):\n",
238
+ " action = action_policy(obs)\n",
239
+ " if action is None or obs.get(\"done\"):\n",
240
+ " break\n",
241
+ " rows.append({\n",
242
+ " \"text\": make_training_text(obs, action),\n",
243
+ " \"episode_id\": episode_id,\n",
244
+ " \"episode_idx\": episode_idx,\n",
245
+ " \"split\": split,\n",
246
+ " \"scenario_id\": scenario_id,\n",
247
+ " \"seed\": seed,\n",
248
+ " \"step\": step_idx,\n",
249
+ " \"action\": json.dumps(action, sort_keys=True),\n",
250
+ " })\n",
251
+ " obs = await env.step(action)\n",
252
+ " total_reward += float(obs.get(\"reward\") or obs.get(\"step_reward\") or 0.0)\n",
253
+ " if obs.get(\"done\"):\n",
254
+ " break\n",
255
+ " print(f\"split={split:5s} episode={episode_id} scenario={scenario_id} rows={len(rows)} total_reward={total_reward:.2f}\")\n",
256
+ " return rows\n",
257
+ "\n",
258
+ "async def collect_balanced_rollouts(config: RolloutConfig = RolloutConfig()) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:\n",
259
+ " train_rows: list[dict[str, Any]] = []\n",
260
+ " test_rows: list[dict[str, Any]] = []\n",
261
+ " episode_idx = 0\n",
262
+ "\n",
263
+ " for scenario_idx, scenario_id in enumerate(SCENARIOS):\n",
264
+ " for task_episode_idx in range(config.train_episodes_per_task):\n",
265
+ " seed = config.seed + scenario_idx * 100 + task_episode_idx\n",
266
+ " episode_id = f\"train_{scenario_id}_{task_episode_idx:03d}\"\n",
267
+ " train_rows.extend(await collect_one_episode(scenario_id, episode_id, episode_idx, \"train\", seed, config.max_steps))\n",
268
+ " episode_idx += 1\n",
269
+ "\n",
270
+ " for task_episode_idx in range(config.test_episodes_per_task):\n",
271
+ " seed = 900 + config.seed + scenario_idx * 100 + task_episode_idx\n",
272
+ " episode_id = f\"test_{scenario_id}_{task_episode_idx:03d}\"\n",
273
+ " test_rows.extend(await collect_one_episode(scenario_id, episode_id, episode_idx, \"test\", seed, config.max_steps))\n",
274
+ " episode_idx += 1\n",
275
+ "\n",
276
+ " return train_rows, test_rows\n",
277
+ "\n",
278
+ "print(\"Note: seeded resets create value variants while preserving the same episode flow. Upload the updated Space before using this against the hosted endpoint.\")\n",
279
+ "train_rows, test_rows = await collect_balanced_rollouts(\n",
280
+ " RolloutConfig(train_episodes_per_task=4, test_episodes_per_task=2, max_steps=20, seed=7)\n",
281
+ ")\n",
282
+ "rows = train_rows + test_rows\n",
283
+ "dataset = Dataset.from_list(rows)\n",
284
+ "train_dataset = Dataset.from_list(train_rows)\n",
285
+ "test_dataset = Dataset.from_list(test_rows)\n",
286
+ "\n",
287
+ "print({\n",
288
+ " \"train_rows\": len(train_dataset),\n",
289
+ " \"test_rows\": len(test_dataset),\n",
290
+ " \"total_rows\": len(dataset),\n",
291
+ " \"train_episodes\": len(set(train_dataset[\"episode_id\"])),\n",
292
+ " \"test_episodes\": len(set(test_dataset[\"episode_id\"])),\n",
293
+ "})\n",
294
+ "print(\"train scenarios\", sorted(set(train_dataset[\"scenario_id\"])))\n",
295
+ "print(\"test scenarios\", sorted(set(test_dataset[\"scenario_id\"])))\n",
296
+ "print(\"train episodes by scenario\")\n",
297
+ "display(pd.DataFrame(train_rows).groupby(\"scenario_id\")[\"episode_id\"].nunique().rename(\"episodes\"))\n",
298
+ "print(\"test episodes by scenario\")\n",
299
+ "display(pd.DataFrame(test_rows).groupby(\"scenario_id\")[\"episode_id\"].nunique().rename(\"episodes\"))\n",
300
+ "{\"train\": train_dataset, \"test\": test_dataset}"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "665b46fa",
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "from peft import LoraConfig\n",
311
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
312
+ "from trl import SFTConfig, SFTTrainer\n",
313
+ "\n",
314
+ "MODEL_NAME = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
315
+ "OUTPUT_DIR = \"flatmate-rl-action-policy\"\n",
316
+ "\n",
317
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
318
+ "if tokenizer.pad_token is None:\n",
319
+ " tokenizer.pad_token = tokenizer.eos_token\n",
320
+ "\n",
321
+ "model = AutoModelForCausalLM.from_pretrained(\n",
322
+ " MODEL_NAME,\n",
323
+ " trust_remote_code=True,\n",
324
+ " device_map=\"auto\",\n",
325
+ ")\n",
326
+ "model.config.use_cache = False\n",
327
+ "\n",
328
+ "peft_config = LoraConfig(\n",
329
+ " r=16,\n",
330
+ " lora_alpha=32,\n",
331
+ " lora_dropout=0.05,\n",
332
+ " bias=\"none\",\n",
333
+ " task_type=\"CAUSAL_LM\",\n",
334
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
335
+ ")\n",
336
+ "\n",
337
+ "training_args = SFTConfig(\n",
338
+ " output_dir=OUTPUT_DIR,\n",
339
+ " dataset_text_field=\"text\",\n",
340
+ " max_length=1536,\n",
341
+ " per_device_train_batch_size=1,\n",
342
+ " gradient_accumulation_steps=8,\n",
343
+ " num_train_epochs=1,\n",
344
+ " learning_rate=5e-5,\n",
345
+ " logging_steps=5,\n",
346
+ " save_steps=50,\n",
347
+ " save_total_limit=2,\n",
348
+ " packing=False,\n",
349
+ " report_to=\"none\",\n",
350
+ ")\n",
351
+ "\n",
352
+ "trainer = SFTTrainer(\n",
353
+ " model=model,\n",
354
+ " args=training_args,\n",
355
+ " train_dataset=train_dataset,\n",
356
+ " eval_dataset=test_dataset,\n",
357
+ " processing_class=tokenizer,\n",
358
+ " peft_config=peft_config,\n",
359
+ ")\n",
360
+ "\n",
361
+ "train_result = trainer.train()\n",
362
+ "test_metrics = trainer.evaluate(eval_dataset=test_dataset)\n",
363
+ "train_log_history = trainer.state.log_history\n",
364
+ "trainer.save_model(OUTPUT_DIR)\n",
365
+ "tokenizer.save_pretrained(OUTPUT_DIR)\n",
366
+ "print(\"heldout_test_metrics\", test_metrics)\n",
367
+ "train_result"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "markdown",
372
+ "id": "22d9fc14",
373
+ "metadata": {},
374
+ "source": [
375
+ "## Training Log\n",
376
+ "\n",
377
+ "Plot the logged training loss over optimizer steps."
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": null,
383
+ "id": "c3e44d74",
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "import json\n",
388
+ "from pathlib import Path\n",
389
+ "\n",
390
+ "import matplotlib.pyplot as plt\n",
391
+ "import pandas as pd\n",
392
+ "\n",
393
+ "log_path = Path(OUTPUT_DIR) / \"train_log_history.json\"\n",
394
+ "log_path.parent.mkdir(parents=True, exist_ok=True)\n",
395
+ "log_path.write_text(json.dumps(train_log_history, indent=2))\n",
396
+ "\n",
397
+ "def plot_training_log(log_history, title: str = \"SFT training loss\"):\n",
398
+ " rows = [row for row in log_history if \"loss\" in row and \"step\" in row]\n",
399
+ " if not rows:\n",
400
+ " print(\"No loss rows found in trainer.state.log_history yet.\")\n",
401
+ " return None\n",
402
+ " df = pd.DataFrame(rows)\n",
403
+ " ax = df.plot(x=\"step\", y=\"loss\", marker=\"o\", figsize=(7, 4), title=title)\n",
404
+ " ax.set_xlabel(\"optimizer step\")\n",
405
+ " ax.set_ylabel(\"loss\")\n",
406
+ " ax.grid(True, alpha=0.3)\n",
407
+ " plt.show()\n",
408
+ " return df\n",
409
+ "\n",
410
+ "train_log_df = plot_training_log(train_log_history)\n",
411
+ "train_log_df.tail() if train_log_df is not None else None"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": null,
417
+ "id": "539548f7",
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": [
421
+ "import torch\n",
422
+ "from peft import AutoPeftModelForCausalLM\n",
423
+ "\n",
424
+ "# Load both the base model and the saved fine-tuned adapter from disk for comparison.\n",
425
+ "try:\n",
426
+ " del model\n",
427
+ "except NameError:\n",
428
+ " pass\n",
429
+ "\n",
430
+ "base_model_for_eval = AutoModelForCausalLM.from_pretrained(\n",
431
+ " MODEL_NAME,\n",
432
+ " trust_remote_code=True,\n",
433
+ " device_map=\"auto\",\n",
434
+ ")\n",
435
+ "base_model_for_eval.eval()\n",
436
+ "base_model_for_eval.config.use_cache = False\n",
437
+ "\n",
438
+ "loaded_model_for_eval = AutoPeftModelForCausalLM.from_pretrained(OUTPUT_DIR, device_map=\"auto\")\n",
439
+ "loaded_model_for_eval.eval()\n",
440
+ "loaded_model_for_eval.config.use_cache = False\n",
441
+ "active_model = loaded_model_for_eval\n",
442
+ "print(f\"Loaded base model from {MODEL_NAME}\")\n",
443
+ "print(f\"Loaded saved SFT model from {OUTPUT_DIR}\")\n",
444
+ "\n",
445
+ "TEST_SEEDS = (901, 902)\n",
446
+ "\n",
447
+ "\n",
448
+ "def prompt_from_observation(obs: dict[str, Any]) -> str:\n",
449
+ " return (\n",
450
+ " \"You are a broker policy for the Flatmate RL environment. \"\n",
451
+ " \"Given an observation, return exactly one JSON action.\\n\\n\"\n",
452
+ " f\"Observation:\\n{flatten_observation(obs)}\\n\\nAction:\\n\"\n",
453
+ " )\n",
454
+ "\n",
455
+ "\n",
456
+ "def _first_balanced_json(text: str) -> str:\n",
457
+ " start = text.find(\"{\")\n",
458
+ " if start == -1:\n",
459
+ " raise ValueError(f\"No JSON object found in generation: {text!r}\")\n",
460
+ " depth = 0\n",
461
+ " in_string = False\n",
462
+ " escape = False\n",
463
+ " for index, char in enumerate(text[start:], start=start):\n",
464
+ " if escape:\n",
465
+ " escape = False\n",
466
+ " continue\n",
467
+ " if char == \"\\\\\" and in_string:\n",
468
+ " escape = True\n",
469
+ " continue\n",
470
+ " if char == '\\\"':\n",
471
+ " in_string = not in_string\n",
472
+ " continue\n",
473
+ " if in_string:\n",
474
+ " continue\n",
475
+ " if char == \"{\":\n",
476
+ " depth += 1\n",
477
+ " elif char == \"}\":\n",
478
+ " depth -= 1\n",
479
+ " if depth == 0:\n",
480
+ " return text[start : index + 1]\n",
481
+ " raise ValueError(f\"Unterminated JSON object in generation: {text!r}\")\n",
482
+ "\n",
483
+ "\n",
484
+ "def normalize_action(action: dict[str, Any]) -> dict[str, Any]:\n",
485
+ " if action.get(\"action_type\") == \"assistant_message\" and str(action.get(\"assistant_message\", \"\")).strip():\n",
486
+ " return {\n",
487
+ " \"action_type\": \"assistant_message\",\n",
488
+ " \"assistant_message\": str(action[\"assistant_message\"]),\n",
489
+ " }\n",
490
+ " if action.get(\"action_type\") == \"tool_call\" and str(action.get(\"tool_name\", \"\")).strip():\n",
491
+ " tool_arguments = action.get(\"tool_arguments\", {})\n",
492
+ " return {\n",
493
+ " \"action_type\": \"tool_call\",\n",
494
+ " \"tool_name\": str(action[\"tool_name\"]),\n",
495
+ " \"tool_arguments\": tool_arguments if isinstance(tool_arguments, dict) else {},\n",
496
+ " }\n",
497
+ " raise ValueError(f\"Invalid action shape: {action!r}\")\n",
498
+ "\n",
499
+ "\n",
500
+ "def parse_action(text: str) -> dict[str, Any]:\n",
501
+ " return normalize_action(json.loads(_first_balanced_json(text)))\n",
502
+ "\n",
503
+ "\n",
504
+ "def heuristic_policy(obs: dict[str, Any]) -> dict[str, Any]:\n",
505
+ " action = action_policy(obs)\n",
506
+ " if action is None:\n",
507
+ " return {\"action_type\": \"assistant_message\", \"assistant_message\": \"Could you confirm the details needed for scheduling?\"}\n",
508
+ " return action\n",
509
+ "\n",
510
+ "\n",
511
+ "def raw_generate_action_text(obs: dict[str, Any]) -> str:\n",
512
+ " prompt = prompt_from_observation(obs) + \"{\"\n",
513
+ " inputs = tokenizer(prompt, return_tensors=\"pt\").to(active_model.device)\n",
514
+ " active_model.generation_config.do_sample = False\n",
515
+ " active_model.generation_config.temperature = None\n",
516
+ " active_model.generation_config.top_p = None\n",
517
+ " active_model.generation_config.top_k = None\n",
518
+ " with torch.no_grad():\n",
519
+ " output = active_model.generate(\n",
520
+ " **inputs,\n",
521
+ " max_new_tokens=80,\n",
522
+ " do_sample=False,\n",
523
+ " repetition_penalty=1.15,\n",
524
+ " no_repeat_ngram_size=3,\n",
525
+ " eos_token_id=tokenizer.eos_token_id,\n",
526
+ " pad_token_id=tokenizer.eos_token_id,\n",
527
+ " )\n",
528
+ " return \"{\" + tokenizer.decode(output[0][inputs[\"input_ids\"].shape[-1]:], skip_special_tokens=True)\n",
529
+ "\n",
530
+ "\n",
531
+ "def model_action_or_error(obs: dict[str, Any]) -> tuple[dict[str, Any] | None, str, str]:\n",
532
+ " raw = raw_generate_action_text(obs)\n",
533
+ " try:\n",
534
+ " return parse_action(raw), raw, \"\"\n",
535
+ " except Exception as exc:\n",
536
+ " return None, raw, str(exc)\n",
537
+ "\n",
538
+ "\n",
539
+ "async def sanity_check_generations(model_label: str, limit: int = 4):\n",
540
+ " rows = []\n",
541
+ " for scenario_id in SCENARIOS[:limit]:\n",
542
+ " async with FlatmateEndpoint() as env:\n",
543
+ " obs = await env.reset(scenario_id, seed=TEST_SEEDS[0])\n",
544
+ " action, raw, error = model_action_or_error(obs)\n",
545
+ " rows.append({\n",
546
+ " \"model\": model_label,\n",
547
+ " \"scenario_id\": scenario_id,\n",
548
+ " \"json_ok\": action is not None,\n",
549
+ " \"raw\": raw[:240],\n",
550
+ " \"parsed_action\": action,\n",
551
+ " \"error\": error,\n",
552
+ " })\n",
553
+ " return pd.DataFrame(rows)\n",
554
+ "\n",
555
+ "\n",
556
+ "async def evaluate_heuristic(label: str = \"heuristic\", scenarios=SCENARIOS, seeds=TEST_SEEDS, max_steps: int = 20, verbose: bool = False):\n",
557
+ " rows = []\n",
558
+ " for scenario_id in scenarios:\n",
559
+ " for seed in seeds:\n",
560
+ " async with FlatmateEndpoint() as env:\n",
561
+ " obs = await env.reset(scenario_id, seed=seed)\n",
562
+ " total_reward = 0.0\n",
563
+ " steps = 0\n",
564
+ " for step_idx in range(max_steps):\n",
565
+ " action = heuristic_policy(obs)\n",
566
+ " if verbose:\n",
567
+ " print(label, scenario_id, seed, step_idx, action)\n",
568
+ " obs = await env.step(action)\n",
569
+ " steps = step_idx + 1\n",
570
+ " total_reward += float(obs.get(\"reward\") or obs.get(\"step_reward\") or 0.0)\n",
571
+ " if obs.get(\"done\"):\n",
572
+ " break\n",
573
+ " rows.append({\n",
574
+ " \"policy\": label,\n",
575
+ " \"scenario_id\": scenario_id,\n",
576
+ " \"seed\": seed,\n",
577
+ " \"total_reward\": total_reward,\n",
578
+ " \"done\": bool(obs.get(\"done\")),\n",
579
+ " \"bookings\": len(obs.get(\"booked_visits\", [])),\n",
580
+ " \"violations\": len(obs.get(\"violations\", [])),\n",
581
+ " \"steps\": steps,\n",
582
+ " \"parse_errors\": 0,\n",
583
+ " })\n",
584
+ " return rows\n",
585
+ "\n",
586
+ "\n",
587
+ "async def evaluate_model_policy(label: str, scenarios=SCENARIOS, seeds=TEST_SEEDS, max_steps: int = 20, verbose: bool = False):\n",
588
+ " rows = []\n",
589
+ " for scenario_id in scenarios:\n",
590
+ " for seed in seeds:\n",
591
+ " async with FlatmateEndpoint() as env:\n",
592
+ " obs = await env.reset(scenario_id, seed=seed)\n",
593
+ " total_reward = 0.0\n",
594
+ " steps = 0\n",
595
+ " parse_errors = 0\n",
596
+ " last_error = \"\"\n",
597
+ " for step_idx in range(max_steps):\n",
598
+ " action, raw, error = model_action_or_error(obs)\n",
599
+ " if action is None:\n",
600
+ " parse_errors += 1\n",
601
+ " last_error = error\n",
602
+ " if verbose:\n",
603
+ " print(label, scenario_id, seed, f\"step={step_idx:02d}\", \"PARSE_ERROR\", raw[:220])\n",
604
+ " total_reward -= 1.0\n",
605
+ " break\n",
606
+ " if verbose:\n",
607
+ " print(label, scenario_id, seed, f\"step={step_idx:02d}\", action)\n",
608
+ " obs = await env.step(action)\n",
609
+ " steps = step_idx + 1\n",
610
+ " total_reward += float(obs.get(\"reward\") or obs.get(\"step_reward\") or 0.0)\n",
611
+ " if obs.get(\"done\"):\n",
612
+ " break\n",
613
+ " rows.append({\n",
614
+ " \"policy\": label,\n",
615
+ " \"scenario_id\": scenario_id,\n",
616
+ " \"seed\": seed,\n",
617
+ " \"total_reward\": total_reward,\n",
618
+ " \"done\": bool(obs.get(\"done\")),\n",
619
+ " \"bookings\": len(obs.get(\"booked_visits\", [])),\n",
620
+ " \"violations\": len(obs.get(\"violations\", [])),\n",
621
+ " \"steps\": steps,\n",
622
+ " \"parse_errors\": parse_errors,\n",
623
+ " \"last_error\": last_error,\n",
624
+ " })\n",
625
+ " return rows\n",
626
+ "\n",
627
+ "\n",
628
+ "async def run_model_inference_each_task(label: str, seed: int = TEST_SEEDS[0], max_steps: int = 20):\n",
629
+ " rows = []\n",
630
+ " for scenario_id in SCENARIOS:\n",
631
+ " print(f\"\\n=== {label}: {scenario_id} ===\")\n",
632
+ " async with FlatmateEndpoint() as env:\n",
633
+ " obs = await env.reset(scenario_id, seed=seed)\n",
634
+ " total_reward = 0.0\n",
635
+ " steps = 0\n",
636
+ " parse_errors = 0\n",
637
+ " for step_idx in range(max_steps):\n",
638
+ " action, raw, error = model_action_or_error(obs)\n",
639
+ " if action is None:\n",
640
+ " parse_errors += 1\n",
641
+ " print(f\"step={step_idx:02d} PARSE_ERROR={error}\")\n",
642
+ " print(\"raw=\", repr(raw[:300]))\n",
643
+ " total_reward -= 1.0\n",
644
+ " break\n",
645
+ " print(f\"step={step_idx:02d} action={action}\")\n",
646
+ " obs = await env.step(action)\n",
647
+ " steps = step_idx + 1\n",
648
+ " total_reward += float(obs.get(\"reward\") or obs.get(\"step_reward\") or 0.0)\n",
649
+ " if obs.get(\"done\"):\n",
650
+ " break\n",
651
+ " result = {\n",
652
+ " \"policy\": label,\n",
653
+ " \"scenario_id\": scenario_id,\n",
654
+ " \"seed\": seed,\n",
655
+ " \"total_reward\": total_reward,\n",
656
+ " \"done\": bool(obs.get(\"done\")),\n",
657
+ " \"bookings\": len(obs.get(\"booked_visits\", [])),\n",
658
+ " \"violations\": len(obs.get(\"violations\", [])),\n",
659
+ " \"steps\": steps,\n",
660
+ " \"parse_errors\": parse_errors,\n",
661
+ " }\n",
662
+ " print(\"result=\", result)\n",
663
+ " rows.append(result)\n",
664
+ " return pd.DataFrame(rows)\n",
665
+ "\n",
666
+ "\n",
667
+ "active_model = base_model_for_eval\n",
668
+ "base_generation_sanity_df = await sanity_check_generations(\"base_model\")\n",
669
+ "base_per_task_inference_df = await run_model_inference_each_task(\"base_model\")\n",
670
+ "base_model_eval = await evaluate_model_policy(\"base_model\")\n",
671
+ "\n",
672
+ "active_model = loaded_model_for_eval\n",
673
+ "loaded_generation_sanity_df = await sanity_check_generations(\"sft_loaded\")\n",
674
+ "loaded_per_task_inference_df = await run_model_inference_each_task(\"sft_loaded\")\n",
675
+ "loaded_eval = await evaluate_model_policy(\"sft_loaded\")\n",
676
+ "\n",
677
+ "per_task_inference_df = pd.concat([base_per_task_inference_df, loaded_per_task_inference_df], ignore_index=True)\n",
678
+ "generation_sanity_df = pd.concat([base_generation_sanity_df, loaded_generation_sanity_df], ignore_index=True)\n",
679
+ "heuristic_eval = await evaluate_heuristic(\"heuristic\")\n",
680
+ "\n",
681
+ "eval_rows = heuristic_eval + base_model_eval + loaded_eval\n",
682
+ "eval_df = pd.DataFrame(eval_rows)\n",
683
+ "display(generation_sanity_df)\n",
684
+ "display(per_task_inference_df)\n",
685
+ "eval_df"
686
+ ]
687
+ },
688
+ {
689
+ "cell_type": "markdown",
690
+ "id": "e1e70c8f",
691
+ "metadata": {},
692
+ "source": [
693
+ "## Performance Comparison\n",
694
+ "\n",
695
+ "Compare heuristic rollout behavior against the trained SFT policy on the same scenarios and seeds."
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "code",
700
+ "execution_count": null,
701
+ "id": "e8931930",
702
+ "metadata": {},
703
+ "outputs": [],
704
+ "source": [
705
+ "def plot_policy_comparison(eval_df, title: str = \"Base vs SFT loaded-model comparison\"):\n",
706
+ " if eval_df is None or eval_df.empty or \"policy\" not in eval_df.columns:\n",
707
+ " print(\"eval_df is empty; run the evaluation cell first.\")\n",
708
+ " return pd.DataFrame()\n",
709
+ "\n",
710
+ " summary = (\n",
711
+ " eval_df.groupby(\"policy\", as_index=True)\n",
712
+ " .agg(\n",
713
+ " avg_reward=(\"total_reward\", \"mean\"),\n",
714
+ " completion_rate=(\"done\", \"mean\"),\n",
715
+ " avg_bookings=(\"bookings\", \"mean\"),\n",
716
+ " avg_violations=(\"violations\", \"mean\"),\n",
717
+ " avg_steps=(\"steps\", \"mean\"),\n",
718
+ " avg_parse_errors=(\"parse_errors\", \"mean\") if \"parse_errors\" in eval_df.columns else (\"steps\", \"size\"),\n",
719
+ " )\n",
720
+ " .sort_index()\n",
721
+ " )\n",
722
+ " plot_cols = [\"avg_reward\", \"completion_rate\", \"avg_bookings\", \"avg_violations\", \"avg_parse_errors\"]\n",
723
+ " axes = summary[plot_cols].plot(\n",
724
+ " kind=\"bar\",\n",
725
+ " subplots=True,\n",
726
+ " layout=(3, 2),\n",
727
+ " figsize=(10, 9),\n",
728
+ " legend=False,\n",
729
+ " title=title,\n",
730
+ " )\n",
731
+ " for ax in axes.ravel():\n",
732
+ " ax.grid(axis=\"y\", alpha=0.3)\n",
733
+ " ax.set_xlabel(\"\")\n",
734
+ " plt.tight_layout()\n",
735
+ " plt.show()\n",
736
+ " return summary\n",
737
+ "\n",
738
+ "comparison_summary = plot_policy_comparison(eval_df)\n",
739
+ "comparison_summary"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "code",
744
+ "execution_count": null,
745
+ "id": "a9fd3807",
746
+ "metadata": {},
747
+ "outputs": [],
748
+ "source": [
749
+ "# Optional: upload the trained adapter/model to the Hub.\n",
750
+ "# from huggingface_hub import notebook_login\n",
751
+ "# notebook_login()\n",
752
+ "# trainer.push_to_hub(\"flatmate-rl-action-policy\")"
753
+ ]
754
+ }
755
+ ],
756
+ "metadata": {
757
+ "kernelspec": {
758
+ "display_name": "Python 3",
759
+ "language": "python",
760
+ "name": "python3"
761
+ },
762
+ "language_info": {
763
+ "codemirror_mode": {
764
+ "name": "ipython",
765
+ "version": 3
766
+ },
767
+ "file_extension": ".py",
768
+ "mimetype": "text/x-python",
769
+ "name": "python",
770
+ "nbconvert_exporter": "python",
771
+ "pygments_lexer": "ipython3",
772
+ "version": "3.11"
773
+ }
774
+ },
775
+ "nbformat": 4,
776
+ "nbformat_minor": 5
777
+ }