| import os |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
| import torch |
| import argparse |
| import time |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--prompt", type=str, default="请介绍一下量子计算的基本原理", help="输入提示词") |
| parser.add_argument("--model_path", type=str, default="/workspace/kimodo/McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", help="模型路径") |
| parser.add_argument("--max_length", type=int, default=512, help="最大生成长度") |
| parser.add_argument("--temperature", type=float, default=0.7, help="温度参数") |
| parser.add_argument("--num_beams", type=int, default=1, help="束搜索宽度,1表示贪心搜索") |
| parser.add_argument("--stream_output", action="store_true", help="流式输出响应") |
| args = parser.parse_args() |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"✅ Using device: {device}") |
|
|
| |
| model_path = args.model_path.strip('"').strip("'") |
|
|
| |
| if not os.path.exists(model_path): |
| raise ValueError(f"模型路径不存在: {model_path}") |
|
|
|
|
| |
| print(f"🔄 正在加载分词器: {model_path}") |
| start_time = time.time() |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| print(f"✅ 分词器加载完成,耗时: {time.time() - start_time:.2f}秒") |
|
|
| |
| print(f"🔄 正在加载模型: {model_path}") |
| start_time = time.time() |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| low_cpu_mem_usage=True, |
| trust_remote_code=True |
| ) |
| print(f"✅ 模型加载完成,耗时: {time.time() - start_time:.2f}秒") |
|
|
| |
| inputs = tokenizer(args.prompt, return_tensors="pt").to(device) |
| input_length = len(inputs["input_ids"][0]) |
| print(f"📝 输入长度: {input_length} tokens") |
|
|
| |
| generate_args = { |
| "input_ids": inputs["input_ids"], |
| "max_length": min(args.max_length, input_length + 128), |
| "temperature": args.temperature, |
| "num_beams": args.num_beams, |
| "early_stopping": True, |
| "no_repeat_ngram_size": 3, |
| "pad_token_id": tokenizer.eos_token_id |
| } |
|
|
| |
| print(f"🤖 正在生成回复... (输入: {input_length} tokens, 最大生成长度: {generate_args['max_length']} tokens)") |
| start_time = time.time() |
|
|
| if args.stream_output: |
| |
| streamer = TextStreamer(tokenizer, skip_prompt=True) |
| with torch.no_grad(): |
| model.generate(**generate_args, streamer=streamer) |
| print(f"\n✅ 生成完成,耗时: {time.time() - start_time:.2f}秒") |
| else: |
| |
| with torch.no_grad(): |
| outputs = model.generate(**generate_args) |
| |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| generated_text = response[len(args.prompt):].strip() |
| print(f"\n🧠 模型输出 ({len(outputs[0]) - input_length} tokens):\n{generated_text}") |
| print(f"✅ 生成完成,耗时: {time.time() - start_time:.2f}秒") |