File size: 5,073 Bytes
89c39b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import asyncio
from openai import AsyncOpenAI
from tqdm import tqdm # 使用标准 tqdm
from load_dataset import load_dataset, length_max
from itertools import islice
import csv
client = AsyncOpenAI(
base_url="http://localhost:8000/v1",
api_key="none"
)
# 定义给 vLLM 使用的 JSON Schema(Python 字典写法)
scientific_func_schema = {
"type": "array",
"description": "List of functions related to scientific and especially chemistry-related computing.",
"items": {
"type": "object",
"additionalProperties": False,
"properties": {
"function_name": {
"type": "string",
"description": "The function name."
},
"function_start_line": {
"type": "integer",
"description": "The starting line number of the function definition (inclusive)."
},
"function_end_line": {
"type": "integer",
"description": "The ending line number of the function definition (inclusive)."
},
"relevance_score": {
"type": "integer",
"minimum": 0,
"maximum": 100,
"description": "Relevance score (0–100) for scientific/chemistry-related computing. Only include functions with score > 0."
},
"relevance_reason": {
"type": "string",
"description": "Explanation of why this function is related to scientific/chemical computing and why it received that score."
},
"doc_start_line": {
"type": ["integer", "null"],
"description": "Starting line number of the associated documentation comment, or null if none."
},
"doc_end_line": {
"type": ["integer", "null"],
"description": "Ending line number of the associated documentation comment, or null if none."
}
},
"required": [
"function_name",
"function_start_line",
"function_end_line",
"relevance_score",
"relevance_reason",
"doc_start_line",
"doc_end_line"
]
}
}
async def process_one(code_file):
prompt, row = code_file
"""处理单条 prompt"""
resp = await client.chat.completions.create(
model="Qwen3",
messages=[{"role": "user", "content": prompt}],
max_tokens=8192,
temperature=0.7,
top_p=0.8,
presence_penalty=1.5,
frequency_penalty=1.5,
extra_body={
"top_k": 20,
"chat_template_kwargs": {
"enable_thinking": False,
},
},
# response_format={
# "type": "json_schema",
# "json_schema": {
# "name": "scientific_functions_analysis",
# "schema": scientific_func_schema,
# "strict": True
# },
# },
# response_format={
# "type": "json_schema",
# "json_schema": {
# "name": "scientific_functions_analysis",
# "schema": {
# 'type': 'array',
# },
# "strict": True
# },
# },
)
content = resp.choices[0].message.content
# if 'true' in content[:6].lower() or 'true' in content[-6:].lower():
# res = True
# else:
# res = False
res = content
return row, res
res_file = open('res2.csv', 'a+', encoding='utf-8')
writer = csv.writer(res_file)
async def process_batch(batch):
"""并发处理一个 batch,同时显示进度条"""
# print(batch[0])
tasks = [asyncio.create_task(process_one(p)) for p in batch]
results = []
for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batch", unit="req", leave=False):
result = await f
# if result[1] == True:
# writer.writerow([result[0]])
writer.writerow([result[0], result[1]])
results.append(result)
return results
async def process_dataset(dataset_iter, batch_size=200):
"""按 batch_size 分批处理整个数据集,显示整体进度条"""
results = []
num_batches = (length_max + batch_size - 1) // batch_size
amount = 0
for i in tqdm(range(num_batches), desc="Overall progress", unit="batch"):
batch = list(islice(dataset_iter, batch_size))
batch_results = await process_batch(batch)
amount += len(batch_results)
# results.extend(batch_results)
print("处理完成,共获得结果条数:", amount)
with open("res.log", "w", encoding="utf-8") as f:
f.write(str(amount))
return results
if __name__ == "__main__":
dataset_iter = load_dataset()
final_results = asyncio.run(process_dataset(dataset_iter, batch_size=64))
print("处理完成,共获得结果条数:", len(final_results))
res_file.close() |