File size: 10,244 Bytes
7134ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# Copyright (c) OpenMMLab. All rights reserved.
# Authors: Kangan Qian (Tsinghua University, Xiaomi Corporation)
# Description: AgentThink class for processing driving scenario data with tool calls

import pickle
import json
from pathlib import Path
from tqdm import tqdm
from scripts.tools.tool_libraries import FuncAgent
from scripts.tools.tool_prompts import get_system_prompt, get_ego_prompts, get_detection_prompt


class AgentThink:
    def __init__(self, token: str = None, split: str = 'train', 

                 data_path: str = 'DriveLMM-o1-main/data/tool_results',

                 drivelmm_json_file: str = 'Drive-MLLM-main/data/DriveLMMo1/DriveLMMo1_TEST.json', 

                 model_name: str = "qwen2.5-VL", verbose: bool = False) -> None:
        """

        Initialize AgentThink class for processing driving scenario data.

        

        Args:

            token (str): Token identifier for the data

            split (str): Data split type ('train' or 'val')

            data_path (str): Path to tool results data

            drivelmm_json_file (str): Path to DriveLMM JSON file

            model_name (str): Name of the model being used

            verbose (bool): Whether to show detailed logs

        """
        self.token = token
        self.split = split
        self.data_path = data_path
        self.model_name = model_name
        self.verbose = verbose
        
        # Load data from pickle file
        folder_name = Path("val") if "val" in split else Path("train")
        self.file_name = Path(data_path) / folder_name / Path(f"{self.token}.pkl")
        with open(self.file_name, "rb") as f:
            self.data_dict = pickle.load(f)
        
        # Initialize function agent
        self.func_agent = FuncAgent(self.data_dict)
        
        # Set limits for tool calls
        self.num_call_detection_times = 3
        self.num_call_prediction_times = 1
        self.num_call_occupancy_times = 1
        self.num_call_map_times = 1

    def _preprocess_tool_results(self, json_data):
        """

        Convert agent-driver data into the scene in drivelmm-o1 format.

        

        Args:

            json_data: JSON data to preprocess

            

        Returns:

            Preprocessed JSON data with tool results

        """
        # TODO: Implement matching and alignment between JSON data and tool results
        new_json_data = []
        for sample in json_data:
            sample_idx = sample['idx']
            scene_token = sample_idx.split('_')[0]
            frame_token = sample_idx.split('_')[1]
            
            # Load corresponding tool data
            folder_name = Path("val") if "val" in self.split else Path("train")
            file_name = Path(self.data_path) / folder_name / Path(f"{frame_token}.pkl")
            with open(file_name, "rb") as f:
                data_dict = pickle.load(f)
            
            # Add tool results to sample
            sample['tool_results'] = data_dict
            new_json_data.append(sample)
        
        # Save processed data
        output_file = f'cot_{self.split}_{self.model_name}.json'
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(new_json_data, f, indent=4)
        
        return new_json_data

    def tool_call(self, response_message):
        """

        Execute a tool call based on the response message.

        

        Args:

            response_message: Message containing tool call information

            

        Returns:

            Tool response dictionary or None if call fails

        """
        try:
            tool_name = response_message['Tool']['function_name']
        except (KeyError, TypeError):
            return None
            
        if tool_name == '' or tool_name == 'none':
            return None
            
        function_args = response_message['Tool']['parameters']
        if len(function_args) > 0:
            if function_args[0] == '':
                function_args = {}
        else:
            return None
            
        # Process function arguments based on tool type
        if isinstance(function_args, list):
            if 'occupancy' in tool_name:
                locations = function_args[0]
                timestep = function_args[1]
                if isinstance(locations, list):
                    locations = tuple(locations)
                function_args = {'locations': [locations], 'timestep': timestep}
            elif 'location' in tool_name:
                locations = function_args[0]
                if isinstance(locations, list):
                    locations = tuple(locations)
                function_args = {'locations': [locations]}
            else:
                if 'open' not in tool_name:
                    obj_list = function_args[0]
                    function_args = {'object_ids': obj_list}
                else:
                    tool_name = 'get_open_world_vocabulary_detection'
                    obj_list = function_args[0]
                    function_args = {'object_names': obj_list}
        
        # Get the function to call
        try:
            function_to_call = getattr(self.func_agent, tool_name)
        except AttributeError:
            return None
        
        # Execute the function call
        if not callable(function_to_call):
            print(f"Function {tool_name} is not callable!")
            return None
        else:
            try:
                tool_returns = function_to_call(**function_args)
            except Exception:
                return None
                
            tool_prompt, tool_result_data = tool_returns
            
            if tool_prompt is None:
                tool_prompt = ""
                
            # Create tool response dictionary
            tool_response = {
                "name": tool_name,
                "args": function_args,
                "prompt": tool_prompt,
            }
            
            if self.verbose:
                print(f"Tool: {tool_name}")
                print(f"Args: {function_args}")
                print(f"Prompt: {tool_prompt}")
                
            return tool_response

    def get_tool_results(self, sample, ego_prompts=None):
        """

        Collect information from driving scenarios using chain-of-thought reasoning with function calls.

        

        Args:

            sample: Data sample to process

            ego_prompts: Optional ego prompts to include

            

        Returns:

            Tuple of (full_messages, system_message, tool_responses)

        """
        # Initialize system message
        init_system_message = get_system_prompt()
        full_messages = []
        tool_responses = []
        
        # Combine system message with ego prompts
        system_message = init_system_message + "\n" + ego_prompts + "\n"
        
        if self.verbose:
            print("System Message:", system_message)
            print("Detection Prompt:", get_detection_prompt())
        
        # Process chain of thought data
        cot_data = sample['cot_data']
        tool_chain = cot_data['Chain']
        
        # Execute tool calls iteratively
        cur_num_det_tool_call = 0
        for chain_node in tool_chain:
            try:
                tool_name = chain_node['Tool']['function_name']
            except (KeyError, TypeError):
                continue
                
            # Limit detection tool calls
            if 'detection' in tool_name:
                cur_num_det_tool_call += 1
                if cur_num_det_tool_call > self.num_call_detection_times:
                    continue
                    
            # Execute tool call
            tool_response = self.tool_call(chain_node)
            
            # Add tool response to messages
            if tool_response is not None:
                full_messages.append({
                    'role': 'function',
                    'name': tool_response['name'],
                    'content': tool_response['prompt'],
                })
                
            tool_responses.append(tool_response)
            
        return full_messages, system_message, tool_responses


