pytorch-training-debugger / tests /test_code_templates_edge.py
omkarrr88
Major fixes + gap fixes
4f58e42
"""Edge-case tests for code_templates.py β€” covers AST fallback and tokenizer paths."""
from __future__ import annotations
from ml_training_debugger.code_templates import (
_normalize_code,
_tokenize_compare,
generate_code_snippet,
validate_fix,
)
class TestNormalizeCode:
def test_strips_whitespace(self) -> None:
assert _normalize_code(" model.train() ") == "model.train()"
def test_multiline(self) -> None:
result = _normalize_code(" line1 \n line2 \n")
assert "line1" in result
assert "line2" in result
class TestTokenizeCompare:
def test_identical_tokens(self) -> None:
assert _tokenize_compare("model.train()", "model.train()")
def test_whitespace_ignored(self) -> None:
assert _tokenize_compare("model.train()", " model.train() ")
def test_different_tokens(self) -> None:
assert not _tokenize_compare("model.train()", "model.eval()")
def test_invalid_syntax(self) -> None:
# Tokenizer returns empty list for invalid syntax
assert _tokenize_compare("(((", "(((")
class TestValidateFixASTFallback:
"""Tests targeting the AST fallback branch in validate_fix."""
def test_eval_mode_ast_fallback_with_train_keyword(self) -> None:
# A replacement that doesn't match exact string or tokenize
# but passes AST validation (contains 'train', no 'eval')
result = validate_fix("eval_mode", 5, "model.train() # fixed mode")
assert result is True
def test_detach_loss_ast_without_detach(self) -> None:
# Replacement without .detach() β€” should pass AST check
result = validate_fix(
"detach_loss", 14, " loss = criterion(output, batch_y) # no detach"
)
assert result is True
def test_inplace_relu_ast_without_inplace(self) -> None:
# Replacement without inplace β€” should pass AST or semantic check
result = validate_fix("inplace_relu", 15, " output = F.relu(output) # fixed")
assert result is True
def test_eval_mode_line_zero_invalid(self) -> None:
assert not validate_fix("eval_mode", 0, "model.train()")
def test_detach_loss_syntax_error_rejected(self) -> None:
# Completely invalid syntax replacement
assert not validate_fix("detach_loss", 14, " ((( invalid syntax")
def test_zero_grad_with_comment(self) -> None:
# zero_grad with inline comment
assert validate_fix(
"zero_grad_missing", 11, " optimizer.zero_grad() # clear grads"
)
def test_zero_grad_without_keyword(self) -> None:
# Missing zero_grad keyword entirely
assert not validate_fix("zero_grad_missing", 11, " pass")
class TestValidateFixSemanticPatterns:
"""Tests targeting semantic equivalence pattern matching."""
def test_eval_mode_semantic_train_present(self) -> None:
# Contains model.train() β€” semantic pattern match
assert validate_fix("eval_mode", 5, "model.train()")
def test_eval_mode_with_eval_keyword_fails(self) -> None:
# Contains model.eval() β€” semantic pattern should reject
assert not validate_fix("eval_mode", 5, "model.eval()")
def test_detach_loss_criterion_without_detach(self) -> None:
assert validate_fix(
"detach_loss", 14, " loss = criterion(output, batch_y)"
)
def test_inplace_relu_without_inplace_flag(self) -> None:
assert validate_fix("inplace_relu", 15, " output = F.relu(output)")
class TestGenerateCodeSnippetHints:
"""Test hint generation for code snippets."""
def test_eval_mode_has_hint(self) -> None:
snippet = generate_code_snippet("eval_mode")
assert snippet["hint"] is not None
def test_detach_loss_has_hint(self) -> None:
snippet = generate_code_snippet("detach_loss")
assert snippet["hint"] is not None
def test_zero_grad_no_hint(self) -> None:
snippet = generate_code_snippet("zero_grad_missing")
assert snippet["hint"] is None
def test_inplace_relu_no_hint(self) -> None:
snippet = generate_code_snippet("inplace_relu")
assert snippet["hint"] is None