File size: 4,241 Bytes
4f58e42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Edge-case tests for code_templates.py — covers AST fallback and tokenizer paths."""

from __future__ import annotations

from ml_training_debugger.code_templates import (
    _normalize_code,
    _tokenize_compare,
    generate_code_snippet,
    validate_fix,
)


class TestNormalizeCode:
    def test_strips_whitespace(self) -> None:
        assert _normalize_code("  model.train()  ") == "model.train()"

    def test_multiline(self) -> None:
        result = _normalize_code("  line1  \n  line2  \n")
        assert "line1" in result
        assert "line2" in result


class TestTokenizeCompare:
    def test_identical_tokens(self) -> None:
        assert _tokenize_compare("model.train()", "model.train()")

    def test_whitespace_ignored(self) -> None:
        assert _tokenize_compare("model.train()", "  model.train()  ")

    def test_different_tokens(self) -> None:
        assert not _tokenize_compare("model.train()", "model.eval()")

    def test_invalid_syntax(self) -> None:
        # Tokenizer returns empty list for invalid syntax
        assert _tokenize_compare("(((", "(((")


class TestValidateFixASTFallback:
    """Tests targeting the AST fallback branch in validate_fix."""

    def test_eval_mode_ast_fallback_with_train_keyword(self) -> None:
        # A replacement that doesn't match exact string or tokenize
        # but passes AST validation (contains 'train', no 'eval')
        result = validate_fix("eval_mode", 5, "model.train()  # fixed mode")
        assert result is True

    def test_detach_loss_ast_without_detach(self) -> None:
        # Replacement without .detach() — should pass AST check
        result = validate_fix(
            "detach_loss", 14, "        loss = criterion(output, batch_y)  # no detach"
        )
        assert result is True

    def test_inplace_relu_ast_without_inplace(self) -> None:
        # Replacement without inplace — should pass AST or semantic check
        result = validate_fix("inplace_relu", 15, "        output = F.relu(output)  # fixed")
        assert result is True

    def test_eval_mode_line_zero_invalid(self) -> None:
        assert not validate_fix("eval_mode", 0, "model.train()")

    def test_detach_loss_syntax_error_rejected(self) -> None:
        # Completely invalid syntax replacement
        assert not validate_fix("detach_loss", 14, "    ((( invalid syntax")

    def test_zero_grad_with_comment(self) -> None:
        # zero_grad with inline comment
        assert validate_fix(
            "zero_grad_missing", 11, "        optimizer.zero_grad()  # clear grads"
        )

    def test_zero_grad_without_keyword(self) -> None:
        # Missing zero_grad keyword entirely
        assert not validate_fix("zero_grad_missing", 11, "        pass")


class TestValidateFixSemanticPatterns:
    """Tests targeting semantic equivalence pattern matching."""

    def test_eval_mode_semantic_train_present(self) -> None:
        # Contains model.train() — semantic pattern match
        assert validate_fix("eval_mode", 5, "model.train()")

    def test_eval_mode_with_eval_keyword_fails(self) -> None:
        # Contains model.eval() — semantic pattern should reject
        assert not validate_fix("eval_mode", 5, "model.eval()")

    def test_detach_loss_criterion_without_detach(self) -> None:
        assert validate_fix(
            "detach_loss", 14, "        loss = criterion(output, batch_y)"
        )

    def test_inplace_relu_without_inplace_flag(self) -> None:
        assert validate_fix("inplace_relu", 15, "        output = F.relu(output)")


class TestGenerateCodeSnippetHints:
    """Test hint generation for code snippets."""

    def test_eval_mode_has_hint(self) -> None:
        snippet = generate_code_snippet("eval_mode")
        assert snippet["hint"] is not None

    def test_detach_loss_has_hint(self) -> None:
        snippet = generate_code_snippet("detach_loss")
        assert snippet["hint"] is not None

    def test_zero_grad_no_hint(self) -> None:
        snippet = generate_code_snippet("zero_grad_missing")
        assert snippet["hint"] is None

    def test_inplace_relu_no_hint(self) -> None:
        snippet = generate_code_snippet("inplace_relu")
        assert snippet["hint"] is None