Kimodo_OneClickStart / Codes /start_model.py
Ye-Song's picture
Add files using upload-large-folder tool
7b853a5 verified
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
import argparse
import time
# 命令行参数支持
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="请介绍一下量子计算的基本原理", help="输入提示词")
parser.add_argument("--model_path", type=str, default="/workspace/kimodo/McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", help="模型路径")
parser.add_argument("--max_length", type=int, default=512, help="最大生成长度")
parser.add_argument("--temperature", type=float, default=0.7, help="温度参数")
parser.add_argument("--num_beams", type=int, default=1, help="束搜索宽度,1表示贪心搜索")
parser.add_argument("--stream_output", action="store_true", help="流式输出响应")
args = parser.parse_args()
# 检查设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")
# 处理模型路径,移除可能的引号
model_path = args.model_path.strip('"').strip("'")
# 检查路径是否存在
if not os.path.exists(model_path):
raise ValueError(f"模型路径不存在: {model_path}")
# 加载Tokenizer
print(f"🔄 正在加载分词器: {model_path}")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
print(f"✅ 分词器加载完成,耗时: {time.time() - start_time:.2f}秒")
# 加载模型
print(f"🔄 正在加载模型: {model_path}")
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True # 如果需要使用远程代码
)
print(f"✅ 模型加载完成,耗时: {time.time() - start_time:.2f}秒")
# 编码输入
inputs = tokenizer(args.prompt, return_tensors="pt").to(device)
input_length = len(inputs["input_ids"][0])
print(f"📝 输入长度: {input_length} tokens")
# 准备生成参数
generate_args = {
"input_ids": inputs["input_ids"],
"max_length": min(args.max_length, input_length + 128), # 限制最大生成长度
"temperature": args.temperature,
"num_beams": args.num_beams,
"early_stopping": True,
"no_repeat_ngram_size": 3, # 避免重复
"pad_token_id": tokenizer.eos_token_id # 确保有pad token
}
# 生成文本
print(f"🤖 正在生成回复... (输入: {input_length} tokens, 最大生成长度: {generate_args['max_length']} tokens)")
start_time = time.time()
if args.stream_output:
# 流式输出
streamer = TextStreamer(tokenizer, skip_prompt=True)
with torch.no_grad():
model.generate(**generate_args, streamer=streamer)
print(f"\n✅ 生成完成,耗时: {time.time() - start_time:.2f}秒")
else:
# 非流式输出
with torch.no_grad():
outputs = model.generate(**generate_args)
# 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取生成的部分(不包括原始提示)
generated_text = response[len(args.prompt):].strip()
print(f"\n🧠 模型输出 ({len(outputs[0]) - input_length} tokens):\n{generated_text}")
print(f"✅ 生成完成,耗时: {time.time() - start_time:.2f}秒")