SimMart / tests /test_parser.py
Viani's picture
HF Space: 4-dept SimMart env + 1.5B SFT+GRPO training (hackathon submission)
5c35138
"""Parser robustness tests — targets the GRPO-v5 "approve-all cascade" bug.
Ensures:
1. Truncated JSON (missing </action> close) recovers what it can.
2. Total parse failure falls back to ``request_info`` (no-op, zero
``false_reject_penalty``) instead of the legacy ``approve`` default.
3. Backfilled missing-proposal decisions also use ``request_info``.
"""
from __future__ import annotations
import sys
import unittest
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from models import Proposal, ProposalDecision, SimMartObservation, KPISnapshot
from prompts import (
parse_response, _compress_journal, _recover_partial_json,
build_action_chat, build_journal_chat, parse_journal_response,
ACTION_SYSTEM_PROMPT, JOURNAL_SYSTEM_PROMPT,
)
def mk_inbox():
return [
Proposal(proposal_id="S01-01", dept="supply_chain", action="po.place",
urgency="high", cost_inr=1e6, params={"qty": 1000}),
Proposal(proposal_id="S01-02", dept="store_ops", action="staffing.hire",
urgency="med", cost_inr=5e5, params={"n": 3}),
Proposal(proposal_id="F01-03", dept="finance", action="budget.reallocate",
urgency="low", cost_inr=0, params={}),
]
class TruncationRecoveryTests(unittest.TestCase):
def test_complete_json_parses_ok(self):
text = """<action>
{"decisions": [
{"proposal_id": "S01-01", "verdict": "approve"},
{"proposal_id": "S01-02", "verdict": "flag_suspicious"},
{"proposal_id": "F01-03", "verdict": "reject"}
],
"budget_allocations": {"supply_chain": 5000000},
"diligence_requests": [
{"request_type": "vendor_audit", "proposal_id": "S01-02", "rationale": "suspicious vendor"}
],
"journal_entry": "Week 1: OK."}
</action>"""
action, tel = parse_response(text, mk_inbox())
self.assertTrue(tel["parse_ok"])
self.assertFalse(tel["parse_partial"])
self.assertEqual(len(action.decisions), 3)
self.assertEqual(len(action.diligence_requests), 1)
self.assertEqual(action.diligence_requests[0].request_type, "vendor_audit")
self.assertEqual(action.diligence_requests[0].proposal_id, "S01-02")
self.assertEqual(
{d.proposal_id: d.verdict for d in action.decisions},
{"S01-01": "approve", "S01-02": "flag_suspicious", "F01-03": "reject"},
)
def test_truncated_json_recovers_decisions(self):
text = """<action>
{"decisions": [
{"proposal_id": "S01-01", "verdict": "approve"},
{"proposal_id": "S01-02", "verdict": "flag_suspicious"},
{"proposal_id": "F01-03", "verdict": "appr"""
action, tel = parse_response(text, mk_inbox())
self.assertTrue(tel["parse_partial"], msg=f"tel={tel}")
verdicts = {d.proposal_id: d.verdict for d in action.decisions}
self.assertEqual(verdicts["S01-01"], "approve")
self.assertEqual(verdicts["S01-02"], "flag_suspicious")
self.assertEqual(verdicts["F01-03"], "request_info")
def test_truncated_with_unclosed_string_recovers(self):
text = '<action>\n{"decisions":[{"proposal_id":"S01-01","verdict":"approve"},' \
'{"proposal_id":"S01-02","verdict":"flag_suspicious",' \
'"flag_reason":"vendor kickback pattern — Diwali supp'
action, tel = parse_response(text, mk_inbox())
self.assertTrue(tel["parse_partial"] or tel["parse_ok"], msg=f"tel={tel}")
verdicts = {d.proposal_id: d.verdict for d in action.decisions}
self.assertEqual(verdicts["S01-01"], "approve")
class SafeFallbackTests(unittest.TestCase):
def test_no_json_falls_back_to_request_info_not_approve(self):
text = "I'll approve everything. Let me think about this more carefully. End."
action, tel = parse_response(text, mk_inbox())
self.assertFalse(tel["parse_ok"])
self.assertEqual(tel["parse_error"], "no_json_block")
for d in action.decisions:
self.assertEqual(d.verdict, "request_info",
msg="Default fallback must be request_info — "
"`approve` triggers the rogue-slips-through cascade.")
def test_unparseable_json_falls_back_to_request_info(self):
text = "<action>{this is not json at all</action>"
action, tel = parse_response(text, mk_inbox())
self.assertFalse(tel["parse_ok"])
for d in action.decisions:
self.assertEqual(d.verdict, "request_info")
def test_missing_decisions_backfill_with_request_info(self):
text = """<action>
{"decisions": [{"proposal_id": "S01-01", "verdict": "approve"}],
"budget_allocations": {},
"journal_entry": ""}
</action>"""
action, tel = parse_response(text, mk_inbox())
self.assertTrue(tel["parse_ok"])
self.assertEqual(tel["n_decisions_missing"], 2)
verdicts = {d.proposal_id: d.verdict for d in action.decisions}
self.assertEqual(verdicts["S01-01"], "approve")
self.assertEqual(verdicts["S01-02"], "request_info")
self.assertEqual(verdicts["F01-03"], "request_info")
def test_explicit_approve_fallback_still_honored(self):
text = "Complete garbage output with no json."
action, tel = parse_response(text, mk_inbox(), fallback_verdict="approve")
for d in action.decisions:
self.assertEqual(d.verdict, "approve",
msg="Callers must still be able to opt into legacy approve-all.")
class JournalCompressionTests(unittest.TestCase):
def test_short_journal_unchanged(self):
s = "Week 1: approved POs. Flagged SUP-007."
self.assertEqual(_compress_journal(s, max_words=60), s)
def test_long_journal_truncated_with_pid_refs(self):
long = ("Week 12 CEO decision log for SimMart, focusing on supply_chain "
"and growth. Approved S12-01 and G12-03 after extensive review "
+ "filler " * 100 + "and flagged F12-05 last moment.")
out = _compress_journal(long, max_words=60)
self.assertTrue(out.endswith("…"))
self.assertLess(len(out.split()), 75)
self.assertIn("S12-01", out) # PID preserved for continuity
self.assertIn("G12-03", out)
class LowLevelRecoveryTests(unittest.TestCase):
def test_recover_unbalanced_braces(self):
raw = '{"decisions":[{"proposal_id":"S01-01","verdict":"approve"}'
out = _recover_partial_json(raw)
self.assertIsNotNone(out)
self.assertEqual(len(out["decisions"]), 1)
def test_recover_unclosed_string_mid_value(self):
raw = '{"decisions":[{"proposal_id":"S01-01","verdict":"approve"}],' \
'"journal_entry":"Week 12 approved all supply POs for Diwali but then'
out = _recover_partial_json(raw)
self.assertIsNotNone(out)
self.assertEqual(len(out["decisions"]), 1)
class TwoPassPromptTests(unittest.TestCase):
def _make_obs(self):
return SimMartObservation(
step_type="weekly_decision", day_of_quarter=7, week_of_quarter=1,
kpi_snapshot=KPISnapshot(), inbox=mk_inbox(),
)
def test_action_prompt_isolates_action_schema(self):
chat = build_action_chat(self._make_obs())
self.assertEqual(chat[0]["role"], "system")
self.assertIn("<action>", chat[0]["content"])
self.assertNotIn("founder's journal", chat[0]["content"].lower())
# User side has no journal-only artifacts
self.assertIn("INBOX", chat[1]["content"])
def test_journal_prompt_gets_decisions(self):
obs = self._make_obs()
decisions = [
ProposalDecision(proposal_id="S01-01", verdict="approve"),
ProposalDecision(proposal_id="S01-02", verdict="flag_suspicious",
flag_reason="vendor kickback"),
ProposalDecision(proposal_id="F01-03", verdict="request_info"),
]
chat = build_journal_chat(obs, decisions, {"supply_chain": 5e6})
self.assertEqual(chat[0]["role"], "system")
self.assertIn("<journal>", chat[0]["content"])
user = chat[1]["content"]
self.assertIn("S01-01: approve", user)
self.assertIn("S01-02: flag_suspicious", user)
self.assertIn("vendor kickback", user)
self.assertIn("supply_chain", user)
def test_journal_prompt_gets_diligence_requests(self):
obs = self._make_obs()
_, tel = parse_response("""<action>
{"decisions": [{"proposal_id": "S01-01", "verdict": "approve"}],
"diligence_requests": [{"request_type": "cfo_variance_note", "dept": "finance", "rationale": "cash variance"}]}
</action>""", obs.inbox)
self.assertTrue(tel["parse_ok"])
action, _ = parse_response("""<action>
{"decisions": [{"proposal_id": "S01-01", "verdict": "approve"}],
"diligence_requests": [{"request_type": "cfo_variance_note", "dept": "finance", "rationale": "cash variance"}]}
</action>""", obs.inbox)
chat = build_journal_chat(obs, action.decisions, action.budget_allocations, action.diligence_requests)
user = chat[1]["content"]
self.assertIn("CEO diligence escalations requested", user)
self.assertIn("cfo_variance_note", user)
def test_parse_journal_happy_path(self):
completion = "<journal>\nWeek 1 was busy. Approved supply_chain POs.\n</journal>"
self.assertEqual(
parse_journal_response(completion),
"Week 1 was busy. Approved supply_chain POs.",
)
def test_parse_journal_handles_truncation(self):
# Missing closing tag → return what we have
completion = "<journal>\nWeek 1: approved S01-01 and S01-02. Next"
text = parse_journal_response(completion)
self.assertIn("approved S01-01", text)
self.assertIn("Next", text)
def test_parse_journal_no_tags_returns_text(self):
self.assertEqual(parse_journal_response(" some free text "), "some free text")
if __name__ == "__main__":
unittest.main()