| |
|
| |
|
| |
|
| | import torch
|
| | from transformers import T5Tokenizer
|
| | from sentence_transformers import SentenceTransformer
|
| | import torch.nn as nn
|
| |
|
| |
|
| | MAPPER_PTH = "semantic_mapper.pth"
|
| | DECODER_PTH = "embedding_decoder.pth"
|
| | MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
| | MAX_LEN = 4096
|
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| |
|
| |
|
| | tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
| | pad_id = tokenizer.pad_token_id
|
| | eos_id = tokenizer.eos_token_id
|
| |
|
| |
|
| | class SemanticMapper(torch.nn.Module):
|
| | def __init__(self, dim):
|
| | super().__init__()
|
| | self.net = torch.nn.Sequential(
|
| | torch.nn.Linear(dim, dim * 2),
|
| | torch.nn.ReLU(),
|
| | torch.nn.Linear(dim * 2, dim)
|
| | )
|
| | def forward(self, x): return self.net(x)
|
| |
|
| | class EmbeddingDecoder(nn.Module):
|
| | def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2):
|
| | super().__init__()
|
| | self.bridge = nn.Linear(input_dim, hidden_dim)
|
| | self.embed = nn.Embedding(vocab_size, hidden_dim)
|
| | self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True)
|
| | self.ln = nn.LayerNorm(hidden_dim)
|
| | self.fc = nn.Linear(hidden_dim, vocab_size, bias=True)
|
| | self.drop = nn.Dropout(p)
|
| |
|
| | self.fc.weight = self.embed.weight
|
| |
|
| | @torch.no_grad()
|
| | def greedy_decode(self, emb_vec, max_len, start_id, eos_id):
|
| | B, _ = emb_vec.shape
|
| | h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0)
|
| | inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device)
|
| | out_ids = []
|
| | for _ in range(max_len):
|
| | token_h = self.drop(self.embed(inp))
|
| | step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1)
|
| | out, h = self.gru(step_in, h)
|
| | out = self.ln(out.squeeze(1))
|
| | logits = self.fc(self.drop(out))
|
| | logits[:, pad_id] = -1e9
|
| | next_id = torch.argmax(logits, dim=-1)
|
| | out_ids.append(next_id.unsqueeze(1))
|
| | if (next_id == eos_id).all(): break
|
| | inp = next_id.unsqueeze(1)
|
| | return torch.cat(out_ids, dim=1)
|
| |
|
| |
|
| |
|
| | mapper_ckpt = torch.load(MAPPER_PTH, map_location=DEVICE)
|
| | mapper = SemanticMapper(mapper_ckpt["dim"]).to(DEVICE)
|
| | mapper.load_state_dict(mapper_ckpt["state_dict"])
|
| | mapper.eval()
|
| |
|
| | dec_ckpt = torch.load(DECODER_PTH, map_location=DEVICE)
|
| | decoder = EmbeddingDecoder(dec_ckpt["dim"], 512, dec_ckpt["vocab_size"]).to(DEVICE)
|
| | decoder.load_state_dict(dec_ckpt["state_dict"])
|
| | decoder.eval()
|
| |
|
| | embedder = SentenceTransformer(MODEL_NAME, device=DEVICE)
|
| |
|
| |
|
| | def chat():
|
| | print("Chat ready. Type 'quit' to exit.")
|
| | while True:
|
| | user = input("User: ").strip()
|
| | if not user or user.lower() in {"quit","exit"}: break
|
| | x = embedder.encode([user], convert_to_tensor=True, device=DEVICE).detach().clone()
|
| | y_pred = mapper(x)
|
| | ids = decoder.greedy_decode(y_pred, max_len=MAX_LEN,
|
| | start_id=pad_id, eos_id=eos_id)[0].tolist()
|
| | reply = tokenizer.decode(ids, skip_special_tokens=True)
|
| | print("Bot:", reply)
|
| |
|
| | if __name__ == "__main__":
|
| | chat()
|
| |
|