File size: 2,046 Bytes
b0fdd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pytest

from src.pytorch_debug_env.bug_library import BUG_TEMPLATES
from src.pytorch_debug_env.environment import PyTorchDebugEnv
from src.pytorch_debug_env.graders import grade_easy, grade_medium, grade_hard
from src.pytorch_debug_env.models import FinalDiagnosis, Hypothesis, PyTorchDebugAction
from src.pytorch_debug_env.scenario_generator import ScenarioGenerator


def _build_action_from_gt(gt: dict) -> PyTorchDebugAction:
    hypothesis = Hypothesis(
        bug_type=gt["bug_type"],
        affected_file=gt["primary_bug_file"],
        confidence=0.9,
    )
    final = FinalDiagnosis(
        bug_type=gt["bug_type"],
        affected_file=gt["primary_bug_file"],
        line_range=gt["line_range"],
        fix_strategy=gt["fix_strategy"],
        confidence=0.9,
    )
    return PyTorchDebugAction(
        current_hypothesis=hypothesis,
        commit_diagnosis=True,
        final_diagnosis=final,
    )


@pytest.mark.parametrize(
    "task_id,grader",
    [
        ("easy", grade_easy),
        ("medium", grade_medium),
        ("hard", grade_hard),
    ],
)
@pytest.mark.asyncio
async def test_task_scores_strict_bounds(task_id, grader):
    env = PyTorchDebugEnv(generator=ScenarioGenerator(BUG_TEMPLATES))
    await env.reset(task_id, seed=7)
    scenario = env.runtime.scenario
    action = _build_action_from_gt(scenario.ground_truth)

    score = grader(action.final_diagnosis.model_dump(), scenario.ground_truth)
    assert 0.0 < score < 1.0

    result = await env.step(action)
    assert 0.0 < result["reward"] < 1.0
    state = await env.state()
    assert 0.0 < state.final_score < 1.0


@pytest.mark.parametrize(
    "grader",
    [grade_easy, grade_medium, grade_hard],
)
def test_empty_action_is_clamped(grader):
    gt = {
        "bug_type": "missing_zero_grad",
        "primary_bug_file": "train.py",
        "related_files": [],
        "line_range": [10, 12],
        "fix_strategy": "Call optimizer.zero_grad() before loss.backward()",
    }
    score = grader({}, gt)
    assert 0.0 < score < 1.0