File size: 2,168 Bytes
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Test code bug generation and fix validation."""

from __future__ import annotations

import pytest

from ml_training_debugger.code_templates import generate_code_snippet, validate_fix


class TestGenerateCodeSnippet:
    def test_eval_mode(self):
        snippet = generate_code_snippet("eval_mode")
        assert "model.eval()" in snippet["code"]
        assert snippet["filename"] == "train.py"
        assert snippet["line_count"] > 0
        assert len(snippet["imports"]) > 0

    def test_detach_loss(self):
        snippet = generate_code_snippet("detach_loss")
        assert ".detach()" in snippet["code"]

    def test_zero_grad_missing(self):
        snippet = generate_code_snippet("zero_grad_missing")
        assert "zero_grad" not in snippet["code"]

    def test_inplace_relu(self):
        snippet = generate_code_snippet("inplace_relu")
        assert "inplace=True" in snippet["code"]

    def test_unknown_bug_raises(self):
        with pytest.raises(ValueError):
            generate_code_snippet("nonexistent_bug")


class TestValidateFix:
    def test_eval_mode_correct_fix(self):
        assert validate_fix("eval_mode", 5, "model.train()")

    def test_eval_mode_with_whitespace(self):
        assert validate_fix("eval_mode", 5, "  model.train()  ")

    def test_eval_mode_wrong_fix(self):
        assert not validate_fix("eval_mode", 5, "pass")

    def test_detach_loss_correct_fix(self):
        assert validate_fix(
            "detach_loss", 14, "        loss = criterion(output, batch_y)"
        )

    def test_detach_loss_with_trailing_spaces(self):
        assert validate_fix(
            "detach_loss", 14, "        loss = criterion(output, batch_y)   "
        )

    def test_zero_grad_correct_fix(self):
        assert validate_fix("zero_grad_missing", 11, "        optimizer.zero_grad()")

    def test_inplace_relu_correct_fix(self):
        assert validate_fix("inplace_relu", 15, "        output = F.relu(output)")

    def test_wrong_line_number(self):
        assert not validate_fix("eval_mode", 999, "model.train()")

    def test_unknown_bug_type(self):
        assert not validate_fix("nonexistent", 1, "pass")