| | from typing import Dict, List, Optional, Tuple |
| |
|
| | from swift.llm.template import split_str_parts_by |
| |
|
| |
|
| | def calculate_loss_scale(query: str, |
| | response: str, |
| | response_loss_scale_map: Dict[str, list], |
| | query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]: |
| | """Calculate the loss scale by splitting the agent response. |
| | |
| | This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf |
| | |
| | Agent response format: |
| | |
| | ```text |
| | Thought: you should always think about what to do |
| | Action: the action to take, should be one of the above tools[fire_recognition, |
| | fire_alert, call_police, call_fireman] |
| | Action Input: the input to the action |
| | Observation: the result of the action |
| | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) |
| | Thought: I now know the final answer |
| | Final Answer: the final answer to the original input question |
| | ``` |
| | Returns: |
| | A tuple of agent response parts and their weights. |
| | """ |
| | |
| | if query_loss_scale_map is not None: |
| | for key in query_loss_scale_map.keys(): |
| | if key in query: |
| | if isinstance(query_loss_scale_map[key], (float, int)): |
| | query_loss_scale_map[key] = [query_loss_scale_map[key]] |
| | loss_scale_value = query_loss_scale_map[key][0] |
| | return [response], [float(loss_scale_value)] |
| | delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2] |
| | if delimiters: |
| | agent_parts = split_str_parts_by(response, delimiters) |
| | else: |
| | regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1] |
| | agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True) |
| | weights = [] |
| | agent_content = [] |
| | for c in agent_parts: |
| | if c['key'] in response_loss_scale_map: |
| | loss_scale = response_loss_scale_map[c['key']] |
| | assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}' |
| | if len(loss_scale) == 1: |
| | weights += loss_scale |
| | agent_content.append(c['content']) |
| | else: |
| | weights += loss_scale |
| | agent_content += [c['key'], c['content']] |
| | else: |
| | weights.append(1.) |
| | agent_content.append(c['content']) |
| | return agent_content, weights |
| |
|