| import asyncio |
| import re |
| from copy import deepcopy |
| from typing import List |
|
|
| import json |
| import torch |
|
|
| from swift.llm import Template, to_device |
| from swift.plugin import ORM, orms, rm_plugins |
| from swift.utils import get_logger |
|
|
| logger = get_logger() |
| """ |
| Step 1: Define a Reward Class |
| Implement your custom reward calculation logic within the __call__ method. |
| The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters. |
| |
| Step 2: Register the Reward Class in orms |
| For example: |
| python orms['external_math_acc'] = MathAccuracy |
| |
| Step 3: Configure the Arguments |
| Use the following arguments when running the script: |
| bash --plugin /path/to/plugin.py --reward_funcs external_math_acc |
| """ |
|
|
|
|
| |
| class MathAccuracy(ORM): |
|
|
| def __init__(self): |
| import importlib.util |
| assert importlib.util.find_spec('math_verify') is not None, ( |
| "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.") |
|
|
| def __call__(self, completions, solution, **kwargs) -> List[float]: |
| from latex2sympy2_extended import NormalizationConfig |
| from math_verify import LatexExtractionConfig, parse, verify |
| rewards = [] |
| for content, sol in zip(completions, solution): |
| gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()]) |
| if len(gold_parsed) != 0: |
| |
| answer_parsed = parse( |
| content, |
| extraction_config=[ |
| LatexExtractionConfig( |
| normalization_config=NormalizationConfig( |
| nits=False, |
| malformed_operators=False, |
| basic_latex=True, |
| equations=True, |
| boxed=True, |
| units=True, |
| ), |
| |
| boxed_match_priority=0, |
| try_extract_without_anchor=False, |
| ) |
| ], |
| extraction_mode='first_match', |
| ) |
| |
| reward = float(verify(answer_parsed, gold_parsed)) |
| else: |
| |
| reward = 1.0 |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| class MathFormat(ORM): |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| """Reward function that checks if the completion has a specific format.""" |
| pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])' |
| matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions] |
| return [1.0 if match else 0.0 for match in matches] |
|
|
|
|
| class CountdownORM(ORM): |
|
|
| def __call__(self, completions, target, nums, **kwargs) -> List[float]: |
| """ |
| Evaluates completions based on Mathematical correctness of the answer |
| |
| Args: |
| completions (list[str]): Generated outputs |
| target (list[str]): Expected answers |
| nums (list[str]): Available numbers |
| |
| Returns: |
| list[float]: Reward scores |
| """ |
| rewards = [] |
| for completion, gt, numbers in zip(completions, target, nums): |
| try: |
| |
| match = re.search(r'<answer>(.*?)<\/answer>', completion) |
| if match is None: |
| rewards.append(0.0) |
| continue |
| |
| equation = match.group(1).strip() |
| if '=' in equation: |
| equation = equation.split('=')[0] |
| |
| used_numbers = [int(n) for n in re.findall(r'\d+', equation)] |
|
|
| |
| if sorted(used_numbers) != sorted(numbers): |
| rewards.append(0.0) |
| continue |
| |
| allowed_pattern = r'^[\d+\-*/().\s]+$' |
| if not re.match(allowed_pattern, equation): |
| rewards.append(0.0) |
| continue |
|
|
| |
| result = eval(equation, {"__builti'ns__": None}, {}) |
| |
| if abs(float(result) - float(gt)) < 1e-5: |
| rewards.append(1.0) |
| else: |
| rewards.append(0.0) |
| except Exception: |
| |
| rewards.append(0.0) |
| return rewards |
|
|
|
|
| class MultiModalAccuracyORM(ORM): |
|
|
| def __call__(self, completions, solution, **kwargs) -> List[float]: |
| """ |
| Reward function that checks if the completion is correct. |
| Args: |
| completions (list[str]): Generated outputs |
| solution (list[str]): Ground Truths. |
| |
| Returns: |
| list[float]: Reward scores |
| """ |
| rewards = [] |
| from math_verify import parse, verify |
| for content, sol in zip(completions, solution): |
| reward = 0.0 |
| |
| try: |
| answer = parse(content) |
| if float(verify(answer, parse(sol))) > 0: |
| reward = 1.0 |
| except Exception: |
| pass |
|
|
| |
| if reward == 0.0: |
| try: |
| |
| sol_match = re.search(r'<answer>(.*?)</answer>', sol) |
| ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() |
|
|
| |
| content_match = re.search(r'<answer>(.*?)</answer>', content) |
| student_answer = content_match.group(1).strip() if content_match else content.strip() |
|
|
| |
| if student_answer == ground_truth: |
| reward = 1.0 |
| except Exception: |
| pass |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| |
| class CodeReward(ORM): |
|
|
| def __init__(self): |
| import importlib.util |
| assert importlib.util.find_spec('e2b') is not None, ( |
| "The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'." |
| ) |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| @staticmethod |
| def extract_code(completion: str, language: str) -> str: |
| pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL) |
| matches = pattern.findall(completion) |
| extracted_answer = matches[-1] if len(matches) >= 1 else '' |
| return extracted_answer |
|
|
| def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]: |
| """Function wrapping the `run_async` function.""" |
| |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
|
|
| try: |
| |
| rewards = loop.run_until_complete(self.run_async(scripts, languages)) |
| finally: |
| loop.close() |
|
|
| return rewards |
|
|
| async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]: |
| from e2b_code_interpreter import AsyncSandbox |
|
|
| |
| try: |
| sbx = await AsyncSandbox.create(timeout=30, request_timeout=3) |
| except Exception as e: |
| logger.warning(f'Error from E2B executor: {e}') |
| return [0.0] * len(scripts) |
| |
| tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)] |
|
|
| |
| results = await asyncio.gather(*tasks) |
| rewards = list(results) |
|
|
| |
| await sbx.kill() |
|
|
| return rewards |
|
|
| async def run_script(self, sbx, script: str, language: str) -> float: |
| try: |
| execution = await sbx.run_code(script, language=language, timeout=30) |
| except Exception as e: |
| logger.warning(f'Error from E2B executor: {e}') |
| return 0.0 |
| try: |
| return float(execution.text) |
| except (TypeError, ValueError): |
| return 0.0 |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| """Reward function that evaluates code snippets using the E2B code interpreter. |
| |
| Assumes the dataset contains a `verification_info` column with test cases. |
| """ |
| evaluation_script_template = """ |
| import subprocess |
| import json |
| |
| def evaluate_code(code, test_cases): |
| passed = 0 |
| total = len(test_cases) |
| exec_timeout = 5 |
| |
| for case in test_cases: |
| process = subprocess.run( |
| ["python3", "-c", code], |
| input=case["input"], |
| text=True, |
| capture_output=True, |
| timeout=exec_timeout |
| ) |
| |
| if process.returncode != 0: # Error in execution |
| continue |
| |
| output = process.stdout.strip() |
| if output.strip() == case["output"].strip(): |
| passed += 1 |
| |
| success_rate = (passed / total) |
| return success_rate |
| |
| code_snippet = {code} |
| test_cases = json.loads({test_cases}) |
| |
| evaluate_code(code_snippet, test_cases) |
| """ |
| verification_info = kwargs['verification_info'] |
| languages = [info['language'] for info in verification_info] |
| code_snippets = [ |
| self.extract_code(completion, language) for completion, language in zip(completions, languages) |
| ] |
| scripts = [ |
| evaluation_script_template.format( |
| code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases']))) |
| for code, info in zip(code_snippets, verification_info) |
| ] |
| try: |
| rewards = self.run_async_from_sync(scripts, languages) |
|
|
| except Exception as e: |
| logger.warning(f'Error from E2B executor: {e}') |
| rewards = [0.0] * len(completions) |
|
|
| return rewards |
|
|
|
|
| class CodeFormat(ORM): |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| verification_info = kwargs['verification_info'] |
| rewards = [] |
| for content, info in zip(completions, verification_info): |
| pattern = r'^<think>.*?</think>\s*<answer>.*?```{}.*?```.*?</answer>(?![\s\S])'.format(info['language']) |
| match = re.match(pattern, content, re.DOTALL | re.MULTILINE) |
| reward = 1.0 if match else 0.0 |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| class CodeRewardByJudge0(ORM): |
| LANGUAGE_ID_MAP = { |
| 'assembly': 45, |
| 'bash': 46, |
| 'basic': 47, |
| 'c': 50, |
| 'c++': 54, |
| 'clojure': 86, |
| 'c#': 51, |
| 'cobol': 77, |
| 'common lisp': 55, |
| 'd': 56, |
| 'elixir': 57, |
| 'erlang': 58, |
| 'executable': 44, |
| 'f#': 87, |
| 'fortran': 59, |
| 'go': 60, |
| 'groovy': 88, |
| 'haskell': 61, |
| 'java': 62, |
| 'javascript': 63, |
| 'kotlin': 78, |
| 'lua': 64, |
| 'multi-file program': 89, |
| 'objective-c': 79, |
| 'ocaml': 65, |
| 'octave': 66, |
| 'pascal': 67, |
| 'perl': 85, |
| 'php': 68, |
| 'plain text': 43, |
| 'prolog': 69, |
| 'python': 71, |
| 'python2': 70, |
| 'python3': 71, |
| 'r': 80, |
| 'ruby': 72, |
| 'rust': 73, |
| 'scala': 81, |
| 'sql': 82, |
| 'swift': 83, |
| 'typescript': 74, |
| 'visual basic.net': 84 |
| } |
| PYTHON_ID = 71 |
|
|
| def __init__(self): |
| import os |
| self.endpoint = os.getenv('JUDGE0_ENDPOINT') |
| assert self.endpoint is not None, ( |
| 'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.') |
| x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN') |
| self.headers = {'Content-Type': 'application/json'} |
| if x_auth_token is not None: |
| self.headers['X-Auth-Token'] = x_auth_token |
|
|
| @staticmethod |
| def extract_code(completion: str, language: str) -> str: |
| pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL) |
| matches = pattern.findall(completion) |
| extracted_answer = matches[-1] if len(matches) >= 1 else '' |
| return extracted_answer |
|
|
| @classmethod |
| def get_language_id(cls, language): |
| if language is None: |
| return cls.PYTHON_ID |
| return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID) |
|
|
| async def _evaluate_code(self, code, test_cases, language_id): |
| import aiohttp |
| try: |
| passed = 0 |
| total = len(test_cases) |
|
|
| for case in test_cases: |
| if code is not None and code != '': |
| async with aiohttp.ClientSession() as session: |
| payload = { |
| 'source_code': code, |
| 'language_id': language_id, |
| 'stdin': case['input'], |
| 'expected_output': case['output'] |
| } |
| logger.debug(f'Payload: {payload}') |
| async with session.post( |
| self.endpoint + '/submissions/?wait=true', json=payload, |
| headers=self.headers) as response: |
| response_json = await response.json() |
| logger.debug(f'Response: {response_json}') |
| if response_json['status']['description'] == 'Accepted': |
| passed += 1 |
|
|
| success_rate = (passed / total) |
| return success_rate |
| except Exception as e: |
| logger.warning(f'Error from Judge0 executor: {e}') |
| return 0.0 |
|
|
| def run_async_from_sync(self): |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| try: |
| rewards = loop.run_until_complete(self.run_async()) |
| finally: |
| loop.close() |
| return rewards |
|
|
| async def run_async(self): |
| tasks = [ |
| self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language'])) |
| for code, info in zip(self.code_snippets, self.verification_info) |
| ] |
| results = await asyncio.gather(*tasks) |
| rewards = list(results) |
| return rewards |
|
|
| def __call__(self, completions, **kwargs) -> List[float]: |
| self.verification_info = kwargs['verification_info'] |
|
|
| languages = [info['language'] for info in self.verification_info] |
| self.code_snippets = [ |
| self.extract_code(completion, language) for completion, language in zip(completions, languages) |
| ] |
|
|
| try: |
| rewards = self.run_async_from_sync() |
| except Exception as e: |
| logger.warning(f'Error from Judge0 executor: {e}') |
| rewards = [0.0] * len(completions) |
| return rewards |
|
|
|
|
| orms['external_math_acc'] = MathAccuracy |
| orms['external_math_format'] = MathFormat |
| orms['external_countdown'] = CountdownORM |
| orms['external_r1v_acc'] = MultiModalAccuracyORM |
| orms['external_code_reward'] = CodeReward |
| orms['external_code_format'] = CodeFormat |
| orms['external_code_reward_by_judge0'] = CodeRewardByJudge0 |
|
|
|
|
| |
| class CustomizedRMPlugin: |
| """ |
| Customized Reward Model Plugin, same to DefaultRMPlugin |
| |
| It assumes that `self.model` is a classification model with a value head(output dimmension 1). |
| The first logits value from the model's output is used as the reward score. |
| """ |
|
|
| def __init__(self, model, template): |
| self.model = model |
| self.template: Template = template |
|
|
| def __call__(self, inputs): |
| batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs] |
| reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device) |
| reward_inputs.pop('labels') |
|
|
| with torch.inference_mode(): |
| return self.model(**reward_inputs).logits[:, 0] |
|
|
|
|
| rm_plugins['my_rmplugin'] = CustomizedRMPlugin |
|
|