| | import re |
| | import textwrap |
| | from copy import deepcopy |
| | from typing import Dict, List |
| |
|
| | import torch |
| |
|
| | from swift.llm import PtEngine, RequestConfig, Template, to_device |
| | from swift.llm.infer.protocol import ChatCompletionResponse |
| | from swift.utils import get_logger |
| |
|
| | logger = get_logger() |
| |
|
| |
|
| | class DefaultRMPlugin: |
| | """ |
| | Default Reward Model Plugin |
| | |
| | This class implements the default processing logic for reward models. |
| | It assumes that `self.model` is a classification model with a value head(output dimmension 1). |
| | The first logits value from the model's output is used as the reward score. |
| | """ |
| |
|
| | def __init__(self, model, template): |
| | self.model = model |
| | self.template: Template = template |
| |
|
| | def __call__(self, inputs): |
| | batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs] |
| | reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device) |
| | reward_inputs.pop('labels') |
| |
|
| | with torch.inference_mode(): |
| | return self.model(**reward_inputs).logits[:, 0] |
| |
|
| |
|
| | class GenRMPlugin(DefaultRMPlugin): |
| |
|
| | def __init__(self, model, template): |
| | """ |
| | Generative Reward Model Plugin Example. |
| | |
| | This method sets up the reward model plugin by initializing the PtEngine for efficient inference, |
| | configuring the request parameters, and defining the system prompt that guides the reward model in |
| | evaluating responses. |
| | |
| | Args: |
| | model (torch.nn.Module): The generative reward model. |
| | template (Template): The template used for encoding input data. |
| | """ |
| |
|
| | super().__init__(model, template) |
| | |
| | self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) |
| | self.request_config = RequestConfig() |
| | self.system = textwrap.dedent(""" |
| | Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant. |
| | Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct. |
| | Before finishing your response, please assign a reward using the following format: |
| | |
| | Reward: {reward} |
| | |
| | For example: |
| | Reward: 0.85 |
| | """) |
| |
|
| | def __call__(self, inputs): |
| | """ |
| | Compute reward scores for the provided inputs. |
| | |
| | This method processes each input by converting dialogue messages into a query, sending the query to the |
| | reward model for inference, and extracting the reward scores from the model's responses. The final reward |
| | for each input is the average of all extracted scores. |
| | Args: |
| | inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing: |
| | - 'messages' (List[Dict]): messages from the training model. Each message dictionary includes: |
| | - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). |
| | - 'content' (str): The content of the message. |
| | - Additional dataset columns as key-value pairs (e.g., 'solutions', 'images'). |
| | Returns: |
| | torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,), |
| | where N is the number of input requests. |
| | """ |
| |
|
| | rm_inputs = self.prepare_rm_inputs(inputs) |
| | results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False) |
| | rewards = self.compute_rewards(results) |
| | return torch.tensor(rewards, dtype=torch.float32) |
| |
|
| | def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]: |
| | """ |
| | Prepare inputs for the reward model by converting messages into queries. |
| | |
| | Args: |
| | inputs (List[Dict]): A list of input requests. |
| | |
| | Returns: |
| | List[Dict]: Processed inputs for the reward model. |
| | """ |
| | rm_inputs = [] |
| | for idx, infer_request in enumerate(inputs): |
| | |
| | rm_infer_request = deepcopy(infer_request) |
| |
|
| | |
| | messages = rm_infer_request.get('messages') |
| | query = self.messages_to_query(messages) |
| |
|
| | |
| | rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}] |
| |
|
| | |
| | rm_infer_request['messages'] = rm_messages |
| | rm_inputs.append(rm_infer_request) |
| | return rm_inputs |
| |
|
| | @staticmethod |
| | def extract_reward(model_output: str) -> float: |
| | """ |
| | Extract the reward score from the model's output. |
| | |
| | Args: |
| | model_output (str): The model's output string, expected to follow the format "Reward: {reward}". |
| | |
| | Returns: |
| | float: The extracted reward score. |
| | |
| | Raises: |
| | ValueError: If the reward score cannot be extracted or the format is incorrect. |
| | """ |
| | match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output) |
| | if match: |
| | return float(match.group(1)) |
| | else: |
| | logger.warning("Unable to extract reward score from the model's output, set reward to 0") |
| | return None |
| |
|
| | @staticmethod |
| | def messages_to_query(messages): |
| | """ |
| | Compress a list of message dictionaries into a single query string. |
| | |
| | Args: |
| | messages (list[dict]): A list of message dictionaries, each containing: |
| | - 'role' (str): The role of the speaker (e.g., 'user', 'assistant'). |
| | - 'content' (str): The content of the message. |
| | |
| | Returns: |
| | str: A single string that concatenates all messages in a formatted manner. |
| | |
| | Example: |
| | >>> messages = [ |
| | ... {'role': 'user', 'content': 'Hello, how are you?'}, |
| | ... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'}, |
| | ... {'role': 'user', 'content': 'Can you help me with my homework?'} |
| | ... ] |
| | >>> print(messages_to_query(messages)) |
| | User: Hello, how are you? |
| | Assistant: I am fine, thank you! How can I assist you today? |
| | User: Can you help me with my homework? |
| | """ |
| | |
| | formatted_messages = [] |
| |
|
| | |
| | role_mapping = { |
| | 'user': 'User', |
| | 'assistant': 'Assistant', |
| | 'system': 'System' |
| | |
| | } |
| |
|
| | for idx, message in enumerate(messages): |
| | if not isinstance(message, dict): |
| | raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.') |
| |
|
| | |
| | role = message.get('role') |
| | content = message.get('content') |
| | if not content: |
| | continue |
| |
|
| | |
| | role_formatted = role_mapping.get(role.lower(), role.capitalize()) |
| |
|
| | |
| | formatted_messages.append(f'{role_formatted}: {content}') |
| |
|
| | |
| | query = '\n'.join(formatted_messages) |
| |
|
| | return query |
| |
|
| | def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]: |
| | """ |
| | Compute average reward scores from the reward model's outputs. |
| | |
| | Args: |
| | results (List[ChatCompletionResponse]): A list of results from the reward model. |
| | |
| | Returns: |
| | List[float]: A list of average reward scores. |
| | """ |
| | rewards = [] |
| | for idx, output in enumerate(results): |
| | try: |
| | cur_rewards = [] |
| | for choice in output.choices: |
| | response = choice.message.content |
| | reward = self.extract_reward(response) |
| | cur_rewards.append(reward) |
| | cur_rewards = [r for r in cur_rewards if r is not None] |
| | if cur_rewards: |
| | average_reward = sum(cur_rewards) / len(cur_rewards) |
| | else: |
| | average_reward = 0.0 |
| | logger.warning('No valid rewards extracted. Assigning reward score of 0.0.') |
| |
|
| | rewards.append(average_reward) |
| | except Exception as e: |
| | logger.error(f'Error computing reward: {e}') |
| | rewards.append(0.0) |
| | return rewards |
| |
|
| |
|
| | rm_plugins = { |
| | 'default': DefaultRMPlugin, |
| | 'genrm': GenRMPlugin, |
| | } |
| |
|