| """Test code bug generation and fix validation.""" | |
| from __future__ import annotations | |
| import pytest | |
| from ml_training_debugger.code_templates import generate_code_snippet, validate_fix | |
| class TestGenerateCodeSnippet: | |
| def test_eval_mode(self): | |
| snippet = generate_code_snippet("eval_mode") | |
| assert "model.eval()" in snippet["code"] | |
| assert snippet["filename"] == "train.py" | |
| assert snippet["line_count"] > 0 | |
| assert len(snippet["imports"]) > 0 | |
| def test_detach_loss(self): | |
| snippet = generate_code_snippet("detach_loss") | |
| assert ".detach()" in snippet["code"] | |
| def test_zero_grad_missing(self): | |
| snippet = generate_code_snippet("zero_grad_missing") | |
| assert "zero_grad" not in snippet["code"] | |
| def test_inplace_relu(self): | |
| snippet = generate_code_snippet("inplace_relu") | |
| assert "inplace=True" in snippet["code"] | |
| def test_unknown_bug_raises(self): | |
| with pytest.raises(ValueError): | |
| generate_code_snippet("nonexistent_bug") | |
| class TestValidateFix: | |
| def test_eval_mode_correct_fix(self): | |
| assert validate_fix("eval_mode", 5, "model.train()") | |
| def test_eval_mode_with_whitespace(self): | |
| assert validate_fix("eval_mode", 5, " model.train() ") | |
| def test_eval_mode_wrong_fix(self): | |
| assert not validate_fix("eval_mode", 5, "pass") | |
| def test_detach_loss_correct_fix(self): | |
| assert validate_fix( | |
| "detach_loss", 14, " loss = criterion(output, batch_y)" | |
| ) | |
| def test_detach_loss_with_trailing_spaces(self): | |
| assert validate_fix( | |
| "detach_loss", 14, " loss = criterion(output, batch_y) " | |
| ) | |
| def test_zero_grad_correct_fix(self): | |
| assert validate_fix("zero_grad_missing", 11, " optimizer.zero_grad()") | |
| def test_inplace_relu_correct_fix(self): | |
| assert validate_fix("inplace_relu", 15, " output = F.relu(output)") | |
| def test_wrong_line_number(self): | |
| assert not validate_fix("eval_mode", 999, "model.train()") | |
| def test_unknown_bug_type(self): | |
| assert not validate_fix("nonexistent", 1, "pass") | |