databoysu commited on
Commit
b4f37fd
Β·
1 Parent(s): beb93a1

fixing state machine

Browse files
README.md CHANGED
@@ -2,7 +2,6 @@
2
  title: TraceFix-RL
3
  emoji: πŸ§‘β€πŸ’»
4
  colorFrom: blue
5
- colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  app_port: 7860
@@ -109,12 +108,12 @@ Server endpoints:
109
  - `--easy`: run episode using easy-tier curriculum sampling.
110
  - `--medium`: run episode using medium-tier curriculum sampling.
111
  - `--hard`: run episode using hard-tier curriculum sampling.
112
- - `--debug`: print raw model response snippets for troubleshooting.
113
 
114
  Example:
115
 
116
  ```bash
117
- python inference.py --medium --debug
118
  ```
119
 
120
  The script also enforces a model-thinking/output cap:
 
2
  title: TraceFix-RL
3
  emoji: πŸ§‘β€πŸ’»
4
  colorFrom: blue
 
5
  sdk: docker
6
  pinned: false
7
  app_port: 7860
 
108
  - `--easy`: run episode using easy-tier curriculum sampling.
109
  - `--medium`: run episode using medium-tier curriculum sampling.
110
  - `--hard`: run episode using hard-tier curriculum sampling.
111
+ - `--thought`: include model thought traces in internal prompt history.
112
 
113
  Example:
114
 
115
  ```bash
116
+ python inference.py --medium --thought
117
  ```
118
 
119
  The script also enforces a model-thinking/output cap:
__pycache__/models.cpython-312.pyc CHANGED
Binary files a/__pycache__/models.cpython-312.pyc and b/__pycache__/models.cpython-312.pyc differ
 
__pycache__/tasks.cpython-312.pyc CHANGED
Binary files a/__pycache__/tasks.cpython-312.pyc and b/__pycache__/tasks.cpython-312.pyc differ
 
inference.py CHANGED
@@ -47,8 +47,7 @@ BENCHMARK = os.getenv("BENCHMARK", "tracefix_rl")
47
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
48
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
49
  THINKING_TOKEN_LIMIT = int(os.getenv("THINKING_TOKEN_LIMIT", "512"))
50
- # Approximation used for hard truncation before sending to server.
51
- THINKING_CHAR_LIMIT = THINKING_TOKEN_LIMIT * 4
52
 
