# -*- coding: utf-8 -*- """ Task 1: Next-Word Prediction using MLP on a Multi-GPU Cluster. This script is designed to be run using torchrun for distributed training. Example usage on a 5-GPU machine: torchrun --nproc_per_node=5 task_1_distributed.py --dataset shakespeare torchrun --nproc_per_node=5 task_1_distributed.py --dataset linux """ import os import re import json import time import argparse from collections import Counter import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import TensorDataset, DataLoader, DistributedSampler from tqdm import tqdm import numpy as np import matplotlib.pyplot as plt from sklearn.manifold import TSNE from sklearn.decomposition import PCA import random # --- Utility Functions for Distributed Training --- def setup(rank, world_size): """Initializes the distributed process group.""" os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group("gloo", rank=rank, world_size=world_size) def cleanup(): """Cleans up the distributed process group.""" dist.destroy_process_group() def is_main_process(): """Checks if the current process is the main one (rank 0).""" return dist.get_rank() == 0 # --- Data Preprocessing --- def download_and_preprocess_text(dataset_name): """Downloads and preprocesses the specified dataset.""" if dataset_name == 'shakespeare': url = 'https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt' filename = 'shakespeare_input.txt' if not os.path.exists(filename): os.system(f"wget {url}") with open(filename, "r", encoding='utf-8') as f: text = f.read() # Keep full stops, remove other special characters text = re.sub(r'[^a-zA-Z0-9 \.]', '', text.lower()) # Replace multiple spaces with a single space text = re.sub(r'\s+', ' ', text).strip() return text elif dataset_name == 'linux': url = 'https://cs.stanford.edu/people/karpathy/char-rnn/linux_input.txt' filename = 'linux_input.txt' if not os.path.exists(filename): os.system(f"wget {url}") with open(filename, "r", encoding='utf-8', errors='ignore') as f: text = f.read() # For code, we treat newlines as separators and don't lowercase # We also keep more special characters lines = text.split('\n') processed_lines = [] for line in lines: # A more lenient regex for code processed_line = re.sub(r'[^\w\s\.\(\)\[\]\{\}\=\+\-\*\/,;:"\'#<>&|!~`?]', '', line) processed_lines.append(processed_line.strip()) return ' \n '.join(processed_lines) # Use newline as a token else: raise ValueError("Invalid dataset name. Choose 'shakespeare' or 'linux'.") def create_vocabulary_and_pairs(text, context_window_size): """Creates vocabulary, reports frequencies, and generates context-target pairs.""" if is_main_process(): print("Tokenizing text...") tokens = text.split(' ') tokens = [token for token in tokens if token] # Remove empty strings if is_main_process(): # Report word frequencies word_counts = Counter(tokens) print("\n--- Vocabulary Report ---") print(f"10 Most Frequent Words: {word_counts.most_common(10)}") print(f"10 Least Frequent Words: {word_counts.most_common()[:-11:-1]}") # Build vocabulary vocab = sorted(list(set(tokens))) word_to_idx = {word: i+1 for i, word in enumerate(vocab)} # 0 is reserved for padding word_to_idx[''] = 0 idx_to_word = {i: word for word, i in word_to_idx.items()} vocab_size = len(word_to_idx) if is_main_process(): print(f"Vocabulary Size: {vocab_size}") # Create context-target pairs indexed_tokens = [word_to_idx[word] for word in tokens] contexts, targets = [], [] for i in range(len(indexed_tokens) - context_window_size): contexts.append(indexed_tokens[i:i+context_window_size]) targets.append(indexed_tokens[i+context_window_size]) return torch.tensor(contexts, dtype=torch.long), torch.tensor(targets, dtype=torch.long), word_to_idx, idx_to_word # --- Model Definition --- class NextWordPredictor(nn.Module): def __init__(self, vocab_size, embedding_dim, context_size, hidden_dim): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.fc1 = nn.Linear(context_size * embedding_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, vocab_size) def forward(self, x): embedded = self.embedding(x).view(x.size(0), -1) out = self.relu(self.fc1(embedded)) out = self.relu(self.fc2(out)) out = self.fc3(out) return out # --- Training and Evaluation --- def train(rank, world_size, args): """Main training and evaluation function.""" setup(rank, world_size) device = torch.device(f"cuda:{rank}") # --- 1. Data Loading and Preprocessing --- if is_main_process(): print(f"--- Starting training for dataset: {args.dataset} ---") raw_text = download_and_preprocess_text(args.dataset) # Save preprocessed text for other processes to load with open(f"{args.dataset}_processed.txt", "w", encoding='utf-8') as f: f.write(raw_text) # Ensure all processes have the preprocessed file before continuing dist.barrier() with open(f"{args.dataset}_processed.txt", "r", encoding='utf-8') as f: raw_text = f.read() contexts, targets, word_to_idx, idx_to_word = create_vocabulary_and_pairs(raw_text, args.context_size) vocab_size = len(word_to_idx) # Save vocabulary only from the main process if is_main_process(): with open(f'{args.dataset}_word_to_idx.json', 'w') as f: json.dump(word_to_idx, f) # Split data dataset = TensorDataset(contexts, targets) train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) # Distributed Samplers train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler, pin_memory=True) # --- 2. Model, Optimizer, and Loss --- model = NextWordPredictor(vocab_size, args.embedding_dim, args.context_size, args.hidden_dim).to(device) ddp_model = DDP(model, device_ids=[rank]) criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding optimizer = optim.AdamW(ddp_model.parameters(), lr=args.lr) scaler = torch.cuda.amp.GradScaler() # For mixed precision # --- 3. Training Loop --- history = {'train_loss': [], 'val_loss': []} for epoch in range(args.epochs): ddp_model.train() train_sampler.set_epoch(epoch) total_train_loss = 0.0 # Use tqdm only on the main process train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]", disable=not is_main_process()) for inputs, labels in train_pbar: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() with torch.cuda.amp.autocast(): outputs = ddp_model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_train_loss += loss.item() avg_train_loss = total_train_loss / len(train_loader) history['train_loss'].append(avg_train_loss) # --- 4. Validation Loop --- ddp_model.eval() total_val_loss = 0.0 with torch.no_grad(): val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Val]", disable=not is_main_process()) for inputs, labels in val_pbar: inputs, labels = inputs.to(device), labels.to(device) with torch.cuda.amp.autocast(): outputs = ddp_model(inputs) loss = criterion(outputs, labels) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_loader) history['val_loss'].append(avg_val_loss) if is_main_process(): print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") # Save model checkpoint torch.save(ddp_model.module.state_dict(), f'{args.dataset}_model.pth') if is_main_process(): print("--- Training Complete ---") print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}") # --- 5. Reporting and Visualization --- # Plotting Loss plt.figure(figsize=(10, 5)) plt.plot(history['train_loss'], label='Training Loss') plt.plot(history['val_loss'], label='Validation Loss') plt.title(f'Training vs. Validation Loss ({args.dataset.capitalize()})') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.savefig(f'{args.dataset}_loss_curve.png') print(f"Loss curve saved to {args.dataset}_loss_curve.png") # Example Predictions print("\n--- Example Predictions ---") model.load_state_dict(torch.load(f'{args.dataset}_model.pth')) model.to(device) model.eval() test_sentences = { 'shakespeare': ["to be or not to", "a horse a horse my", "shall i compare thee to"], 'linux': ["if (err != 0)", "static const struct file_operations", "return -EINVAL;"] } for sentence in test_sentences[args.dataset]: context_tokens = sentence.lower().split() if args.dataset == 'shakespeare' else sentence.split() context_indices = [word_to_idx.get(w, 0) for w in context_tokens] context_tensor = torch.tensor([context_indices[-args.context_size:]], dtype=torch.long).to(device) with torch.no_grad(): prediction = model(context_tensor) predicted_index = torch.argmax(prediction, dim=1).item() predicted_word = idx_to_word.get(predicted_index, '') print(f"'{sentence}' -> '{predicted_word}'") # Embedding Visualization print("\n--- Visualizing Embeddings with t-SNE ---") num_words_to_visualize = 200 words = list(word_to_idx.keys()) if len(words) > num_words_to_visualize: words_to_visualize = random.sample(words, num_words_to_visualize) else: words_to_visualize = words indices = [word_to_idx[w] for w in words_to_visualize] embeddings = model.embedding.weight.data[indices].cpu().numpy() tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1)) embeddings_2d = tsne.fit_transform(embeddings) plt.figure(figsize=(16, 16)) for i, word in enumerate(words_to_visualize): x, y = embeddings_2d[i, :] plt.scatter(x, y) plt.annotate(word, (x, y), alpha=0.7) plt.title(f't-SNE Visualization of Word Embeddings ({args.dataset.capitalize()})') plt.grid(True) plt.savefig(f'{args.dataset}_embeddings.png') print(f"Embedding visualization saved to {args.dataset}_embeddings.png") cleanup() # --- Main Execution --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Multi-GPU Next Word Prediction Trainer") parser.add_argument('--dataset', type=str, required=True, choices=['shakespeare', 'linux'], help='Dataset to use.') # Model Hyperparameters parser.add_argument('--context_size', type=int, default=5, help='Number of context words.') parser.add_argument('--embedding_dim', type=int, default=64, help='Dimension of word embeddings.') parser.add_argument('--hidden_dim', type=int, default=1024, help='Dimension of hidden layers.') # Training Hyperparameters parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.') parser.add_argument('--batch_size', type=int, default=16384, help='Batch size per GPU.') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.') args = parser.parse_args() world_size = torch.cuda.device_count() if world_size < 1: print("This script requires at least one GPU.") else: # Use torch.multiprocessing.spawn to launch DDP processes # Note: For cluster environments, torchrun is the preferred method. # This script is designed for torchrun. rank = int(os.environ["LOCAL_RANK"]) train(rank, world_size, args)