| | import asyncio |
| | from openai import AsyncOpenAI |
| | from tqdm import tqdm |
| | from load_dataset import load_dataset, length_max |
| | from itertools import islice |
| | import csv |
| |
|
| | client = AsyncOpenAI( |
| | base_url="http://localhost:8000/v1", |
| | api_key="none" |
| | ) |
| |
|
| | |
| | scientific_func_schema = { |
| | "type": "array", |
| | "description": "List of functions related to scientific and especially chemistry-related computing.", |
| | "items": { |
| | "type": "object", |
| | "additionalProperties": False, |
| | "properties": { |
| | "function_name": { |
| | "type": "string", |
| | "description": "The function name." |
| | }, |
| | "function_start_line": { |
| | "type": "integer", |
| | "description": "The starting line number of the function definition (inclusive)." |
| | }, |
| | "function_end_line": { |
| | "type": "integer", |
| | "description": "The ending line number of the function definition (inclusive)." |
| | }, |
| | "relevance_score": { |
| | "type": "integer", |
| | "minimum": 0, |
| | "maximum": 100, |
| | "description": "Relevance score (0–100) for scientific/chemistry-related computing. Only include functions with score > 0." |
| | }, |
| | "relevance_reason": { |
| | "type": "string", |
| | "description": "Explanation of why this function is related to scientific/chemical computing and why it received that score." |
| | }, |
| | "doc_start_line": { |
| | "type": ["integer", "null"], |
| | "description": "Starting line number of the associated documentation comment, or null if none." |
| | }, |
| | "doc_end_line": { |
| | "type": ["integer", "null"], |
| | "description": "Ending line number of the associated documentation comment, or null if none." |
| | } |
| | }, |
| | "required": [ |
| | "function_name", |
| | "function_start_line", |
| | "function_end_line", |
| | "relevance_score", |
| | "relevance_reason", |
| | "doc_start_line", |
| | "doc_end_line" |
| | ] |
| | } |
| | } |
| |
|
| |
|
| | async def process_one(code_file): |
| | prompt, row = code_file |
| | """处理单条 prompt""" |
| | resp = await client.chat.completions.create( |
| | model="Qwen3", |
| | messages=[{"role": "user", "content": prompt}], |
| | max_tokens=8192, |
| | temperature=0.7, |
| | top_p=0.8, |
| | presence_penalty=1.5, |
| | frequency_penalty=1.5, |
| | |
| | extra_body={ |
| | "top_k": 20, |
| | "chat_template_kwargs": { |
| | "enable_thinking": False, |
| | }, |
| |
|
| | }, |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ) |
| |
|
| | content = resp.choices[0].message.content |
| | |
| | |
| | |
| | |
| | |
| | res = content |
| |
|
| | return row, res |
| |
|
| | res_file = open('res2.csv', 'a+', encoding='utf-8') |
| | writer = csv.writer(res_file) |
| |
|
| | async def process_batch(batch): |
| | """并发处理一个 batch,同时显示进度条""" |
| | |
| | tasks = [asyncio.create_task(process_one(p)) for p in batch] |
| | results = [] |
| |
|
| | for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batch", unit="req", leave=False): |
| | result = await f |
| | |
| | |
| | writer.writerow([result[0], result[1]]) |
| | results.append(result) |
| |
|
| | return results |
| |
|
| | async def process_dataset(dataset_iter, batch_size=200): |
| | """按 batch_size 分批处理整个数据集,显示整体进度条""" |
| | results = [] |
| | num_batches = (length_max + batch_size - 1) // batch_size |
| | amount = 0 |
| | for i in tqdm(range(num_batches), desc="Overall progress", unit="batch"): |
| | batch = list(islice(dataset_iter, batch_size)) |
| | batch_results = await process_batch(batch) |
| | amount += len(batch_results) |
| | |
| | print("处理完成,共获得结果条数:", amount) |
| | with open("res.log", "w", encoding="utf-8") as f: |
| | f.write(str(amount)) |
| | return results |
| |
|
| | if __name__ == "__main__": |
| | dataset_iter = load_dataset() |
| | final_results = asyncio.run(process_dataset(dataset_iter, batch_size=64)) |
| | print("处理完成,共获得结果条数:", len(final_results)) |
| | res_file.close() |