53
  SYSTEM_PROMPT = (
54
  "You are controlling a Python debugging RL environment. "
@@ -100,7 +99,7 @@ def _extract_json(text: str) -> dict[str, Any]:
100
  except json.JSONDecodeError:
101
  pass
102
 
103
- return {"action_type": "RUN_TESTS"}
104
 
105
 
106
  def _build_observation_text(observation: Any) -> str:
@@ -116,31 +115,26 @@ def _build_observation_text(observation: Any) -> str:
116
 
117
 
118
  def _get_model_action(
119
- client: OpenAI, observation: Any, history: list[str], debug: bool = False
120
- ) -> dict[str, Any]:
121
  obs_text = _build_observation_text(observation)
122
  user_prompt = (
123
  "Pick the single best next action and return only JSON.\n\n"
124
  f"{obs_text}\n\n"
125
  f"history:\n{chr(10).join(history[-5:]) if history else 'none'}"
126
  )
127
- try:
128
- completion = client.chat.completions.create(
129
- model=MODEL_NAME,
130
- messages=[
131
- {"role": "system", "content": SYSTEM_PROMPT},
132
- {"role": "user", "content": user_prompt},
133
- ],
134
- temperature=0.0,
135
- max_tokens=THINKING_TOKEN_LIMIT,
136
- stream=False,
137
- )
138
- response_text = (completion.choices[0].message.content or "").strip()
139
- if debug:
140
- print(f"[DEBUG] raw_model_response={response_text[:500]}", flush=True)
141
- action = _extract_json(response_text)
142
- except Exception:
143
- action = {"action_type": "RUN_TESTS"}
144
 
145
  if action.get("action_type") not in {
146
  "VIEW_CODE",
@@ -150,27 +144,20 @@ def _get_model_action(
150
  "RESET_TO_ORIGINAL",
151
  "SUBMIT",
152
  }:
153
- action = {"action_type": "RUN_TESTS"}
154
 
155
- return action
156
 
157
 
158
  def _to_code_action(action_dict: dict[str, Any]) -> CodeAction:
159
- thought = action_dict.get("thought")
160
- if isinstance(thought, str):
161
- thought = thought[:THINKING_CHAR_LIMIT]
162
-
163
  payload = {
164
- "action_type": action_dict.get("action_type", "RUN_TESTS"),
165
- "thought": thought,
166
  "start_line": action_dict.get("start_line"),
167
  "end_line": action_dict.get("end_line"),
168
  "new_code_block": action_dict.get("new_code_block"),
169
  }
170
- try:
171
- return CodeAction(**payload)
172
- except Exception:
173
- return CodeAction(action_type="RUN_TESTS")
174
 
175
 
176
  def _compute_score(step_result: Any, rewards: list[float]) -> float:
@@ -184,7 +171,7 @@ def _compute_score(step_result: Any, rewards: list[float]) -> float:
184
  return max(0.0, min(1.0, float(raw)))
185
 
186
 
187
- async def run(difficulty: Optional[str] = None, debug: bool = False) -> None:
188
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
189
 
190
  env: Optional[TraceFixRLEnv] = None
@@ -214,8 +201,34 @@ async def run(difficulty: Optional[str] = None, debug: bool = False) -> None:
214
  if result.done:
215
  break
216
 
217
- action_dict = _get_model_action(client, result.observation, history, debug=debug)
218
- action = _to_code_action(action_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  result = await env.step(action)
220
 
221
  reward = float(result.reward or 0.0)
@@ -229,7 +242,7 @@ async def run(difficulty: Optional[str] = None, debug: bool = False) -> None:
229
 
230
  rewards.append(reward)
231
  steps_taken = step
232
- history.append(f"step={step} action={action_str} reward={reward:.2f}")
233
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
234
 
235
  if done:
@@ -243,10 +256,6 @@ async def run(difficulty: Optional[str] = None, debug: bool = False) -> None:
243
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
244
  started = True
245
  msg = str(exc).replace("\n", " ")
246
- if steps_taken == 0:
247
- log_step(step=1, action="RUN_TESTS", reward=0.0, done=False, error=msg)
248
- steps_taken = 1
249
- rewards.append(0.0)
250
  score = 0.0
251
  success = False
252
  finally:
@@ -264,7 +273,7 @@ if __name__ == "__main__":
264
  group.add_argument("--easy", action="store_true", help="Run on easy curriculum tier.")
265
  group.add_argument("--medium", action="store_true", help="Run on medium curriculum tier.")
266
  group.add_argument("--hard", action="store_true", help="Run on hard curriculum tier.")
267
- parser.add_argument("--debug", action="store_true", help="Print debug model output snippets.")
268
  args = parser.parse_args()
269
 
270
  difficulty: Optional[str] = None
@@ -275,4 +284,4 @@ if __name__ == "__main__":
275
  elif args.hard:
276
  difficulty = "hard"
277
 
278
- asyncio.run(run(difficulty=difficulty, debug=args.debug))
 
47
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
48
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.99"))
49
  THINKING_TOKEN_LIMIT = int(os.getenv("THINKING_TOKEN_LIMIT", "512"))
50
+ MAX_PARSE_RETRIES = 3
 
51
 
52
  SYSTEM_PROMPT = (
53
  "You are controlling a Python debugging RL environment. "
 
99
  except json.JSONDecodeError:
100
  pass
101
 
102
+ raise ValueError("Invalid JSON response.")
103
 
104
 
105
  def _build_observation_text(observation: Any) -> str:
 
115
 
116
 
117
  def _get_model_action(
118
+ client: OpenAI, observation: Any, history: list[str]
119
+ ) -> tuple[dict[str, Any], str]:
120
  obs_text = _build_observation_text(observation)
121
  user_prompt = (
122
  "Pick the single best next action and return only JSON.\n\n"
123
  f"{obs_text}\n\n"
124
  f"history:\n{chr(10).join(history[-5:]) if history else 'none'}"
125
  )
126
+ completion = client.chat.completions.create(
127
+ model=MODEL_NAME,
128
+ messages=[
129
+ {"role": "system", "content": SYSTEM_PROMPT},
130
+ {"role": "user", "content": user_prompt},
131
+ ],
132
+ temperature=0.0,
133
+ max_tokens=THINKING_TOKEN_LIMIT,
134
+ stream=False,
135
+ )
136
+ response_text = (completion.choices[0].message.content or "").strip()
137
+ action = _extract_json(response_text)
 
 
 
 
 
138
 
139
  if action.get("action_type") not in {
140
  "VIEW_CODE",
 
144
  "RESET_TO_ORIGINAL",
145
  "SUBMIT",
146
  }:
147
+ raise ValueError("Invalid action_type in model response.")
148
 
149
+ return action, response_text
150
 
151
 
152
  def _to_code_action(action_dict: dict[str, Any]) -> CodeAction:
 
 
 
 
153
  payload = {
154
+ "action_type": action_dict.get("action_type"),
155
+ "thought": action_dict.get("thought"),
156
  "start_line": action_dict.get("start_line"),
157
  "end_line": action_dict.get("end_line"),
158
  "new_code_block": action_dict.get("new_code_block"),
159
  }
160
+ return CodeAction(**payload)
 
 
 
161
 
162
 
163
  def _compute_score(step_result: Any, rewards: list[float]) -> float:
 
171
  return max(0.0, min(1.0, float(raw)))
172
 
173
 
174
+ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> None:
175
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
176
 
177
  env: Optional[TraceFixRLEnv] = None
 
201
  if result.done:
202
  break
203
 
204
+ action: Optional[CodeAction] = None
205
+ model_response = ""
206
+
207
+ for attempt in range(1, MAX_PARSE_RETRIES + 1):
208
+ try:
209
+ action_dict, model_response = _get_model_action(client, result.observation, history)
210
+ action = _to_code_action(action_dict)
211
+ if show_thought:
212
+ history.append(f"thought={action.thought}")
213
+ break
214
+ except Exception as exc:
215
+ cause = str(exc).replace("\n", " ")
216
+ history.append(
217
+ (
218
+ f"parse_failure attempt={attempt} cause={cause}. "
219
+ "Error: Invalid JSON or schema. Return a complete valid JSON object "
220
+ "with fields: thought, action_type, start_line, end_line, new_code_block."
221
+ )
222
+ )
223
+ if model_response:
224
+ history.append(f"raw_response={model_response[:500]}")
225
+
226
+ if action is None:
227
+ action = CodeAction(
228
+ action_type="RUN_TESTS",
229
+ thought="Fallback after repeated invalid JSON/schema responses.",
230
+ )
231
+
232
  result = await env.step(action)
233
 
234
  reward = float(result.reward or 0.0)
 
242
 
243
  rewards.append(reward)
244
  steps_taken = step
245
+ history.append(f"step={step} action={action_str} reward={reward:.2f} error={error or 'null'}")
246
  log_step(step=step, action=action_str, reward=reward, done=done, error=error)
247
 
248
  if done:
 
256
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
257
  started = True
258
  msg = str(exc).replace("\n", " ")
 
 
 
 
259
  score = 0.0
260
  success = False
261
  finally:
 
273
  group.add_argument("--easy", action="store_true", help="Run on easy curriculum tier.")
274
  group.add_argument("--medium", action="store_true", help="Run on medium curriculum tier.")
275
  group.add_argument("--hard", action="store_true", help="Run on hard curriculum tier.")
276
+ parser.add_argument("--thought", action="store_true", help="Include model thought traces in internal history.")
277
  args = parser.parse_args()
278
 
279
  difficulty: Optional[str] = None
 
284
  elif args.hard:
285
  difficulty = "hard"
286
 
287
+ asyncio.run(run(difficulty=difficulty, show_thought=args.thought))
models.py CHANGED
@@ -21,9 +21,9 @@ ActionType = Literal[
21
  class CodeAction(Action):
22
  """Structured action consumed by the environment."""
23
 
24
- thought: Optional[str] = Field(
25
- default=None,
26
- description="Optional reasoning string for debugging/traceability.",
27
  )
28
  action_type: ActionType = Field(
29
  ...,
 
21
  class CodeAction(Action):
22
  """Structured action consumed by the environment."""
23
 
24
+ thought: str = Field(
25
+ ...,
26
+ description="Mandatory reasoning string before selecting an action.",
27
  )
28
  action_type: ActionType = Field(
29
  ...,
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  fastapi==0.111.0
2
  uvicorn[standard]==0.30.1
3
- pydantic==1.10.17
4
  websockets==12.0
5
  openai>=1.30.0
 
1
  fastapi==0.111.0
2
  uvicorn[standard]==0.30.1
3
+ pydantic>=2.0.0
4
  websockets==12.0
5
  openai>=1.30.0
tasks.py CHANGED
@@ -609,10 +609,10 @@ TASK_MI_STRICT_OVERLAP = _t(
609
  " intervals.sort()",
610
  " merged = []",
611
  " for interval in intervals:",
612
- " if not merged or merged[-1][1] <= interval[0]:",
613
  " merged.append(list(interval))",
614
  " else:",
615
- " merged[-1][1] = min(merged[-1][1], interval[1])",
616
  " return merged",
617
  ],
618
  tests=[_tmi_1, _tmi_2, _tmi_3],
@@ -680,4 +680,4 @@ TASKS_BY_DIFFICULTY: Dict[str, List[Dict]] = {
680
  # Flat list β€” used for random sampling when training_step is not set
681
  ALL_TASKS: List[Dict] = [
682
  t for bucket in TASKS_BY_DIFFICULTY.values() for t in bucket
683
- ]
 
609
  " intervals.sort()",
610
  " merged = []",
611
  " for interval in intervals:",
612
+ " if not merged or merged[-1][1] < interval[0]:",
613
  " merged.append(list(interval))",
614
  " else:",
615
+ " merged[-1][1] = max(merged[-1][1], interval[1])",
616
  " return merged",
617
  ],
618
  tests=[_tmi_1, _tmi_2, _tmi_3],
 
680
  # Flat list β€” used for random sampling when training_step is not set
681
  ALL_TASKS: List[Dict] = [
682
  t for bucket in TASKS_BY_DIFFICULTY.values() for t in bucket
683
+ ]