| import torch |
| import torch.nn as nn |
| from tokenizers import Tokenizer |
| import re |
| import argparse |
| import sys |
|
|
| |
| |
| |
|
|
| class StabilizedDenoisingModel(nn.Module): |
| def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers): |
| super(StabilizedDenoisingModel, self).__init__() |
| self.embedding = nn.Embedding(vocab_size, embed_dim) |
| self.row_transform = nn.Linear(embed_dim, hidden_dim) |
| self.dim_transform = nn.Linear(hidden_dim, hidden_dim) |
| self.norm = nn.LayerNorm(hidden_dim) |
| |
| self.denoise_layers = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, hidden_dim) |
| ) |
| for _ in range(num_layers) |
| ]) |
| |
| self.output_layer = nn.Linear(hidden_dim, vocab_size) |
| |
| def forward(self, input_seq): |
| embedded_seq = self.embedding(input_seq) |
| hidden_space = self.row_transform(embedded_seq) |
| hidden_space = self.dim_transform(hidden_space) |
| hidden_space = self.norm(hidden_space) |
|
|
| for denoise_layer in self.denoise_layers: |
| signal = denoise_layer(hidden_space) |
| gate = torch.sigmoid(signal) |
| denoised = hidden_space - gate * signal + (1 - gate) * torch.relu(signal) |
| hidden_space = self.norm(hidden_space + denoised) |
|
|
| logits = self.output_layer(hidden_space) |
| return logits |
|
|
| |
| |
| |
|
|
| def clean_text(text): |
| """清洗输入文本""" |
| text = text.lower() |
| text = re.sub(r'[^a-z0-9\s.,!?;:\'"-]', '', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| return text |
|
|
| |
| |
| |
|
|
| def stream_generate_text(model, tokenizer, device, start_text, max_len=100, temperature=0.8): |
| """流式生成文本,逐个token输出(修复输出问题)""" |
| model.eval() |
| |
| |
| start_text = clean_text(start_text) |
| |
| |
| input_ids = tokenizer.encode(start_text).ids |
| input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device) |
| |
| generated_ids = input_ids.copy() |
| |
| |
| last_output_length = len(start_text) |
| |
| |
| print(start_text, end="", flush=True) |
| |
| for i in range(max_len): |
| with torch.no_grad(): |
| |
| if input_tensor.size(1) > 100: |
| input_tensor = input_tensor[:, -100:] |
| |
| |
| logits = model(input_tensor) |
| next_token_logits = logits[:, -1, :] / temperature |
| probs = torch.softmax(next_token_logits, dim=-1) |
| |
| |
| probs[probs < 0.01] = 0 |
| probs = probs / probs.sum() |
| |
| |
| next_token = torch.multinomial(probs, num_samples=1).item() |
| |
| |
| if next_token == tokenizer.token_to_id("<SEP>"): |
| break |
| |
| |
| generated_ids.append(next_token) |
| next_token_tensor = torch.tensor([[next_token]], device=device, dtype=torch.long) |
| input_tensor = torch.cat([input_tensor, next_token_tensor], dim=1) |
| |
| |
| current_text = tokenizer.decode(generated_ids) |
| |
| |
| new_text = current_text[last_output_length:] |
| last_output_length = len(current_text) |
| |
| |
| print(new_text, end="", flush=True) |
| |
| |
| return tokenizer.decode(generated_ids) |
|
|
| |
| |
| |
|
|
| def main(model_path, tokenizer_path): |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") |
| print(f"使用设备: {device}") |
| |
| |
| tokenizer = Tokenizer.from_file(tokenizer_path) |
| vocab_size = tokenizer.get_vocab_size() |
| print(f"加载分词器成功,词汇表大小: {vocab_size}") |
| |
| |
| model_params = { |
| "vocab_size": vocab_size, |
| "embed_dim": 256, |
| "hidden_dim": 512, |
| "num_layers": 16 |
| } |
| |
| |
| model = StabilizedDenoisingModel(**model_params).to(device) |
| |
| |
| try: |
| checkpoint = torch.load(model_path, map_location=device) |
| |
| if 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| model.load_state_dict(checkpoint) |
| |
| print(f"加载模型成功: {model_path}") |
| except Exception as e: |
| print(f"模型加载失败: {str(e)}") |
| return |
| |
| |
| print("\n===== GTC-2 Large Base Model Text Generator (Early Research Preview) =====") |
| print("输入文本后按回车生成,输入'quit'退出") |
| |
| while True: |
| user_input = input("\n输入: ") |
| if "activate" in user_input and "venv" in user_input: |
| print("检测到虚拟环境激活命令,已忽略") |
| continue |
|
|
| if user_input.lower() == 'quit': |
| break |
| |
| |
| sys.stdout.flush() |
| |
| |
| print("生成: ", end="", flush=True) |
| generated_text = stream_generate_text( |
| model, |
| tokenizer, |
| device, |
| user_input, |
| max_len=100, |
| temperature=0.8 |
| ) |
| |
| print("\n") |
|
|
| if __name__ == "__main__": |
| |
| parser = argparse.ArgumentParser(description='GTC-2 Large Base Model 文本生成器') |
| parser.add_argument('--model', type=str, default='best_model.pth', |
| help='模型文件路径 (默认: best_model.pth)') |
| parser.add_argument('--tokenizer', type=str, default='bpe_tokenizer.json', |
| help='分词器文件路径 (默认: bpe_tokenizer.json)') |
| |
| args = parser.parse_args() |
| |
| main(args.model, args.tokenizer) |