File size: 15,605 Bytes
7134ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
import re
import torch
import torch.nn.functional as F
from collections import Counter
def key_step_single(content, key_step):
"""
Check if the key steps are present in the content.
"""
if isinstance(key_step, list):
key_step = key_step[0]
# step_texts = re.findall(r"### Step \d+:\s*(.*?)(?=\n###|\Z)", content, re.DOTALL)
step_texts = extract_key_steps(content)
# step_text_combined = " ".join(step_texts).lower()
steps_keywords = key_step.get("Steps", [])
total_keywords = len(steps_keywords)
matched_keywords = 0
for group in steps_keywords:
if isinstance(group, str):
group = [group]
if any(alt.lower() in step_texts for alt in group):
matched_keywords += 1
breakpoint()
reward = matched_keywords / total_keywords if total_keywords > 0 else 0.0
return reward
def extract_key_steps(text):
"""
Extract key steps from the text based on various possible markers.
"""
options = ["Step-by-Step Reasoning:", "**Step-by-Step Reasoning**:", "### Step 1:", "### Step 2:", "### Step 3:", "### Step 4:", "### Step 5:"]
key_steps = ""
for opt in options:
if opt in text:
key_steps = text.split(opt)[-1]
break
return key_steps.strip()
def extract_final_answer(text):
"""
Extract the final answer from the text based on various possible markers.
"""
options = ["The final answer is:", "**Final Answer:**", "**Final Answer**:", "Final Answer", "Answer", "Why take this action?:",
"**Final Answer**", "**Final Decision**:","Final Step:", "<CONCLUSION>"]
final_ans = ""
for opt in options:
if opt in text:
final_ans = text.split(opt)[-1]
break
return final_ans.strip()
def extract_options(text):
"""
Extract options from the text.
"""
pattern = r'([A-F])\)\s*(.+)'
matches = re.findall(pattern, text)
options = [(label, option) for label, option in matches]
if not options:
if "none of the " in text.lower():
return [('F', 'None of the option.')]
return [('none','none')]
return options
def description_accuracy_reward(completions, solution, key_steps, **kwargs):
"""Reward function that checks if the completion contains key phrases."""
contents = [completion for completion in completions]
rewards = []
# Extract keywords from solution
sol_final_answer = extract_final_answer(solution[0])
keywords = ["ego vehicle", "car directly ahead", "distance management", "barrier", "restricted access", "pedestrian", "safety", "navigation", "path", "caution", "speed"]
for content in contents:
reward = 0.0
student_final_answer = extract_final_answer(content)
# Check for presence of keywords
matched_keywords = [word for word in keywords if word in student_final_answer.lower()]
reward = len(matched_keywords) / len(keywords) if keywords else 0.0
rewards.append(reward)
return rewards
def description_accuracy_rewardV2(completions, solution, key_steps, **kwargs):
"""Reward function that checks if the completion contains key phrases and key steps."""
contents = [completion for completion in completions]
rewards = []
# Extract keywords from solution
sol_final_answer = extract_final_answer(solution[0])
keywords = ["ego vehicle", "car directly ahead", "distance management", "barrier", "restricted access", "pedestrian", "safety", "navigation", "path", "caution", "speed"]
for content in contents:
reward = 0.0
student_final_answer = extract_final_answer(content)
# Check for presence of keywords
matched_keywords = [word for word in keywords if word in student_final_answer.lower()]
keyword_reward = len(matched_keywords) / len(keywords) if keywords else 0.0
# Check for key steps
if key_steps:
step_reward = key_step_single(content, key_steps)
reward = 0.5 * keyword_reward + 0.5 * step_reward
else:
reward = keyword_reward
rewards.append(reward)
return rewards
def process_text_embed(model, tokenizer, sentence):
# model_path = "/high_perf_store/mlinfra-vepfs/qiankangan/DriveLMM-o1-main/all-MiniLM-L6-v2" # 确保这是模型文件夹的绝对路径
# model = SentenceTransformer(model_path)
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Sentences we want sentence embeddings for
# sentences = ['This is an example sentence', 'Each sentence is converted']
# Tokenize sentences
encoded_input = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
# breakpoint()
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# breakpoint()
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def description_semantic_reward(completions, solution, key_steps, **kwargs):
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('all-MiniLM-L6-v2')
"""Reward function that uses semantic similarity to evaluate the answer."""
contents = [completion for completion in completions]
rewards = []
# Encode the solution
sol_final_answer = extract_final_answer(solution[0])
solution_embedding = process_text_embed(model, tokenizer, [sol_final_answer])
for content in contents:
reward = 0.0
student_final_answer = extract_final_answer(content)
# Encode the student answer
student_embedding = process_text_embed(model, tokenizer, [student_final_answer])
# Calculate cosine similarity
breakpoint()
similarity = util.cos_sim(solution_embedding, student_embedding)
reward = float(similarity)
breakpoint()
rewards.append(reward)
return rewards
def mc_accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is correct for multiple-choice questions."""
contents = [completion for completion in completions]
rewards = []
# Extract correct answer from solution
sol_options = extract_options(solution[0])
breakpoint()
sol_correct_label = None
for label, option in sol_options:
if "correct" in option.lower() or "answer" in option.lower():
sol_correct_label = label
break
if not sol_correct_label:
sol_final_answer = extract_final_answer(solution[0])
sol_correct_label_match = re.search(r'([A-F])', sol_final_answer)
sol_correct_label = sol_correct_label_match.group(1) if sol_correct_label_match else None
breakpoint()
for content in contents:
reward = 0.0
# Extract student's answer
student_final_answer = extract_final_answer(content)
student_label_match = re.search(r'([A-F])', student_final_answer)
student_label = student_label_match.group(1) if student_label_match else None
# Compare labels
if student_label and sol_correct_label and student_label == sol_correct_label:
reward = 1.0
rewards.append(reward)
breakpoint()
return rewards
def check_reasoning_completeness(text):
"""
Check if the text contains all required sections.
"""
required_sections = [
"Step-by-Step Reasoning",
"Final Answer"
]
# breakpoint()
missing_sections = [section for section in required_sections if section not in text]
# breakpoint()
return not missing_sections
def check_reasoning_logic(text):
"""
Check if the sections appear in the correct order by finding their positions.
"""
# Find the start and end positions of '**Step-by-Step Reasoning**' section
reasoning_steps_start = text.find("**Step-by-Step Reasoning**")
if reasoning_steps_start == -1:
return False # '**Reasoning Steps**' not found
reasoning_steps_end = text.find("**Final Answer**", reasoning_steps_start)
if reasoning_steps_end == -1:
return False # '**Final Answer**' not found after '**Reasoning Steps**'
# Find the start position of '**Final Answer**' section
final_answer_start = text.find("**Final Answer**")
if final_answer_start == -1 or final_answer_start < reasoning_steps_start:
return False # '**Final Answer**' not found or not after '**Reasoning Steps**'
# Check if there is content between '**Reasoning Steps**' and '**Final Answer**'
reasoning_content = text[reasoning_steps_start:final_answer_start].strip()
if not reasoning_content:
return False # No content in '**Reasoning Steps**' section
# Check if there is content after '**Final Answer**'
final_answer_content = text[final_answer_start:].strip()
if not final_answer_content:
return False # No content in '**Final Answer**' section
return True
def description_validity_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
# Check if answer format is correct
pattern = r"\*\*Step-by-Step Reasoning\*\*:\n\n(\d+\. .+\n\n)+\*\*Final Answer\*\*: .+"
completion_contents = [completion for completion in completions]
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
# Check if the content includes all required sections and in correct order
validity_rewards = []
for content in completion_contents:
if not check_reasoning_completeness(content):
validity_rewards.append(0.0)
continue
if not check_reasoning_logic(content):
validity_rewards.append(0.0)
continue
validity_rewards.append(1.0)
breakpoint()
# Combine format and validity rewards
rewards = []
for format_match, validity in zip(matches, validity_rewards):
if format_match:
rewards.append(0.5 + validity * 0.5) # 0.5 for format, 0.5 for validity
else:
rewards.append(validity) # Only validity reward if format is incorrect
breakpoint()
return rewards
if __name__ == "__main__":
"""debug description"""
solution = ["**Step-by-Step Reasoning**:\n\n1. In the top middle image, the ego vehicle is right behind another vehicle. This car is directly in its path and needs to be considered for speed and distance management.\n2. The top middle image also shows a barrier ahead. This indicates restricted access which the ego vehicle must consider before proceeding.\n3. The top right image includes a person standing near the road. The ego vehicle should be cautious of pedestrians and maintain a safe distance. \n\n**Final Answer**: The three most important objects the ego vehicle must consider are the car directly ahead, the barrier indicating restricted access, and the pedestrian standing on the side of the road. These objects are crucial for safety and navigation."]
completion = ["**Step-by-Step Reasoning**:\n\n1. In the 'front' image, there is a car directly ahead of the ego vehicle. This is crucial for the ego vehicle to consider due to potential obstacles or traffic flow.\n\n2. In the 'front right' image, there are pedestrians visible on the pathway. The presence of pedestrians is significant as they may cross the path of the ego vehicle or be influenced its movement.\n\n3. In the 'back right' image, there is another vehicle approaching from the opposite direction. This vehicle is important for the ego vehicle to monitor to avoid collisions and ensure safe passage.\n\n**Final Answer**: The three most important objects in the scene are:\n1. The car directly in front of the ego vehicle in the 'front' image.\n2. The pedestrians on the sidewalk in the 'front right' image.\n3. The approaching vehicle in the 'back right' image."]
key_steps = [
{
"Steps": [
["car", "directly ahead"],
["barrier", "ahead", "restricted access"],
["person", "near road", " pedestrians"]
]
}
]
description_validity_reward(completion)
# description_accuracy_rewardV2(completions=completion, solution=solution, key_steps=key_steps)
# description_semantic_reward(completions=completion, solution=solution, key_steps=key_steps)
"""debug mc"""
# solution = ["**Step-by-Step Reasoning**:\n\n1. In the \"front\" image of the top row, there is a black car ahead, in a lane that is narrowed by a divider. This suggests limited space for maneuvering.\n2. In the \"back\" image of the bottom row, the car is in a lane that has space on the left, but the right side is blocked by another car.\n3. Considering the options, \"Change lane to right\" (D) is not feasible due to the blockage. \"Come to a complete stop\" (A) or \"Brake suddenly\" (B) might not ensure collision avoidance given the proximity of the car ahead.\n4. \"Speed up\" (C) is not ideal in a potentially congested area. The safest way could be to \"Change lane to left\" (E) since it appears more space is available.\n\n**Final Answer**: E) Change lane to left."]
# completion = ["**Step-by-Step Reasoning**:\n\n1. In the 'front' image, there's a car directly ahead of the ego vehicle, which indicates potential for collision if movement continues.\n2. The current speed of the ego vehicle is very low at 0.82 km/h, suggesting that stopping would be effective without skidding or abrupt maneuvers.\n3. There are no visible obstacles or safe lanes to the left or right that are clear for immediate lane changes.\n4. Given the proximity to the car ahead and lack of safe maneuver spaces, continuing forward is unsafe.\n\n**Final Answer**: A) Come to a complete stop."]
# completion = ["**Step-by-Step Reasoning**:\n\n1. In the 'front' image, there's a car directly ahead of the ego vehicle, which indicates potential for collision if movement continues.\n2. The current speed of the ego vehicle is very low at 0.82 km/h, suggesting that stopping would be effective without skidding or abrupt maneuvers.\n3. There are no visible obstacles or safe lanes to the left or right that are clear for immediate lane changes.\n4. Given the proximity to the car ahead and lack of safe maneuver spaces, continuing forward is unsafe.\n\n**Final Answer**: E) Change lane to left."]
# mc_accuracy_reward(completions=completion, solution=solution)
|