dataset-builder / data3 /vllm_high.py
DouDou
Upload data3/vllm_high.py with huggingface_hub
89c39b8 verified
import asyncio
from openai import AsyncOpenAI
from tqdm import tqdm # 使用标准 tqdm
from load_dataset import load_dataset, length_max
from itertools import islice
import csv
client = AsyncOpenAI(
base_url="http://localhost:8000/v1",
api_key="none"
)
# 定义给 vLLM 使用的 JSON Schema(Python 字典写法)
scientific_func_schema = {
"type": "array",
"description": "List of functions related to scientific and especially chemistry-related computing.",
"items": {
"type": "object",
"additionalProperties": False,
"properties": {
"function_name": {
"type": "string",
"description": "The function name."
},
"function_start_line": {
"type": "integer",
"description": "The starting line number of the function definition (inclusive)."
},
"function_end_line": {
"type": "integer",
"description": "The ending line number of the function definition (inclusive)."
},
"relevance_score": {
"type": "integer",
"minimum": 0,
"maximum": 100,
"description": "Relevance score (0–100) for scientific/chemistry-related computing. Only include functions with score > 0."
},
"relevance_reason": {
"type": "string",
"description": "Explanation of why this function is related to scientific/chemical computing and why it received that score."
},
"doc_start_line": {
"type": ["integer", "null"],
"description": "Starting line number of the associated documentation comment, or null if none."
},
"doc_end_line": {
"type": ["integer", "null"],
"description": "Ending line number of the associated documentation comment, or null if none."
}
},
"required": [
"function_name",
"function_start_line",
"function_end_line",
"relevance_score",
"relevance_reason",
"doc_start_line",
"doc_end_line"
]
}
}
async def process_one(code_file):
prompt, row = code_file
"""处理单条 prompt"""
resp = await client.chat.completions.create(
model="Qwen3",
messages=[{"role": "user", "content": prompt}],
max_tokens=8192,
temperature=0.7,
top_p=0.8,
presence_penalty=1.5,
frequency_penalty=1.5,
extra_body={
"top_k": 20,
"chat_template_kwargs": {
"enable_thinking": False,
},
},
# response_format={
# "type": "json_schema",
# "json_schema": {
# "name": "scientific_functions_analysis",
# "schema": scientific_func_schema,
# "strict": True
# },
# },
# response_format={
# "type": "json_schema",
# "json_schema": {
# "name": "scientific_functions_analysis",
# "schema": {
# 'type': 'array',
# },
# "strict": True
# },
# },
)
content = resp.choices[0].message.content
# if 'true' in content[:6].lower() or 'true' in content[-6:].lower():
# res = True
# else:
# res = False
res = content
return row, res
res_file = open('res2.csv', 'a+', encoding='utf-8')
writer = csv.writer(res_file)
async def process_batch(batch):
"""并发处理一个 batch,同时显示进度条"""
# print(batch[0])
tasks = [asyncio.create_task(process_one(p)) for p in batch]
results = []
for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batch", unit="req", leave=False):
result = await f
# if result[1] == True:
# writer.writerow([result[0]])
writer.writerow([result[0], result[1]])
results.append(result)
return results
async def process_dataset(dataset_iter, batch_size=200):
"""按 batch_size 分批处理整个数据集,显示整体进度条"""
results = []
num_batches = (length_max + batch_size - 1) // batch_size
amount = 0
for i in tqdm(range(num_batches), desc="Overall progress", unit="batch"):
batch = list(islice(dataset_iter, batch_size))
batch_results = await process_batch(batch)
amount += len(batch_results)
# results.extend(batch_results)
print("处理完成,共获得结果条数:", amount)
with open("res.log", "w", encoding="utf-8") as f:
f.write(str(amount))
return results
if __name__ == "__main__":
dataset_iter = load_dataset()
final_results = asyncio.run(process_dataset(dataset_iter, batch_size=64))
print("处理完成,共获得结果条数:", len(final_results))
res_file.close()