def main(drivelmm_json_file="/path/to/final_cot_test_gpt-4.1-mini.json"):
    """

    Main function to process DriveLMM JSON data with AgentThink.

    

    Args:

        drivelmm_json_file: Path to DriveLMM JSON file

        

    Returns:

        Processed JSON data

    """
    # Load JSON data
    with open(drivelmm_json_file, "r", encoding="utf-8") as file:
        json_data = json.load(file)
    
    # Process each sample in the JSON data
    new_json_data = []
    for index, sample in enumerate(tqdm(json_data, desc="Processing JSON samples")):
        sample_idx = sample['idx']
        scene_token = sample_idx.split('_')[0]
        frame_token = sample_idx.split('_')[1]
        
        # Initialize agent and get ego prompts
        agent = AgentThink(
            token=frame_token, 
            split='val', 
            data_path="/path/to/tool_results",
            drivelmm_json_file=drivelmm_json_file,
            model_name='Qwen2.5-VL'
        )
        
        cur_data_dict = agent.data_dict
        ego_prompts = get_ego_prompts(cur_data_dict)
        
        # Get tool results
        full_messages, system_prompts, tool_responses = agent.get_tool_results(
            sample=sample, 
            ego_prompts=ego_prompts
        )
        
        # Update sample with tool results
        sample['tool_result'] = tool_responses
        sample['system_prompts'] = system_prompts
        new_json_data.append(sample)
    
    # Save processed data
    output_file = f'{agent.data_path}/cot_{agent.split}_{agent.model_name}.json'
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(new_json_data, f, indent=4)
        
    return new_json_data


if __name__ == "__main__":
    # Example usage
    main()