| |
| """ |
| Task 1: Next-Word Prediction using MLP on a Multi-GPU Cluster. |
| |
| This script is designed to be run using torchrun for distributed training. |
| Example usage on a 5-GPU machine: |
| torchrun --nproc_per_node=5 task_1_distributed.py --dataset shakespeare |
| torchrun --nproc_per_node=5 task_1_distributed.py --dataset linux |
| """ |
| import os |
| import re |
| import json |
| import time |
| import argparse |
| from collections import Counter |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import TensorDataset, DataLoader, DistributedSampler |
| from tqdm import tqdm |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from sklearn.manifold import TSNE |
| from sklearn.decomposition import PCA |
| import random |
|
|
| |
|
|
| def setup(rank, world_size): |
| """Initializes the distributed process group.""" |
| os.environ['MASTER_ADDR'] = 'localhost' |
| os.environ['MASTER_PORT'] = '12355' |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) |
|
|
| def cleanup(): |
| """Cleans up the distributed process group.""" |
| dist.destroy_process_group() |
|
|
| def is_main_process(): |
| """Checks if the current process is the main one (rank 0).""" |
| return dist.get_rank() == 0 |
|
|
| |
|
|
| def download_and_preprocess_text(dataset_name): |
| """Downloads and preprocesses the specified dataset.""" |
| if dataset_name == 'shakespeare': |
| url = 'https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt' |
| filename = 'shakespeare_input.txt' |
| if not os.path.exists(filename): |
| os.system(f"wget {url}") |
| with open(filename, "r", encoding='utf-8') as f: |
| text = f.read() |
| |
| text = re.sub(r'[^a-zA-Z0-9 \.]', '', text.lower()) |
| |
| text = re.sub(r'\s+', ' ', text).strip() |
| return text |
| elif dataset_name == 'linux': |
| url = 'https://cs.stanford.edu/people/karpathy/char-rnn/linux_input.txt' |
| filename = 'linux_input.txt' |
| if not os.path.exists(filename): |
| os.system(f"wget {url}") |
| with open(filename, "r", encoding='utf-8', errors='ignore') as f: |
| text = f.read() |
| |
| |
| lines = text.split('\n') |
| processed_lines = [] |
| for line in lines: |
| |
| processed_line = re.sub(r'[^\w\s\.\(\)\[\]\{\}\=\+\-\*\/,;:"\'#<>&|!~`?]', '', line) |
| processed_lines.append(processed_line.strip()) |
| return ' \n '.join(processed_lines) |
| else: |
| raise ValueError("Invalid dataset name. Choose 'shakespeare' or 'linux'.") |
|
|
| def create_vocabulary_and_pairs(text, context_window_size): |
| """Creates vocabulary, reports frequencies, and generates context-target pairs.""" |
| if is_main_process(): |
| print("Tokenizing text...") |
| tokens = text.split(' ') |
| tokens = [token for token in tokens if token] |
|
|
| if is_main_process(): |
| |
| word_counts = Counter(tokens) |
| print("\n--- Vocabulary Report ---") |
| print(f"10 Most Frequent Words: {word_counts.most_common(10)}") |
| print(f"10 Least Frequent Words: {word_counts.most_common()[:-11:-1]}") |
|
|
| |
| vocab = sorted(list(set(tokens))) |
| word_to_idx = {word: i+1 for i, word in enumerate(vocab)} |
| word_to_idx['<pad>'] = 0 |
| idx_to_word = {i: word for word, i in word_to_idx.items()} |
| vocab_size = len(word_to_idx) |
|
|
| if is_main_process(): |
| print(f"Vocabulary Size: {vocab_size}") |
|
|
| |
| indexed_tokens = [word_to_idx[word] for word in tokens] |
| contexts, targets = [], [] |
| for i in range(len(indexed_tokens) - context_window_size): |
| contexts.append(indexed_tokens[i:i+context_window_size]) |
| targets.append(indexed_tokens[i+context_window_size]) |
|
|
| return torch.tensor(contexts, dtype=torch.long), torch.tensor(targets, dtype=torch.long), word_to_idx, idx_to_word |
|
|
| |
|
|
| class NextWordPredictor(nn.Module): |
| def __init__(self, vocab_size, embedding_dim, context_size, hidden_dim): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) |
| self.fc1 = nn.Linear(context_size * embedding_dim, hidden_dim) |
| self.relu = nn.ReLU() |
| self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
| self.fc3 = nn.Linear(hidden_dim, vocab_size) |
|
|
| def forward(self, x): |
| embedded = self.embedding(x).view(x.size(0), -1) |
| out = self.relu(self.fc1(embedded)) |
| out = self.relu(self.fc2(out)) |
| out = self.fc3(out) |
| return out |
|
|
| |
|
|
| def train(rank, world_size, args): |
| """Main training and evaluation function.""" |
| setup(rank, world_size) |
| device = torch.device(f"cuda:{rank}") |
|
|
| |
| if is_main_process(): |
| print(f"--- Starting training for dataset: {args.dataset} ---") |
| raw_text = download_and_preprocess_text(args.dataset) |
| |
| with open(f"{args.dataset}_processed.txt", "w", encoding='utf-8') as f: |
| f.write(raw_text) |
|
|
| |
| dist.barrier() |
|
|
| with open(f"{args.dataset}_processed.txt", "r", encoding='utf-8') as f: |
| raw_text = f.read() |
|
|
| contexts, targets, word_to_idx, idx_to_word = create_vocabulary_and_pairs(raw_text, args.context_size) |
| vocab_size = len(word_to_idx) |
| |
| |
| if is_main_process(): |
| with open(f'{args.dataset}_word_to_idx.json', 'w') as f: |
| json.dump(word_to_idx, f) |
|
|
| |
| dataset = TensorDataset(contexts, targets) |
| train_size = int(0.8 * len(dataset)) |
| val_size = len(dataset) - train_size |
| train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
| |
| train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) |
| val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, pin_memory=True) |
| val_loader = DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler, pin_memory=True) |
|
|
| |
| model = NextWordPredictor(vocab_size, args.embedding_dim, args.context_size, args.hidden_dim).to(device) |
| ddp_model = DDP(model, device_ids=[rank]) |
|
|
| criterion = nn.CrossEntropyLoss(ignore_index=0) |
| optimizer = optim.AdamW(ddp_model.parameters(), lr=args.lr) |
| scaler = torch.cuda.amp.GradScaler() |
|
|
| |
| history = {'train_loss': [], 'val_loss': []} |
| for epoch in range(args.epochs): |
| ddp_model.train() |
| train_sampler.set_epoch(epoch) |
| total_train_loss = 0.0 |
|
|
| |
| train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]", disable=not is_main_process()) |
| for inputs, labels in train_pbar: |
| inputs, labels = inputs.to(device), labels.to(device) |
|
|
| optimizer.zero_grad() |
| with torch.cuda.amp.autocast(): |
| outputs = ddp_model(inputs) |
| loss = criterion(outputs, labels) |
|
|
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| total_train_loss += loss.item() |
|
|
| avg_train_loss = total_train_loss / len(train_loader) |
| history['train_loss'].append(avg_train_loss) |
|
|
| |
| ddp_model.eval() |
| total_val_loss = 0.0 |
| with torch.no_grad(): |
| val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Val]", disable=not is_main_process()) |
| for inputs, labels in val_pbar: |
| inputs, labels = inputs.to(device), labels.to(device) |
| with torch.cuda.amp.autocast(): |
| outputs = ddp_model(inputs) |
| loss = criterion(outputs, labels) |
| total_val_loss += loss.item() |
|
|
| avg_val_loss = total_val_loss / len(val_loader) |
| history['val_loss'].append(avg_val_loss) |
|
|
| if is_main_process(): |
| print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") |
| |
| torch.save(ddp_model.module.state_dict(), f'{args.dataset}_model.pth') |
|
|
| if is_main_process(): |
| print("--- Training Complete ---") |
| print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}") |
|
|
| |
| |
| plt.figure(figsize=(10, 5)) |
| plt.plot(history['train_loss'], label='Training Loss') |
| plt.plot(history['val_loss'], label='Validation Loss') |
| plt.title(f'Training vs. Validation Loss ({args.dataset.capitalize()})') |
| plt.xlabel('Epochs') |
| plt.ylabel('Loss') |
| plt.legend() |
| plt.grid(True) |
| plt.savefig(f'{args.dataset}_loss_curve.png') |
| print(f"Loss curve saved to {args.dataset}_loss_curve.png") |
|
|
| |
| print("\n--- Example Predictions ---") |
| model.load_state_dict(torch.load(f'{args.dataset}_model.pth')) |
| model.to(device) |
| model.eval() |
| test_sentences = { |
| 'shakespeare': ["to be or not to", "a horse a horse my", "shall i compare thee to"], |
| 'linux': ["if (err != 0)", "static const struct file_operations", "return -EINVAL;"] |
| } |
| for sentence in test_sentences[args.dataset]: |
| context_tokens = sentence.lower().split() if args.dataset == 'shakespeare' else sentence.split() |
| context_indices = [word_to_idx.get(w, 0) for w in context_tokens] |
| context_tensor = torch.tensor([context_indices[-args.context_size:]], dtype=torch.long).to(device) |
| with torch.no_grad(): |
| prediction = model(context_tensor) |
| predicted_index = torch.argmax(prediction, dim=1).item() |
| predicted_word = idx_to_word.get(predicted_index, '<unk>') |
| print(f"'{sentence}' -> '{predicted_word}'") |
|
|
| |
| print("\n--- Visualizing Embeddings with t-SNE ---") |
| num_words_to_visualize = 200 |
| words = list(word_to_idx.keys()) |
| if len(words) > num_words_to_visualize: |
| words_to_visualize = random.sample(words, num_words_to_visualize) |
| else: |
| words_to_visualize = words |
| |
| indices = [word_to_idx[w] for w in words_to_visualize] |
| embeddings = model.embedding.weight.data[indices].cpu().numpy() |
|
|
| tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1)) |
| embeddings_2d = tsne.fit_transform(embeddings) |
|
|
| plt.figure(figsize=(16, 16)) |
| for i, word in enumerate(words_to_visualize): |
| x, y = embeddings_2d[i, :] |
| plt.scatter(x, y) |
| plt.annotate(word, (x, y), alpha=0.7) |
| plt.title(f't-SNE Visualization of Word Embeddings ({args.dataset.capitalize()})') |
| plt.grid(True) |
| plt.savefig(f'{args.dataset}_embeddings.png') |
| print(f"Embedding visualization saved to {args.dataset}_embeddings.png") |
|
|
| cleanup() |
|
|
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Multi-GPU Next Word Prediction Trainer") |
| parser.add_argument('--dataset', type=str, required=True, choices=['shakespeare', 'linux'], help='Dataset to use.') |
| |
| parser.add_argument('--context_size', type=int, default=5, help='Number of context words.') |
| parser.add_argument('--embedding_dim', type=int, default=64, help='Dimension of word embeddings.') |
| parser.add_argument('--hidden_dim', type=int, default=1024, help='Dimension of hidden layers.') |
| |
| parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.') |
| parser.add_argument('--batch_size', type=int, default=16384, help='Batch size per GPU.') |
| parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.') |
|
|
| args = parser.parse_args() |
|
|
| world_size = torch.cuda.device_count() |
| if world_size < 1: |
| print("This script requires at least one GPU.") |
| else: |
| |
| |
| |
| rank = int(os.environ["LOCAL_RANK"]) |
| train(rank, world_size, args) |