File size: 2,483 Bytes
558b89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Task graders for syntax and bug-fix tasks."""

from __future__ import annotations

from .common import clamp_score, compiles, normalized_diff_score, style_score, syntax_error_message
from .optimization import grade_optimization_task
from .pytest_runner import run_pytest_suite
from ..models import TaskGrade
from ..tasks.task_bank import TaskSpec


def grade_syntax_task(candidate_code: str, task: TaskSpec) -> TaskGrade:
    error = syntax_error_message(candidate_code)
    diff_score = normalized_diff_score(candidate_code, task.reference_code)
    style_base = style_score(candidate_code, task.style_max_line_length)
    if not error:
        return TaskGrade(score=1.0, syntax_score=1.0, quality_score=style_base, details={"compile_error": ""})
    partial = clamp_score(0.15 + (0.55 * diff_score))
    return TaskGrade(score=partial, syntax_score=0.0, quality_score=diff_score * style_base, details={"compile_error": error})


def grade_bug_fix_task(candidate_code: str, task: TaskSpec, include_hidden: bool = True) -> TaskGrade:
    if not compiles(candidate_code):
        error = syntax_error_message(candidate_code)
        return TaskGrade(score=0.0, syntax_score=0.0, details={"compile_error": error})

    tests = list(task.visible_tests)
    if include_hidden:
        tests.extend(task.hidden_tests)

    execution = run_pytest_suite(candidate_code, tests, timeout_s=3.0)
    if execution.timed_out:
        return TaskGrade(
            score=0.0,
            syntax_score=1.0,
            tests_passed=execution.passed,
            tests_total=execution.total,
            timed_out=True,
            details={"compile_error": "", "tests": execution.output},
        )

    pass_fraction = execution.passed / execution.total if execution.total else 0.0
    quality = style_score(candidate_code, task.style_max_line_length)
    return TaskGrade(
        score=clamp_score(pass_fraction),
        syntax_score=1.0,
        tests_passed=execution.passed,
        tests_total=execution.total,
        quality_score=quality,
        details={"compile_error": "", "tests": execution.output},
    )


def grade_task(candidate_code: str, task: TaskSpec, include_hidden: bool = True) -> TaskGrade:
    if task.task_kind == "syntax_fix":
        return grade_syntax_task(candidate_code, task)
    if task.task_kind == "bug_fix":
        return grade_bug_fix_task(candidate_code, task, include_hidden=include_hidden)
    return grade_optimization_task(candidate_code, task)