| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| from tqdm import tqdm |
| import math |
| import speech_recognition as sr |
| import pyttsx3 |
| from googlesearch import search |
| import warnings |
| from typing import List, Dict, Union |
|
|
| |
| warnings.filterwarnings("ignore") |
|
|
| class WebSearchWrapper: |
| """Wrapper for web search with caching""" |
| def __init__(self, cache_size: int = 100): |
| self.cache: Dict[str, List[str]] = {} |
| self.cache_size = cache_size |
| |
| def search(self, query: str, num_results: int = 3) -> List[str]: |
| """Perform web search with caching""" |
| if query.lower() in self.cache: |
| return self.cache[query.lower()] |
| |
| try: |
| search_results = list(search(query, num_results=num_results, stop=num_results, pause=2)) |
| self._add_to_cache(query, search_results) |
| return search_results |
| except Exception as e: |
| print(f"Web search error: {e}") |
| return [] |
| |
| def _add_to_cache(self, query: str, results: List[str]): |
| """Add results to cache with LRU eviction policy""" |
| if len(self.cache) >= self.cache_size: |
| self.cache.pop(next(iter(self.cache))) |
| self.cache[query.lower()] = results |
|
|
| class FullChatDataset(Dataset): |
| def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=256): |
| self.datasets = [] |
| |
| for name in dataset_names: |
| try: |
| dataset = load_dataset(name, split="train") |
| self.datasets.append(dataset) |
| except Exception as e: |
| print(f"Failed to load dataset {name}: {e}") |
| |
| self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| self.max_length = max_length |
| |
| def __len__(self): |
| return sum(len(d) for d in self.datasets) |
| |
| def __getitem__(self, idx): |
| for dataset in self.datasets: |
| if idx < len(dataset): |
| item = dataset[idx] |
| break |
| idx -= len(dataset) |
| |
| if 'dialog' in item: |
| dialog = item['dialog'] |
| elif 'messages' in item: |
| dialog = [msg['text'] for msg in item['messages']] |
| else: |
| dialog = [v for k, v in item.items() if isinstance(v, str)] |
| |
| context = " [SEP] ".join(dialog[:-1]) |
| response = dialog[-1] |
| |
| inputs = self.tokenizer( |
| context, |
| text_pair=response, |
| max_length=self.max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors="pt" |
| ) |
| |
| return { |
| 'input_ids': inputs['input_ids'].flatten(), |
| 'attention_mask': inputs['attention_mask'].flatten(), |
| 'labels': inputs['input_ids'].flatten() |
| } |
|
|
| class SimpleTransformerModel(nn.Module): |
| def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, d_model) |
| self.pos_encoder = PositionalEncoding(d_model) |
| encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) |
| self.fc = nn.Linear(d_model, vocab_size) |
| |
| def forward(self, x, mask=None): |
| x = self.embedding(x) |
| x = self.pos_encoder(x) |
| x = self.transformer(x, mask) |
| return self.fc(x) |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=500): |
| super().__init__() |
| position = torch.arange(max_len).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
| pe = torch.zeros(max_len, d_model) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| self.register_buffer('pe', pe) |
| |
| def forward(self, x): |
| return x + self.pe[:x.size(1)] |
|
|
| class VoiceInterface: |
| def __init__(self): |
| self.recognizer = sr.Recognizer() |
| self.engine = pyttsx3.init() |
| |
| def listen(self) -> Union[str, None]: |
| with sr.Microphone() as source: |
| print("Listening...") |
| audio = self.recognizer.listen(source) |
| try: |
| text = self.recognizer.recognize_google(audio) |
| print(f"You said: {text}") |
| return text |
| except Exception as e: |
| print(f"Error recognizing speech: {e}") |
| return None |
| |
| def speak(self, text: str): |
| print(f"Bot: {text}") |
| self.engine.say(text) |
| self.engine.runAndWait() |
|
|
| class ChatBot: |
| def __init__(self): |
| self.dataset = FullChatDataset() |
| self.model = SimpleTransformerModel(len(self.dataset.tokenizer)) |
| self.voice_interface = VoiceInterface() |
| self.web_searcher = WebSearchWrapper() |
| |
| def train(self, epochs=3, lr=3e-4): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.model = self.model.to(device) |
| criterion = nn.CrossEntropyLoss(ignore_index=0) |
| optimizer = optim.Adam(self.model.parameters(), lr=lr) |
| |
| dataloader = DataLoader(self.dataset, batch_size=8, shuffle=True) |
| |
| for epoch in range(epochs): |
| self.model.train() |
| total_loss = 0 |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
| |
| for batch in pbar: |
| inputs = batch['input_ids'].to(device) |
| masks = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
| |
| optimizer.zero_grad() |
| outputs = self.model(inputs, masks) |
| loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1)) |
| loss.backward() |
| optimizer.step() |
| |
| total_loss += loss.item() |
| pbar.set_postfix({'loss': loss.item()}) |
| |
| print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}") |
| |
| def generate_response(self, prompt: str, max_length: int = 100, use_web: bool = True) -> str: |
| device = next(self.model.parameters()).device |
| self.model.eval() |
| |
| |
| if use_web and self._needs_web_search(prompt): |
| web_results = self.web_searcher.search(prompt) |
| if web_results: |
| prompt = f"Web context: {', '.join(web_results[:3])}. User question: {prompt}" |
| |
| inputs = self.dataset.tokenizer( |
| prompt, |
| return_tensors="pt", |
| max_length=256, |
| truncation=True, |
| padding='max_length' |
| ).to(device) |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| input_ids=inputs['input_ids'], |
| attention_mask=inputs['attention_mask'], |
| max_length=max_length, |
| do_sample=True, |
| top_k=50, |
| top_p=0.95, |
| temperature=0.7 |
| ) |
| |
| response = self.dataset.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return response |
| |
| def _needs_web_search(self, text: str) -> bool: |
| """Determine if a query needs web search""" |
| question_words = ['what', 'when', 'where', 'who', 'why', 'how', 'which', '?'] |
| return any(word in text.lower() for word in question_words) |