|
|
import re
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from collections import Counter
|
|
|
|
|
|
def key_step_single(content, key_step):
|
|
|
"""
|
|
|
Check if the key steps are present in the content.
|
|
|
"""
|
|
|
if isinstance(key_step, list):
|
|
|
key_step = key_step[0]
|
|
|
|
|
|
step_texts = extract_key_steps(content)
|
|
|
|
|
|
|
|
|
steps_keywords = key_step.get("Steps", [])
|
|
|
total_keywords = len(steps_keywords)
|
|
|
|
|
|
matched_keywords = 0
|
|
|
for group in steps_keywords:
|
|
|
if isinstance(group, str):
|
|
|
group = [group]
|
|
|
if any(alt.lower() in step_texts for alt in group):
|
|
|
matched_keywords += 1
|
|
|
breakpoint()
|
|
|
reward = matched_keywords / total_keywords if total_keywords > 0 else 0.0
|
|
|
return reward
|
|
|
|
|
|
def extract_key_steps(text):
|
|
|
"""
|
|
|
Extract key steps from the text based on various possible markers.
|
|
|
"""
|
|
|
options = ["Step-by-Step Reasoning:", "**Step-by-Step Reasoning**:", "### Step 1:", "### Step 2:", "### Step 3:", "### Step 4:", "### Step 5:"]
|
|
|
key_steps = ""
|
|
|
for opt in options:
|
|
|
if opt in text:
|
|
|
key_steps = text.split(opt)[-1]
|
|
|
break
|
|
|
return key_steps.strip()
|
|
|
|
|
|
def extract_final_answer(text):
|
|
|
"""
|
|
|
Extract the final answer from the text based on various possible markers.
|
|
|
"""
|
|
|
options = ["The final answer is:", "**Final Answer:**", "**Final Answer**:", "Final Answer", "Answer", "Why take this action?:",
|
|
|
"**Final Answer**", "**Final Decision**:","Final Step:", "<CONCLUSION>"]
|
|
|
final_ans = ""
|
|
|
for opt in options:
|
|
|
if opt in text:
|
|
|
final_ans = text.split(opt)[-1]
|
|
|
break
|
|
|
return final_ans.strip()
|
|
|
|
|
|
def extract_options(text):
|
|
|
"""
|
|
|
Extract options from the text.
|
|
|
"""
|
|
|
pattern = r'([A-F])\)\s*(.+)'
|
|
|
matches = re.findall(pattern, text)
|
|
|
options = [(label, option) for label, option in matches]
|
|
|
if not options:
|
|
|
if "none of the " in text.lower():
|
|
|
return [('F', 'None of the option.')]
|
|
|
return [('none','none')]
|
|
|
return options
|
|
|
|
|
|
def description_accuracy_reward(completions, solution, key_steps, **kwargs):
|
|
|
"""Reward function that checks if the completion contains key phrases."""
|
|
|
contents = [completion for completion in completions]
|
|
|
rewards = []
|
|
|
|
|
|
|
|
|
sol_final_answer = extract_final_answer(solution[0])
|
|
|
keywords = ["ego vehicle", "car directly ahead", "distance management", "barrier", "restricted access", "pedestrian", "safety", "navigation", "path", "caution", "speed"]
|
|
|
|
|
|
for content in contents:
|
|
|
reward = 0.0
|
|
|
student_final_answer = extract_final_answer(content)
|
|
|
|
|
|
matched_keywords = [word for word in keywords if word in student_final_answer.lower()]
|
|
|
reward = len(matched_keywords) / len(keywords) if keywords else 0.0
|
|
|
|
|
|
rewards.append(reward)
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
|
|
|
def description_accuracy_rewardV2(completions, solution, key_steps, **kwargs):
|
|
|
"""Reward function that checks if the completion contains key phrases and key steps."""
|
|
|
contents = [completion for completion in completions]
|
|
|
rewards = []
|
|
|
|
|
|
|
|
|
sol_final_answer = extract_final_answer(solution[0])
|
|
|
keywords = ["ego vehicle", "car directly ahead", "distance management", "barrier", "restricted access", "pedestrian", "safety", "navigation", "path", "caution", "speed"]
|
|
|
|
|
|
for content in contents:
|
|
|
reward = 0.0
|
|
|
student_final_answer = extract_final_answer(content)
|
|
|
|
|
|
matched_keywords = [word for word in keywords if word in student_final_answer.lower()]
|
|
|
keyword_reward = len(matched_keywords) / len(keywords) if keywords else 0.0
|
|
|
|
|
|
|
|
|
if key_steps:
|
|
|
step_reward = key_step_single(content, key_steps)
|
|
|
reward = 0.5 * keyword_reward + 0.5 * step_reward
|
|
|
else:
|
|
|
reward = keyword_reward
|
|
|
|
|
|
rewards.append(reward)
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
def process_text_embed(model, tokenizer, sentence):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mean_pooling(model_output, attention_mask):
|
|
|
token_embeddings = model_output[0]
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoded_input = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
model_output = model(**encoded_input)
|
|
|
|
|
|
|
|
|
|
|
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
|
|
|
|
|
|
|
|
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
|
|
return sentence_embeddings
|
|
|
|
|
|
def description_semantic_reward(completions, solution, key_steps, **kwargs):
|
|
|
from sentence_transformers import SentenceTransformer, util
|
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
tokenizer = AutoTokenizer.from_pretrained('all-MiniLM-L6-v2')
|
|
|
model = AutoModel.from_pretrained('all-MiniLM-L6-v2')
|
|
|
|
|
|
"""Reward function that uses semantic similarity to evaluate the answer."""
|
|
|
contents = [completion for completion in completions]
|
|
|
rewards = []
|
|
|
|
|
|
|
|
|
sol_final_answer = extract_final_answer(solution[0])
|
|
|
solution_embedding = process_text_embed(model, tokenizer, [sol_final_answer])
|
|
|
|
|
|
for content in contents:
|
|
|
reward = 0.0
|
|
|
|
|
|
student_final_answer = extract_final_answer(content)
|
|
|
|
|
|
student_embedding = process_text_embed(model, tokenizer, [student_final_answer])
|
|
|
|
|
|
breakpoint()
|
|
|
similarity = util.cos_sim(solution_embedding, student_embedding)
|
|
|
reward = float(similarity)
|
|
|
breakpoint()
|
|
|
rewards.append(reward)
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
def mc_accuracy_reward(completions, solution, **kwargs):
|
|
|
"""Reward function that checks if the completion is correct for multiple-choice questions."""
|
|
|
contents = [completion for completion in completions]
|
|
|
rewards = []
|
|
|
|
|
|
|
|
|
sol_options = extract_options(solution[0])
|
|
|
breakpoint()
|
|
|
sol_correct_label = None
|
|
|
for label, option in sol_options:
|
|
|
if "correct" in option.lower() or "answer" in option.lower():
|
|
|
sol_correct_label = label
|
|
|
break
|
|
|
if not sol_correct_label:
|
|
|
sol_final_answer = extract_final_answer(solution[0])
|
|
|
sol_correct_label_match = re.search(r'([A-F])', sol_final_answer)
|
|
|
sol_correct_label = sol_correct_label_match.group(1) if sol_correct_label_match else None
|
|
|
|
|
|
breakpoint()
|
|
|
for content in contents:
|
|
|
reward = 0.0
|
|
|
|
|
|
|
|
|
student_final_answer = extract_final_answer(content)
|
|
|
student_label_match = re.search(r'([A-F])', student_final_answer)
|
|
|
student_label = student_label_match.group(1) if student_label_match else None
|
|
|
|
|
|
|
|
|
if student_label and sol_correct_label and student_label == sol_correct_label:
|
|
|
reward = 1.0
|
|
|
|
|
|
rewards.append(reward)
|
|
|
breakpoint()
|
|
|
return rewards
|
|
|
|
|
|
def check_reasoning_completeness(text):
|
|
|
"""
|
|
|
Check if the text contains all required sections.
|
|
|
"""
|
|
|
required_sections = [
|
|
|
"Step-by-Step Reasoning",
|
|
|
"Final Answer"
|
|
|
]
|
|
|
|
|
|
missing_sections = [section for section in required_sections if section not in text]
|
|
|
|
|
|
return not missing_sections
|
|
|
|
|
|
def check_reasoning_logic(text):
|
|
|
"""
|
|
|
Check if the sections appear in the correct order by finding their positions.
|
|
|
"""
|
|
|
|
|
|
reasoning_steps_start = text.find("**Step-by-Step Reasoning**")
|
|
|
if reasoning_steps_start == -1:
|
|
|
return False
|
|
|
|
|
|
reasoning_steps_end = text.find("**Final Answer**", reasoning_steps_start)
|
|
|
if reasoning_steps_end == -1:
|
|
|
return False
|
|
|
|
|
|
|
|
|
final_answer_start = text.find("**Final Answer**")
|
|
|
if final_answer_start == -1 or final_answer_start < reasoning_steps_start:
|
|
|
return False
|
|
|
|
|
|
|
|
|
reasoning_content = text[reasoning_steps_start:final_answer_start].strip()
|
|
|
if not reasoning_content:
|
|
|
return False
|
|
|
|
|
|
|
|
|
final_answer_content = text[final_answer_start:].strip()
|
|
|
if not final_answer_content:
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
def description_validity_reward(completions, **kwargs):
|
|
|
"""Reward function that checks if the completion has a specific format."""
|
|
|
|
|
|
pattern = r"\*\*Step-by-Step Reasoning\*\*:\n\n(\d+\. .+\n\n)+\*\*Final Answer\*\*: .+"
|
|
|
completion_contents = [completion for completion in completions]
|
|
|
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
|
|
|
|
|
|
|
|
validity_rewards = []
|
|
|
for content in completion_contents:
|
|
|
if not check_reasoning_completeness(content):
|
|
|
validity_rewards.append(0.0)
|
|
|
continue
|
|
|
if not check_reasoning_logic(content):
|
|
|
validity_rewards.append(0.0)
|
|
|
continue
|
|
|
validity_rewards.append(1.0)
|
|
|
breakpoint()
|
|
|
|
|
|
rewards = []
|
|
|
for format_match, validity in zip(matches, validity_rewards):
|
|
|
if format_match:
|
|
|
rewards.append(0.5 + validity * 0.5)
|
|
|
else:
|
|
|
rewards.append(validity)
|
|
|
breakpoint()
|
|
|
return rewards
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
"""debug description"""
|
|
|
solution = ["**Step-by-Step Reasoning**:\n\n1. In the top middle image, the ego vehicle is right behind another vehicle. This car is directly in its path and needs to be considered for speed and distance management.\n2. The top middle image also shows a barrier ahead. This indicates restricted access which the ego vehicle must consider before proceeding.\n3. The top right image includes a person standing near the road. The ego vehicle should be cautious of pedestrians and maintain a safe distance. \n\n**Final Answer**: The three most important objects the ego vehicle must consider are the car directly ahead, the barrier indicating restricted access, and the pedestrian standing on the side of the road. These objects are crucial for safety and navigation."]
|
|
|
|
|
|
completion = ["**Step-by-Step Reasoning**:\n\n1. In the 'front' image, there is a car directly ahead of the ego vehicle. This is crucial for the ego vehicle to consider due to potential obstacles or traffic flow.\n\n2. In the 'front right' image, there are pedestrians visible on the pathway. The presence of pedestrians is significant as they may cross the path of the ego vehicle or be influenced its movement.\n\n3. In the 'back right' image, there is another vehicle approaching from the opposite direction. This vehicle is important for the ego vehicle to monitor to avoid collisions and ensure safe passage.\n\n**Final Answer**: The three most important objects in the scene are:\n1. The car directly in front of the ego vehicle in the 'front' image.\n2. The pedestrians on the sidewalk in the 'front right' image.\n3. The approaching vehicle in the 'back right' image."]
|
|
|
|
|
|
key_steps = [
|
|
|
{
|
|
|
"Steps": [
|
|
|
["car", "directly ahead"],
|
|
|
["barrier", "ahead", "restricted access"],
|
|
|
["person", "near road", " pedestrians"]
|
|
|
]
|
|
|
}
|
|
|
]
|
|
|
description_validity_reward(completion)
|
|
|
|
|
|
|
|
|
"""debug mc"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|