| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Verifiable Logic Rewards """ |
|
|
| import os |
| import subprocess |
| import tempfile |
| import evaluate |
| import logging |
| import datasets |
| from tqdm import tqdm |
| import time |
| import multiprocessing as mp |
| import re |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _CITATION = """\ |
| @misc{anonymous2025slr, |
| author = {Anonymous}, |
| title = {SLR: Anonymous}, |
| year = {2025} |
| } |
| """ |
|
|
| _DESCRIPTION = """\ |
| Verifiable Rewards for Scalable Logical Reasoning (**SLR**) provides verifiable rewards via logic programm execution. |
| It deterministically evaluates candidate hypotheses by executing them against the validation program and verifying all positive examples ($E^+$) are entailed and all negative examples ($E^-$) are not entailed . |
| Evaluations performed are fully verifiable and grounded in formal logic, ensuring an automatic, transparent, and reproducible standard for evaluation and reward in both supervised and reinforcement learning settings. |
| How it Works: |
| - Input: A candidate hypothesis (logic rule) and an executable validation program containing background knowledge and examples. |
| - Execution: The candidate rule is executed against the validation program using a Prolog interpreter. |
| - Correctness Criteria: The rule is considered correct if it entails all positive examples and rejects all negative examples. |
| - Metrics: We provide a range of evaluation metrics (detailed below). |
| - Usage: see **Documentation tab** for details on how to use Verifiable Rewards for Scalable Logical Reasoning in your own projects. |
| |
| Example usage: |
| from evaluate import load |
| |
| symbolic_judge = load("LG-Anonym/VerifiableRewardsForScalableLogicalReasoning") |
| |
| validation_program = \"\"\" |
| eastbound(train0). |
| has_car(train0, car0_1). |
| car_num(car0_1, 1). |
| car_color(car0_1, white). |
| car_len(car0_1, short). |
| has_wall(car0_1, full). |
| |
| westbound(train1). |
| has_car(train1, car1_1). |
| car_num(car1_1, 1). |
| car_color(car1_1, yellow). |
| car_len(car1_1, short). |
| has_wall(car1_1, full). |
| \"\"\" |
| |
| predicted_rule = "eastbound(Train):- has_car(Train, Car1), car_color(Car1, white)." |
| |
| results = symbolic_judge.compute( |
| predictions=[predicted_rule], |
| references=[{"validation_program": validation_program, |
| "evaluation_config": { |
| "positive_predicate": "eastbound", |
| "negative_predicate": "westbound" |
| }}] |
| ) |
| |
| Note: A local Prolog interpreter is required to execute validation programs. |
| """ |
|
|
| _KWARGS_DESCRIPTION = """ |
| Args: |
| predictions (`list` of `str`): Each prediction should be a Prolog rule like "pred(T) :- Body." |
| references (`list` of `dict`): Each reference should contain: |
| - 'validation_program' (`str`): Background knowledge in Prolog syntax |
| - 'evaluation_config' (`dict`, optional): Configuration of predicates to use for evaluation. |
| Define: positive_predicate, and negative_predicate, the positive one should match the head of the rule to evaluate. |
| Returns: |
| accuracy (`float`): The proportion of predictions that correctly solve all examples. Value is between 0 and 1. |
| partial_score (`float`): Average proportion of correctly classified examples across all predictions. Value is between 0 and 1. |
| syntax_score (`float`): Proportion of rules with valid syntax. Value is between 0 and 1. |
| detailed_results (`list` of `dict`): Per-example results including correctness, partial score, execution time, and any errors encountered. |
| """ |
|
|
|
|
| def validate_rule_no_hardcoded_cars(prediction): |
| """Reject rules that hardcode specific car identifiers""" |
| import re |
|
|
| |
| hardcoded_pattern = r'has_car\([^,]+,\s*([a-z][a-z0-9_]*)\)' |
| matches = re.findall(hardcoded_pattern, prediction) |
|
|
| if matches: |
| return False, f"Cars must be variables: {matches[0]}" |
|
|
| return True, "Rule is valid" |
|
|
|
|
| def _evaluate_with_prolog(prediction, validation_program, eval_config, timeout=5): |
| """ |
| Evaluates a predicted rule against the validation program using Prolog. |
| """ |
|
|
|
|
| |
| positive_pred = eval_config.get("positive_predicate", "eastbound") |
| negative_pred = eval_config.get("negative_predicate", "westbound") |
| allow_multiple_rules = eval_config.get("allow_multiple_rules", False) |
|
|
| |
| rule_to_evaluate = extract_ilp_from_text_v2(prediction, positive_pred, allow_multiple_rules) |
|
|
| is_valid, validation_msg = validate_rule_no_hardcoded_cars(rule_to_evaluate) |
| if not is_valid: |
| return { |
| "is_correct": False, |
| "partial_score": 0.0, |
| "syntax_valid": False, |
| "error": f"Rule validation failed: {validation_msg}" |
| } |
|
|
| if positive_pred not in rule_to_evaluate: |
| p = prediction.replace('\n', ' ') |
| return { |
| "is_correct": False, |
| "partial_score": 0.0, |
| "syntax_valid": False, |
| "error": f"Invalid Syntax: Logic Rule not found for symbol '{positive_pred}': {p}" |
| } |
|
|
| pos_examples = re.findall(rf'{positive_pred}\(([^)]+)\)', validation_program) |
| neg_examples = re.findall(rf'{negative_pred}\(([^)]+)\)', validation_program) |
|
|
| |
| arity = 1 |
| if pos_examples: |
| arity = pos_examples[0].count(',') + 1 |
| elif neg_examples: |
| arity = neg_examples[0].count(',') + 1 |
|
|
| |
| vars = ", ".join([f"X{i}" for i in range(1, arity + 1)]) |
|
|
| symbolic_judge = f""" |
| % Dynamic evaluation predicates |
| check({vars}) :- pos({vars}), {positive_pred}({vars}). % positive covered |
| check({vars}) :- neg({vars}), \\+ {positive_pred}({vars}). % negative rejected |
| |
| % Count successful checks |
| check_count(Count) :- |
| (setof(({vars}), ((pos({vars}); neg({vars})), check({vars})), CorrectExamples) -> |
| length(CorrectExamples, Count) |
| ; |
| Count = 0 |
| ). |
| |
| check_all :- forall((pos({vars});neg({vars})), check({vars})). |
| """ |
| |
| validation_program = re.sub(rf'\b{positive_pred}\b', 'pos', validation_program) |
| validation_program = re.sub(rf'\b{negative_pred}\b', 'neg', validation_program) |
|
|
| pos_negs = validation_program.count("pos(") + validation_program.count("neg(") |
| validation_program = '\n'.join(sorted(validation_program.splitlines())) |
| full_program = validation_program + "\n\n" + symbolic_judge + "\n\n" + rule_to_evaluate + "\n\n" |
|
|
| with tempfile.NamedTemporaryFile(suffix='.pl', mode='w', delete=False) as f: |
| f.write(full_program) |
| temp_file = f.name |
|
|
| try: |
| eval_start_time = time.time() |
| |
| cmd = ['swipl', '-s', temp_file, '-g', 'check_count(Count), writeln(Count)', '-t', 'halt'] |
| result = subprocess.run( |
| cmd, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| timeout=timeout, |
| text=True |
| ) |
| partial_score = 0.0 if result.stdout.strip() == '' else int(result.stdout.strip()) |
| |
| partial_score = partial_score / pos_negs if pos_negs > 0 else 0.0 |
|
|
| is_correct = True if partial_score == 1.0 else False |
|
|
| error = f'{result.stderr} -> Eval Rule "{rule_to_evaluate}"' if result.stderr else None |
| t1 = time.time() |
|
|
| return { |
| "is_correct": is_correct, |
| "partial_score": partial_score, |
| "syntax_valid": True, |
| "error": error, |
| "exec_time1": t1 - eval_start_time, |
| } |
|
|
| except subprocess.TimeoutExpired: |
| r = rule_to_evaluate.replace('\n', ' ') |
| return {"is_correct": False, "partial_score": 0.0, "syntax_valid": False, |
| "error": "Evaluation timed out after {timeout} seconds for rule: '{r}'"} |
| except Exception as e: |
| return {"is_correct": False, "partial_score": 0.0, "syntax_valid": False, |
| "error": f"Error evaluating rule '{rule_to_evaluate}' returns: '{result.stdout.strip() if result else 'No error message'}' with error: {e}"} |
| finally: |
| if os.path.exists(temp_file): |
| os.remove(temp_file) |
|
|
| def extract_ilp_from_text(text): |
| rule_patterns = [ |
| |
| r'([a-zA-Z_][a-zA-Z0-9_]*\([^)]*\)\s*:-[^.]*\.)', |
| |
| |
| ] |
| p_code = '' |
| for pattern in rule_patterns: |
| matches = re.findall(pattern, text) |
| for match in matches: |
| |
| statement = match.strip() |
| if not statement.endswith('.'): |
| statement += '.' |
| p_code += statement + '\n' |
| return p_code |
|
|
|
|
| def extract_ilp_from_text_v2(text, target_predicate=None, allow_multiple_rules=False): |
| text = re.sub(r'%.*?(?=\n|$)', '', text) |
| |
| text = re.sub(r'\n\s*', ' ', text) |
| |
| rule_pattern = re.compile(rf'({target_predicate}\([^()]*\)\s*:-.*?\.)') |
| rules = list(rule_pattern.findall(text)) |
| if len(rules) > 1 and not allow_multiple_rules: |
| |
| rules = rules[-1:] |
| |
| p_code = '' |
| for rule in rules: |
| |
| statement = rule.strip() |
| if not statement.endswith('.'): |
| statement += '.' |
| p_code += statement + '\n' |
| return p_code.strip() |
|
|
|
|
| @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
| class VerifiableRewardsForScalableLogicalReasoning(evaluate.Metric): |
| def __init__(self, config_name=None, **kwargs): |
| """ |
| Initializes the PrologEval metric. |
| |
| Args: |
| config_name (str, optional): Name of the configuration to use. |
| **kwargs: Additional keyword arguments. |
| """ |
| super().__init__(config_name=config_name, **kwargs) |
| self.config_name = config_name or "default" |
| self._info = self._info() |
| self._download_and_prepare(dl_manager=None) |
|
|
| def _info(self): |
| return evaluate.MetricInfo( |
| description=_DESCRIPTION, |
| citation=_CITATION, |
| inputs_description=_KWARGS_DESCRIPTION, |
| features=datasets.Features({ |
| 'predictions': datasets.Value('string'), |
| 'references': { |
| 'validation_program': datasets.Value('string'), |
| 'evaluation_config': { |
| 'positive_predicate': datasets.Value('string'), |
| 'negative_predicate': datasets.Value('string') |
| } |
| }, |
| }), |
| codebase_urls=["https://github.com/LG-Anonym/SLR-Bench"], |
| reference_urls=["https://huggingface.co/datasets/LG-Anonym/SLR-Bench"] |
| ) |
|
|
| def _download_and_prepare(self, dl_manager): |
| """Checks if SWI-Prolog is installed or warns the user.""" |
| try: |
| subprocess.run( |
| ["swipl", "--version"], |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| check=True |
| ) |
| except (subprocess.CalledProcessError, FileNotFoundError): |
| logger.warning( |
| "SWI-Prolog not found. Please install it:\n" |
| "Ubuntu/Debian: sudo apt-get install swi-prolog\n" |
| "macOS: brew install swi-prolog\n" |
| "Windows: download from https://www.swi-prolog.org/download/stable" |
| ) |
|
|
| def _compute(self, predictions: list, references: list): |
| """Calculates the accuracy of predictions using Prolog for evaluation with multiprocessing.""" |
| if not isinstance(predictions, list): |
| predictions = [predictions] |
|
|
| if len(predictions) != len(references): |
| raise ValueError( |
| f"Number of predictions ({len(predictions)}) and references {len(references)}) don't match") |
|
|
| TIMEOUT = 10 if len(predictions) > 500 else 5 |
| |
| eval_inputs = [] |
| for i, (prediction, reference) in enumerate(zip(predictions, references)): |
| validation_program = reference.get("validation_program", reference.get("validation program")) |
|
|
| |
| |
| eval_config = reference.get("evaluation_config", { |
| "positive_predicate": "eastbound", |
| "negative_predicate": "westbound" |
| }) |
|
|
| if not validation_program: |
| raise ValueError(f"Example {i} does not contain validation program field") |
|
|
| eval_inputs.append((prediction, validation_program, eval_config, TIMEOUT)) |
|
|
| |
| if len(eval_inputs) > 500: |
| |
| num_cpus = max(1, mp.cpu_count() - 1) |
| with mp.Pool(processes=num_cpus) as pool: |
| results = list(tqdm( |
| pool.starmap(_evaluate_with_prolog, eval_inputs), |
| total=len(eval_inputs), |
| desc=f"Evaluating rules (parallel processing with {num_cpus} CPUs)" |
| )) |
| else: |
| |
| results = [] |
| for prediction, validation_program, eval_config, t in tqdm(eval_inputs, total=len(predictions), desc="Evaluating rules"): |
| results.append(_evaluate_with_prolog(prediction, validation_program, eval_config, timeout=t)) |
|
|
| |
| partial_scores = [result["partial_score"] for result in results] |
| correct_count = sum(1 for result in results if result["is_correct"]) |
| syntax_valid_count = sum(1 for result in results if result["syntax_valid"]) |
|
|
| accuracy = correct_count / len(predictions) if predictions else 0 |
| partial_score = sum(partial_scores) / len(predictions) if partial_scores else 0 |
| syntax_score = syntax_valid_count / len(predictions) if predictions else 0 |
|
|
| return { |
| "accuracy": accuracy, |
| "partial_score": partial_score, |
| "syntax_score": syntax_score, |
| "detailed_results": results |
| } |
|
|