| """Edge-case tests for code_templates.py β covers AST fallback and tokenizer paths.""" |
|
|
| from __future__ import annotations |
|
|
| from ml_training_debugger.code_templates import ( |
| _normalize_code, |
| _tokenize_compare, |
| generate_code_snippet, |
| validate_fix, |
| ) |
|
|
|
|
| class TestNormalizeCode: |
| def test_strips_whitespace(self) -> None: |
| assert _normalize_code(" model.train() ") == "model.train()" |
|
|
| def test_multiline(self) -> None: |
| result = _normalize_code(" line1 \n line2 \n") |
| assert "line1" in result |
| assert "line2" in result |
|
|
|
|
| class TestTokenizeCompare: |
| def test_identical_tokens(self) -> None: |
| assert _tokenize_compare("model.train()", "model.train()") |
|
|
| def test_whitespace_ignored(self) -> None: |
| assert _tokenize_compare("model.train()", " model.train() ") |
|
|
| def test_different_tokens(self) -> None: |
| assert not _tokenize_compare("model.train()", "model.eval()") |
|
|
| def test_invalid_syntax(self) -> None: |
| |
| assert _tokenize_compare("(((", "(((") |
|
|
|
|
| class TestValidateFixASTFallback: |
| """Tests targeting the AST fallback branch in validate_fix.""" |
|
|
| def test_eval_mode_ast_fallback_with_train_keyword(self) -> None: |
| |
| |
| result = validate_fix("eval_mode", 5, "model.train() # fixed mode") |
| assert result is True |
|
|
| def test_detach_loss_ast_without_detach(self) -> None: |
| |
| result = validate_fix( |
| "detach_loss", 14, " loss = criterion(output, batch_y) # no detach" |
| ) |
| assert result is True |
|
|
| def test_inplace_relu_ast_without_inplace(self) -> None: |
| |
| result = validate_fix("inplace_relu", 15, " output = F.relu(output) # fixed") |
| assert result is True |
|
|
| def test_eval_mode_line_zero_invalid(self) -> None: |
| assert not validate_fix("eval_mode", 0, "model.train()") |
|
|
| def test_detach_loss_syntax_error_rejected(self) -> None: |
| |
| assert not validate_fix("detach_loss", 14, " ((( invalid syntax") |
|
|
| def test_zero_grad_with_comment(self) -> None: |
| |
| assert validate_fix( |
| "zero_grad_missing", 11, " optimizer.zero_grad() # clear grads" |
| ) |
|
|
| def test_zero_grad_without_keyword(self) -> None: |
| |
| assert not validate_fix("zero_grad_missing", 11, " pass") |
|
|
|
|
| class TestValidateFixSemanticPatterns: |
| """Tests targeting semantic equivalence pattern matching.""" |
|
|
| def test_eval_mode_semantic_train_present(self) -> None: |
| |
| assert validate_fix("eval_mode", 5, "model.train()") |
|
|
| def test_eval_mode_with_eval_keyword_fails(self) -> None: |
| |
| assert not validate_fix("eval_mode", 5, "model.eval()") |
|
|
| def test_detach_loss_criterion_without_detach(self) -> None: |
| assert validate_fix( |
| "detach_loss", 14, " loss = criterion(output, batch_y)" |
| ) |
|
|
| def test_inplace_relu_without_inplace_flag(self) -> None: |
| assert validate_fix("inplace_relu", 15, " output = F.relu(output)") |
|
|
|
|
| class TestGenerateCodeSnippetHints: |
| """Test hint generation for code snippets.""" |
|
|
| def test_eval_mode_has_hint(self) -> None: |
| snippet = generate_code_snippet("eval_mode") |
| assert snippet["hint"] is not None |
|
|
| def test_detach_loss_has_hint(self) -> None: |
| snippet = generate_code_snippet("detach_loss") |
| assert snippet["hint"] is not None |
|
|
| def test_zero_grad_no_hint(self) -> None: |
| snippet = generate_code_snippet("zero_grad_missing") |
| assert snippet["hint"] is None |
|
|
| def test_inplace_relu_no_hint(self) -> None: |
| snippet = generate_code_snippet("inplace_relu") |
| assert snippet["hint"] is None |
|
|