|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pickle
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from tqdm import tqdm
|
|
|
from scripts.tools.tool_libraries import FuncAgent
|
|
|
from scripts.tools.tool_prompts import get_system_prompt, get_ego_prompts, get_detection_prompt
|
|
|
|
|
|
|
|
|
class AgentThink:
|
|
|
def __init__(self, token: str = None, split: str = 'train',
|
|
|
data_path: str = 'DriveLMM-o1-main/data/tool_results',
|
|
|
drivelmm_json_file: str = 'Drive-MLLM-main/data/DriveLMMo1/DriveLMMo1_TEST.json',
|
|
|
model_name: str = "qwen2.5-VL", verbose: bool = False) -> None:
|
|
|
"""
|
|
|
Initialize AgentThink class for processing driving scenario data.
|
|
|
|
|
|
Args:
|
|
|
token (str): Token identifier for the data
|
|
|
split (str): Data split type ('train' or 'val')
|
|
|
data_path (str): Path to tool results data
|
|
|
drivelmm_json_file (str): Path to DriveLMM JSON file
|
|
|
model_name (str): Name of the model being used
|
|
|
verbose (bool): Whether to show detailed logs
|
|
|
"""
|
|
|
self.token = token
|
|
|
self.split = split
|
|
|
self.data_path = data_path
|
|
|
self.model_name = model_name
|
|
|
self.verbose = verbose
|
|
|
|
|
|
|
|
|
folder_name = Path("val") if "val" in split else Path("train")
|
|
|
self.file_name = Path(data_path) / folder_name / Path(f"{self.token}.pkl")
|
|
|
with open(self.file_name, "rb") as f:
|
|
|
self.data_dict = pickle.load(f)
|
|
|
|
|
|
|
|
|
self.func_agent = FuncAgent(self.data_dict)
|
|
|
|
|
|
|
|
|
self.num_call_detection_times = 3
|
|
|
self.num_call_prediction_times = 1
|
|
|
self.num_call_occupancy_times = 1
|
|
|
self.num_call_map_times = 1
|
|
|
|
|
|
def _preprocess_tool_results(self, json_data):
|
|
|
"""
|
|
|
Convert agent-driver data into the scene in drivelmm-o1 format.
|
|
|
|
|
|
Args:
|
|
|
json_data: JSON data to preprocess
|
|
|
|
|
|
Returns:
|
|
|
Preprocessed JSON data with tool results
|
|
|
"""
|
|
|
|
|
|
new_json_data = []
|
|
|
for sample in json_data:
|
|
|
sample_idx = sample['idx']
|
|
|
scene_token = sample_idx.split('_')[0]
|
|
|
frame_token = sample_idx.split('_')[1]
|
|
|
|
|
|
|
|
|
folder_name = Path("val") if "val" in self.split else Path("train")
|
|
|
file_name = Path(self.data_path) / folder_name / Path(f"{frame_token}.pkl")
|
|
|
with open(file_name, "rb") as f:
|
|
|
data_dict = pickle.load(f)
|
|
|
|
|
|
|
|
|
sample['tool_results'] = data_dict
|
|
|
new_json_data.append(sample)
|
|
|
|
|
|
|
|
|
output_file = f'cot_{self.split}_{self.model_name}.json'
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
json.dump(new_json_data, f, indent=4)
|
|
|
|
|
|
return new_json_data
|
|
|
|
|
|
def tool_call(self, response_message):
|
|
|
"""
|
|
|
Execute a tool call based on the response message.
|
|
|
|
|
|
Args:
|
|
|
response_message: Message containing tool call information
|
|
|
|
|
|
Returns:
|
|
|
Tool response dictionary or None if call fails
|
|
|
"""
|
|
|
try:
|
|
|
tool_name = response_message['Tool']['function_name']
|
|
|
except (KeyError, TypeError):
|
|
|
return None
|
|
|
|
|
|
if tool_name == '' or tool_name == 'none':
|
|
|
return None
|
|
|
|
|
|
function_args = response_message['Tool']['parameters']
|
|
|
if len(function_args) > 0:
|
|
|
if function_args[0] == '':
|
|
|
function_args = {}
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
|
|
|
if isinstance(function_args, list):
|
|
|
if 'occupancy' in tool_name:
|
|
|
locations = function_args[0]
|
|
|
timestep = function_args[1]
|
|
|
if isinstance(locations, list):
|
|
|
locations = tuple(locations)
|
|
|
function_args = {'locations': [locations], 'timestep': timestep}
|
|
|
elif 'location' in tool_name:
|
|
|
locations = function_args[0]
|
|
|
if isinstance(locations, list):
|
|
|
locations = tuple(locations)
|
|
|
function_args = {'locations': [locations]}
|
|
|
else:
|
|
|
if 'open' not in tool_name:
|
|
|
obj_list = function_args[0]
|
|
|
function_args = {'object_ids': obj_list}
|
|
|
else:
|
|
|
tool_name = 'get_open_world_vocabulary_detection'
|
|
|
obj_list = function_args[0]
|
|
|
function_args = {'object_names': obj_list}
|
|
|
|
|
|
|
|
|
try:
|
|
|
function_to_call = getattr(self.func_agent, tool_name)
|
|
|
except AttributeError:
|
|
|
return None
|
|
|
|
|
|
|
|
|
if not callable(function_to_call):
|
|
|
print(f"Function {tool_name} is not callable!")
|
|
|
return None
|
|
|
else:
|
|
|
try:
|
|
|
tool_returns = function_to_call(**function_args)
|
|
|
except Exception:
|
|
|
return None
|
|
|
|
|
|
tool_prompt, tool_result_data = tool_returns
|
|
|
|
|
|
if tool_prompt is None:
|
|
|
tool_prompt = ""
|
|
|
|
|
|
|
|
|
tool_response = {
|
|
|
"name": tool_name,
|
|
|
"args": function_args,
|
|
|
"prompt": tool_prompt,
|
|
|
}
|
|
|
|
|
|
if self.verbose:
|
|
|
print(f"Tool: {tool_name}")
|
|
|
print(f"Args: {function_args}")
|
|
|
print(f"Prompt: {tool_prompt}")
|
|
|
|
|
|
return tool_response
|
|
|
|
|
|
def get_tool_results(self, sample, ego_prompts=None):
|
|
|
"""
|
|
|
Collect information from driving scenarios using chain-of-thought reasoning with function calls.
|
|
|
|
|
|
Args:
|
|
|
sample: Data sample to process
|
|
|
ego_prompts: Optional ego prompts to include
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (full_messages, system_message, tool_responses)
|
|
|
"""
|
|
|
|
|
|
init_system_message = get_system_prompt()
|
|
|
full_messages = []
|
|
|
tool_responses = []
|
|
|
|
|
|
|
|
|
system_message = init_system_message + "\n" + ego_prompts + "\n"
|
|
|
|
|
|
if self.verbose:
|
|
|
print("System Message:", system_message)
|
|
|
print("Detection Prompt:", get_detection_prompt())
|
|
|
|
|
|
|
|
|
cot_data = sample['cot_data']
|
|
|
tool_chain = cot_data['Chain']
|
|
|
|
|
|
|
|
|
cur_num_det_tool_call = 0
|
|
|
for chain_node in tool_chain:
|
|
|
try:
|
|
|
tool_name = chain_node['Tool']['function_name']
|
|
|
except (KeyError, TypeError):
|
|
|
continue
|
|
|
|
|
|
|
|
|
if 'detection' in tool_name:
|
|
|
cur_num_det_tool_call += 1
|
|
|
if cur_num_det_tool_call > self.num_call_detection_times:
|
|
|
continue
|
|
|
|
|
|
|
|
|
tool_response = self.tool_call(chain_node)
|
|
|
|
|
|
|
|
|
if tool_response is not None:
|
|
|
full_messages.append({
|
|
|
'role': 'function',
|
|
|
'name': tool_response['name'],
|
|
|
'content': tool_response['prompt'],
|
|
|
})
|
|
|
|
|
|
tool_responses.append(tool_response)
|
|
|
|
|
|
return full_messages, system_message, tool_responses
|
|
|
|
|
|
|
|
|
def main(drivelmm_json_file="/path/to/final_cot_test_gpt-4.1-mini.json"):
|
|
|
"""
|
|
|
Main function to process DriveLMM JSON data with AgentThink.
|
|
|
|
|
|
Args:
|
|
|
drivelmm_json_file: Path to DriveLMM JSON file
|
|
|
|
|
|
Returns:
|
|
|
Processed JSON data
|
|
|
"""
|
|
|
|
|
|
with open(drivelmm_json_file, "r", encoding="utf-8") as file:
|
|
|
json_data = json.load(file)
|
|
|
|
|
|
|
|
|
new_json_data = []
|
|
|
for index, sample in enumerate(tqdm(json_data, desc="Processing JSON samples")):
|
|
|
sample_idx = sample['idx']
|
|
|
scene_token = sample_idx.split('_')[0]
|
|
|
frame_token = sample_idx.split('_')[1]
|
|
|
|
|
|
|
|
|
agent = AgentThink(
|
|
|
token=frame_token,
|
|
|
split='val',
|
|
|
data_path="/path/to/tool_results",
|
|
|
drivelmm_json_file=drivelmm_json_file,
|
|
|
model_name='Qwen2.5-VL'
|
|
|
)
|
|
|
|
|
|
cur_data_dict = agent.data_dict
|
|
|
ego_prompts = get_ego_prompts(cur_data_dict)
|
|
|
|
|
|
|
|
|
full_messages, system_prompts, tool_responses = agent.get_tool_results(
|
|
|
sample=sample,
|
|
|
ego_prompts=ego_prompts
|
|
|
)
|
|
|
|
|
|
|
|
|
sample['tool_result'] = tool_responses
|
|
|
sample['system_prompts'] = system_prompts
|
|
|
new_json_data.append(sample)
|
|
|
|
|
|
|
|
|
output_file = f'{agent.data_path}/cot_{agent.split}_{agent.model_name}.json'
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
json.dump(new_json_data, f, indent=4)
|
|
|
|
|
|
return new_json_data
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main() |