| |
|
| |
|
| |
|
| |
|
| | import json
|
| | import os
|
| | import argparse
|
| | import logging
|
| | import time
|
| | import threading
|
| | import functools
|
| | import concurrent.futures
|
| | import platform
|
| | from tqdm import tqdm
|
| | from typing import Any, List, Tuple, Dict
|
| | from tenacity import retry, stop_after_attempt, wait_random_exponential
|
| | from concurrent.futures import ThreadPoolExecutor
|
| | from scripts.tools.tool_libraries import FuncAgent
|
| | from openai import OpenAI
|
| |
|
| |
|
| | client = OpenAI(
|
| | api_key="your_api_key_here",
|
| | base_url="https://api.example.com/v1"
|
| | )
|
| |
|
| |
|
| | class TimeoutError(Exception):
|
| | """Custom exception for function timeout"""
|
| | pass
|
| |
|
| |
|
| | def timeout(seconds):
|
| | """
|
| | Decorator: Adds timeout mechanism to decorated functions.
|
| |
|
| | Uses signal.SIGALRM on Unix systems and threading on Windows.
|
| |
|
| | Args:
|
| | seconds (int): Timeout duration in seconds
|
| | """
|
| | def decorator(func):
|
| | if platform.system() != "Windows":
|
| |
|
| | import signal
|
| |
|
| | def _handle(signum, frame):
|
| | raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds}s")
|
| |
|
| | @functools.wraps(func)
|
| | def wrapper(*args, **kwargs):
|
| | old_handler = signal.signal(signal.SIGALRM, _handle)
|
| | signal.setitimer(signal.ITIMER_REAL, seconds)
|
| | try:
|
| | return func(*args, **kwargs)
|
| | finally:
|
| | signal.setitimer(signal.ITIMER_REAL, 0)
|
| | signal.signal(signal.SIGALRM, old_handler)
|
| | return wrapper
|
| | else:
|
| |
|
| | @functools.wraps(func)
|
| | def wrapper(*args, **kwargs):
|
| | result = [TimeoutError(f"Function '{func.__name__}' timed out after {seconds}s")]
|
| |
|
| | def target():
|
| | try:
|
| | result[0] = func(*args, **kwargs)
|
| | except Exception as e:
|
| | result[0] = e
|
| |
|
| | t = threading.Thread(target=target, daemon=True)
|
| | t.start()
|
| | t.join(seconds)
|
| |
|
| | if t.is_alive():
|
| | raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds}s")
|
| | if isinstance(result[0], Exception):
|
| | raise result[0]
|
| | return result[0]
|
| | return wrapper
|
| | return decorator
|
| |
|
| |
|
| | @retry(wait=wait_random_exponential(min=3, max=10), stop=stop_after_attempt(5))
|
| | def completion_with_backoff(**kwargs):
|
| | """Call OpenAI API with exponential backoff retry strategy"""
|
| | return client.chat.completions.create(**kwargs)
|
| |
|
| |
|
| | def read_json(json_file: str) -> Any:
|
| | """Read JSON file and return its content"""
|
| | with open(json_file, "r", encoding="utf-8") as file:
|
| | return json.load(file)
|
| |
|
| |
|
| | def save_progress(file_path: str, index: int) -> None:
|
| | """
|
| | Save processing progress to a JSON file
|
| |
|
| | Args:
|
| | file_path: Path to save progress file
|
| | index: Current processing index
|
| | """
|
| | with open(file_path, 'w', encoding='utf-8') as f:
|
| | json.dump({"processed_index": index}, f, ensure_ascii=False, indent=4)
|
| |
|
| |
|
| | def get_system_prompt() -> str:
|
| | """Generate system prompt for autonomous driving agent"""
|
| | return """
|
| | **A Language Agent for Autonomous Driving**
|
| | Role: You are the brain of an autonomous vehicle (a.k.a. ego-vehicle).
|
| | Your task is to extract necessary information from the driving scenario that will be useful for driving question answering.
|
| |
|
| | Necessary information might include:
|
| | - Visual infos: Visual information from specific cameras
|
| | - Detections: Detected objects that require attention
|
| | - Predictions: Estimated future motions of detected objects
|
| | - Maps: Traffic lanes and road boundaries
|
| | - Occupancy: Whether locations are occupied by other objects
|
| |
|
| | Task:
|
| | - Determine what types of information (Visual info, Detections, Predictions, Maps, Occupancy) to extract
|
| | - Prioritize Detections and Predictions for motion planning questions
|
| | - Focus on Maps information for lane maintenance and road layout questions
|
| | """
|
| |
|
| |
|
| | def get_user_prompt(raw_question: str, reason_part: str, raw_answer: str) -> str:
|
| | """
|
| | Generate user prompt for tool selection
|
| |
|
| | Args:
|
| | raw_question: The driving question to answer
|
| | reason_part: Reasoning steps from existing CoT data
|
| | raw_answer: Final answer to the question
|
| |
|
| | Returns:
|
| | Formatted user prompt
|
| | """
|
| | input_info = f"The question you need to answer is: {raw_question}\nThe final answer to this question is: {raw_answer}"
|
| | tool_info_intro = generate_func_prompt()
|
| |
|
| | output_format = """
|
| | Please choose tools to answer this problem and support the final answer.
|
| | Return a list of tool names (no more than 4) with "tool_name" and related "parameters".
|
| | If parameters are blank, use [""] as the value.
|
| |
|
| | Output format example:
|
| | [
|
| | {"tool_name": "function_name1", "parameters": ["arg1", "arg2"]},
|
| | {"tool_name": "function_name2", "parameters": ["arg1", ["option1", "option2"]]}
|
| | ]
|
| |
|
| | STRICTLY FOLLOW THE JSON RESPONSE FORMAT.
|
| | RESPONSE MUST START WITH "[{" AND END WITH "}]".
|
| | DO NOT START WITH "```json" OR ANY MARKDOWN.
|
| | """
|
| |
|
| | return f"{input_info}\nAvailable tools: {tool_info_intro}\n{output_format}"
|
| |
|
| |
|
| | def generate_func_prompt(debug: bool = False) -> str:
|
| | """
|
| | Generate prompt listing available functions and their parameters
|
| |
|
| | Args:
|
| | debug: Whether to print debug information
|
| |
|
| | Returns:
|
| | Formatted function prompt
|
| | """
|
| | try:
|
| | func_agent = FuncAgent()
|
| | function_list = (
|
| | func_agent.detection_func_infos +
|
| | func_agent.map_func_infos +
|
| | func_agent.prediction_func_infos +
|
| | func_agent.occupancy_func_infos +
|
| | func_agent.visual_func_infos
|
| | )
|
| | except Exception:
|
| | func_agent = FuncAgent()
|
| | function_list = (
|
| | func_agent.detection_func_infos +
|
| | func_agent.map_func_infos +
|
| | func_agent.prediction_func_infos +
|
| | func_agent.occupancy_func_infos +
|
| | func_agent.visual_func_infos
|
| | )
|
| |
|
| | prompt = "Available functions:\n"
|
| | for info in function_list:
|
| | param_str = ", ".join(info["parameters"]["required"]) if info["parameters"].get("required") else ""
|
| | prompt += f"- {info['name']}({param_str}) # {info['description']}\n"
|
| |
|
| | if debug:
|
| | print(prompt)
|
| |
|
| | return prompt
|
| |
|
| |
|
| | def generate_choose_func_prompt(tool_list: List[Dict]) -> str:
|
| | """
|
| | Generate prompt for selected tools
|
| |
|
| | Args:
|
| | tool_list: List of selected tools
|
| |
|
| | Returns:
|
| | Formatted prompt for selected tools
|
| | """
|
| | func_agent = FuncAgent()
|
| | function_list = (
|
| | func_agent.detection_func_infos +
|
| | func_agent.map_func_infos +
|
| | func_agent.prediction_func_infos +
|
| | func_agent.occupancy_func_infos +
|
| | func_agent.visual_func_infos
|
| | )
|
| |
|
| | prompt = "Selected tools:\n"
|
| | for tool_sample in tool_list:
|
| | tool_name = tool_sample['tool_name']
|
| | for info in function_list:
|
| | if info['name'] == tool_name:
|
| | param_str = ", ".join(info["parameters"]["required"]) if info["parameters"].get("required") else ""
|
| | prompt += f"- {info['name']}({param_str}) # {info['description']}\n"
|
| |
|
| | return prompt
|
| |
|
| |
|
| | def get_cot_system_prompt() -> str:
|
| | """Generate system prompt for CoT generation"""
|
| | return """You're an autonomous driving inference optimization expert.
|
| | Your task is to decompose short CoT data into finer-grained atomic steps and reorganize them for optimal reasoning path.
|
| | Break down reasoning into minimal units, dynamically cluster atomic steps, extract and label key steps, add tool invocations,
|
| | and form the optimal reasoning path for the current problem."""
|
| |
|
| |
|
| | def get_cot_user_prompt(raw_question: str, reason_part: str, final_answer: str, tool_choose_list: List[Dict]) -> str:
|
| | """
|
| | Generate user prompt for CoT generation
|
| |
|
| | Args:
|
| | raw_question: The driving question
|
| | reason_part: Existing reasoning steps
|
| | final_answer: Final answer to the question
|
| | tool_choose_list: List of selected tools
|
| |
|
| | Returns:
|
| | Formatted CoT user prompt
|
| | """
|
| | input_info = (
|
| | f"Raw question: {raw_question}\n"
|
| | f"Existing CoT data: {reason_part}\n"
|
| | f"Final answer: {final_answer}\n\n"
|
| | "Please reconstruct the raw CoT data into atomic CoT data:"
|
| | )
|
| |
|
| | tool_use_prompt = generate_choose_func_prompt(tool_choose_list)
|
| | tool_use_info = (
|
| | f"For each atomic step, choose appropriate tools and parameters.\n"
|
| | f"Available tools: {tool_use_prompt}\n"
|
| | )
|
| |
|
| | output_format = """
|
| | Also generate:
|
| | - Sub-question for each action/reasoning step (perception, prediction, planning)
|
| | - Guess answer for each Sub if you're confident
|
| | - Keywords (2-5 synonyms/alternatives) related to the Guess Answer
|
| | - Missing_flag: "False" if you can't answer, "True" otherwise
|
| | - next_action: "continue reasoning" or "conclude"
|
| | - Continue until reasoning chain is complete
|
| | - Final answer and its keywords
|
| |
|
| | Output format example:
|
| | {
|
| | "Question": "",
|
| | "Chain": [
|
| | {
|
| | "Tool": {"function_name": "tool1", "parameters": ["param1"]},
|
| | "Sub": "Sub-question 1",
|
| | "Guess_Answer": "Answer 1",
|
| | "key_words": ["word1", "word2"],
|
| | "Missing_flag": "False",
|
| | "next_action": "continue reasoning"
|
| | },
|
| | ...
|
| | ],
|
| | "final_answer_keywords": ["keyword1", "keyword2"],
|
| | "final_answer": "Final answer"
|
| | }
|
| |
|
| | STRICTLY FOLLOW THE JSON RESPONSE FORMAT.
|
| | RESPONSE MUST START WITH "{".
|
| | DO NOT START WITH "```json" OR ANY MARKDOWN.
|
| | """
|
| |
|
| | return input_info + tool_use_info + output_format
|
| |
|
| |
|
| | def extract_key_steps(text: str) -> str:
|
| | """
|
| | Extract key reasoning steps from text
|
| |
|
| | Args:
|
| | text: Text containing reasoning steps
|
| |
|
| | Returns:
|
| | Cleaned reasoning steps
|
| | """
|
| | stop_markers = [
|
| | "The final answer is:", "**Final Answer:**", "**Final Answer**:",
|
| | "Final Answer", "Answer", "Why take this action?:",
|
| | "**Final Answer**", "**Final Decision**:", "Final Step:", "<CONCLUSION>"
|
| | ]
|
| |
|
| | start_marker = "**Step-by-Step Reasoning**:"
|
| |
|
| | if start_marker in text:
|
| | text_parts = text.split(start_marker)
|
| | relevant_part = text_parts[1].strip()
|
| |
|
| | for marker in stop_markers:
|
| | if marker in relevant_part:
|
| | relevant_part = relevant_part.split(marker)[0].strip()
|
| |
|
| | return relevant_part
|
| |
|
| | return ""
|
| |
|
| |
|
| | def extract_final_answer(text: str) -> str:
|
| | """
|
| | Extract final answer from text
|
| |
|
| | Args:
|
| | text: Text containing final answer
|
| |
|
| | Returns:
|
| | Extracted final answer
|
| | """
|
| | options = [
|
| | "The final answer is:", "**Final Answer:**", "**Final Answer**:",
|
| | "Final Answer", "Answer", "Why take this action?:",
|
| | "**Final Answer**", "**Final Decision**:", "Final Step:", "<CONCLUSION>"
|
| | ]
|
| |
|
| | for opt in options:
|
| | if opt in text:
|
| | return text.split(opt)[-1].strip()
|
| |
|
| | return ""
|
| |
|
| |
|
| | @timeout(100)
|
| | def run_one_round_conversation(
|
| | full_messages: List[Dict],
|
| | system_message: str,
|
| | user_message: str,
|
| | temperature: float = 0.0,
|
| | model_name: str = "gpt-4o-mini"
|
| | ) -> Tuple[List[Dict], str]:
|
| | """
|
| | Perform one round of conversation using OpenAI API
|
| |
|
| | Args:
|
| | full_messages: Conversation history
|
| | system_message: System prompt
|
| | user_message: User prompt
|
| | temperature: Sampling temperature
|
| | model_name: Model to use
|
| |
|
| | Returns:
|
| | Updated conversation history and response message
|
| | """
|
| | messages = [
|
| | {"role": "system", "content": system_message},
|
| | {"role": "user", "content": [{"type": "text", "text": user_message}]}
|
| | ] if system_message else [{"role": "user", "content": [{"type": "text", "text": user_message}]}]
|
| |
|
| | response = completion_with_backoff(
|
| | model=model_name,
|
| | messages=messages,
|
| | temperature=temperature,
|
| | )
|
| |
|
| | response_message = response.choices[0].message.content
|
| | full_messages.append(response_message)
|
| |
|
| | return full_messages, response_message
|
| |
|
| |
|
| | def ask_tool_choice(
|
| | full_messages: List[Dict],
|
| | top_system_prompt: str,
|
| | tool_choose_user_messages: str,
|
| | temperature: float = 0.0,
|
| | model_name: str = "gpt-4o-mini",
|
| | max_retries: int = 10
|
| | ) -> Tuple[List[Dict], List[Dict], str]:
|
| | """
|
| | Request tool selection from GPT
|
| |
|
| | Args:
|
| | full_messages: Conversation history
|
| | top_system_prompt: System prompt
|
| | tool_choose_user_messages: User prompt for tool selection
|
| | temperature: Sampling temperature
|
| | model_name: Model to use
|
| | max_retries: Maximum retry attempts
|
| |
|
| | Returns:
|
| | tool_choose_list, full_messages, tool_choose_response_message
|
| | """
|
| | for attempt in range(1, max_retries + 1):
|
| | try:
|
| | full_messages, response_message = run_one_round_conversation(
|
| | full_messages=full_messages,
|
| | system_message=top_system_prompt,
|
| | user_message=tool_choose_user_messages,
|
| | temperature=temperature,
|
| | model_name=model_name,
|
| | )
|
| |
|
| | tool_choose_list = json.loads(response_message)
|
| | return tool_choose_list, full_messages, response_message
|
| |
|
| | except json.JSONDecodeError as e:
|
| | logging.warning(
|
| | "Attempt %d/%d: JSON decode failed - %s; Response: %s",
|
| | attempt, max_retries, e, response_message[:200]
|
| | )
|
| |
|
| | raise ValueError(f"Failed to get valid JSON after {max_retries} attempts")
|
| |
|
| |
|
| | def ask_cot_data(
|
| | full_message: List[Dict],
|
| | cot_system_prompt: str,
|
| | cot_user_prompt: str,
|
| | temperature: float = 0.0,
|
| | model_name: str = "gpt-4o-mini",
|
| | max_retries: int = 10
|
| | ) -> Tuple[Dict, List[Dict], str]:
|
| | """
|
| | Request CoT data from GPT
|
| |
|
| | Args:
|
| | full_message: Conversation history
|
| | cot_system_prompt: System prompt for CoT generation
|
| | cot_user_prompt: User prompt for CoT generation
|
| | temperature: Sampling temperature
|
| | model_name: Model to use
|
| | max_retries: Maximum retry attempts
|
| |
|
| | Returns:
|
| | cot_data, full_message, cot_response_message
|
| | """
|
| | for attempt in range(1, max_retries + 1):
|
| | try:
|
| | full_message, response_message = run_one_round_conversation(
|
| | full_messages=full_message,
|
| | system_message=cot_system_prompt,
|
| | user_message=cot_user_prompt,
|
| | temperature=temperature,
|
| | model_name=model_name,
|
| | )
|
| |
|
| | cot_data = json.loads(response_message)
|
| | return cot_data, full_message, response_message
|
| |
|
| | except json.JSONDecodeError as e:
|
| | logging.warning(
|
| | "Attempt %d/%d: JSON decode failed - %s; Response: %s",
|
| | attempt, max_retries, e, response_message[:200]
|
| | )
|
| |
|
| | raise ValueError(f"Failed to get valid JSON after {max_retries} attempts")
|
| |
|
| |
|
| | def gen_pipeline(args: argparse.Namespace) -> None:
|
| | """
|
| | Main pipeline for generating CoT data
|
| |
|
| | Args:
|
| | args: Command line arguments
|
| | """
|
| | json_file = 'data/DriveLMMo1_TRAIN.json'
|
| | all_samples = read_json(json_file)
|
| | progress_file_path = 'progress.json'
|
| | output_file = f'final_cot_{args.split}_{args.model_name}.json'
|
| |
|
| | new_sample = []
|
| | start_index = 0
|
| |
|
| |
|
| | try:
|
| | with open(progress_file_path, 'r', encoding='utf-8') as f:
|
| | progress_data = json.load(f)
|
| | start_index = progress_data.get('processed_index', -1) + 1
|
| |
|
| | with open(output_file, 'r', encoding='utf-8') as f:
|
| | new_sample = json.load(f)
|
| | except FileNotFoundError:
|
| | pass
|
| |
|
| | if args.debug:
|
| | all_samples = all_samples[:20]
|
| | logging.info("Running in debug mode with %d samples", len(all_samples))
|
| |
|
| |
|
| | for index, sample in enumerate(tqdm(all_samples[start_index:], initial=start_index, desc="Processing samples")):
|
| | current_index = start_index + index
|
| |
|
| | try:
|
| | raw_question = sample['question']
|
| |
|
| | if args.split == 'train':
|
| | raw_answer = sample['answer']
|
| | reason_part = extract_key_steps(raw_answer)
|
| | final_answer = extract_final_answer(raw_answer)
|
| | elif args.split == 'test':
|
| | final_answer = sample['final_answer']
|
| | reason_part = sample['steps']
|
| |
|
| |
|
| | top_system_prompt = get_system_prompt()
|
| | tool_choose_user_messages = get_user_prompt(
|
| | raw_question=raw_question,
|
| | reason_part=reason_part,
|
| | raw_answer=final_answer
|
| | )
|
| |
|
| | tool_choose_list, full_message, tool_choose_response_message = ask_tool_choice(
|
| | full_messages=[],
|
| | top_system_prompt=top_system_prompt,
|
| | tool_choose_user_messages=tool_choose_user_messages,
|
| | model_name="gpt-4o-mini"
|
| | )
|
| |
|
| | if args.debug:
|
| | logging.info("Tool selection response: %s", tool_choose_response_message[:200])
|
| |
|
| |
|
| | cot_system_prompt = get_cot_system_prompt()
|
| | cot_user_prompt = get_cot_user_prompt(
|
| | raw_question=raw_question,
|
| | reason_part=reason_part,
|
| | final_answer=final_answer,
|
| | tool_choose_list=tool_choose_list
|
| | )
|
| |
|
| | cot_data, full_message, cot_response_message = ask_cot_data(
|
| | full_message=full_message,
|
| | cot_system_prompt=cot_system_prompt,
|
| | cot_user_prompt=cot_user_prompt,
|
| | model_name=args.model_name
|
| | )
|
| |
|
| | if args.debug:
|
| | logging.info("CoT response: %s", cot_response_message[:200])
|
| |
|
| |
|
| | sample["cot_data"] = cot_data
|
| | new_sample.append(sample)
|
| |
|
| |
|
| | with open(output_file, 'w', encoding='utf-8') as f:
|
| | json.dump(new_sample, f, indent=4)
|
| |
|
| | save_progress(progress_file_path, current_index)
|
| |
|
| | except Exception as e:
|
| | logging.error("Error processing sample %d: %s", current_index, str(e))
|
| | if args.debug:
|
| | raise
|
| |
|
| |
|
| | with open(f'cot_{args.split}_{args.model_name}.json', 'w', encoding='utf-8') as f:
|
| | json.dump(new_sample, f, indent=4)
|
| |
|
| | logging.info("Process completed successfully")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser(description="Generate CoT data for autonomous driving scenarios")
|
| | parser.add_argument('--split', type=str, default='train', required=True,
|
| | choices=['train', 'test'],
|
| | help="Dataset split: 'train' or 'test'")
|
| | parser.add_argument('--model_name', type=str, default='gpt-4o-mini', required=True,
|
| | choices=['gpt-4o', 'gpt-4o-mini', 'gpt-4', 'gpt-4.1-mini'],
|
| | help="OpenAI model to use for CoT generation")
|
| | parser.add_argument('--debug', action="store_true",
|
| | help="Enable debug mode with detailed logging")
|
| |
|
| | args = parser.parse_args()
|
| |
|
| |
|
| | func_agent = FuncAgent()
|
| |
|
| |
|
| | log_level = logging.DEBUG if args.debug else logging.INFO
|
| | logging.basicConfig(
|
| | level=log_level,
|
| | format='%(asctime)s - %(levelname)s - %(message)s',
|
| | filename=f'gen_{args.split}_data_{args.model_name}.log',
|
| | filemode='w'
|
| | )
|
| |
|
| | if args.debug:
|
| | logging.debug("Arguments: %s", vars(args))
|
| |
|
| | gen_pipeline(args) |