| | |
| | """ |
| | Test script for plot arc classifier |
| | """ |
| |
|
| | import json |
| | import torch |
| | from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification |
| |
|
| | def load_tests(): |
| | """Load synthetic test cases""" |
| | with open('tests/synthetic_tests.json', 'r') as f: |
| | return json.load(f) |
| |
|
| | def run_tests(): |
| | """Run all synthetic tests""" |
| | print("Loading model...") |
| | tokenizer = DebertaV2Tokenizer.from_pretrained('.') |
| | model = DebertaV2ForSequenceClassification.from_pretrained('.') |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.to(device) |
| | model.eval() |
| | |
| | class_names = ['NONE', 'INTERNAL', 'EXTERNAL', 'BOTH'] |
| | class_to_idx = {name: idx for idx, name in enumerate(class_names)} |
| | |
| | tests = load_tests() |
| | |
| | correct = 0 |
| | total = len(tests) |
| | |
| | print(f"Running {total} synthetic tests...\n") |
| | |
| | for i, test in enumerate(tests, 1): |
| | text = test['description'] |
| | expected = test['expected_class'] |
| | expected_idx = class_to_idx[expected] |
| | |
| | |
| | inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | probabilities = torch.softmax(outputs.logits, dim=-1) |
| | predicted_idx = torch.argmax(probabilities, dim=-1).item() |
| | confidence = probabilities[0][predicted_idx].item() |
| | |
| | predicted = class_names[predicted_idx] |
| | is_correct = predicted == expected |
| | |
| | if is_correct: |
| | correct += 1 |
| | status = "✅ PASS" |
| | else: |
| | status = "❌ FAIL" |
| | |
| | print(f"Test {i:2d}: {status}") |
| | print(f" Text: {text[:100]}{'...' if len(text) > 100 else ''}") |
| | print(f" Expected: {expected} | Predicted: {predicted} (conf: {confidence:.3f})") |
| | print(f" Reasoning: {test['reasoning']}") |
| | print() |
| | |
| | accuracy = correct / total |
| | print(f"Results: {correct}/{total} correct ({accuracy:.1%})") |
| | |
| | return accuracy |
| |
|
| | if __name__ == "__main__": |
| | run_tests() |
| |
|