Spaces:
Sleeping
Sleeping
| """ | |
| utils.py β Shared Utilities: Download, Seeding, Plotting | |
| """ | |
| import os | |
| import random | |
| import numpy as np | |
| import argparse | |
| import matplotlib.pyplot as plt | |
| def set_seed(seed: int = 42): | |
| """Set random seeds for reproducibility.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| try: | |
| import torch | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| except ImportError: | |
| pass | |
| def download_dataset(): | |
| """Auto-download IMDB dataset using HuggingFace datasets.""" | |
| try: | |
| from datasets import load_dataset | |
| import pandas as pd | |
| os.makedirs("data/raw", exist_ok=True) | |
| print("π₯ Downloading IMDB dataset from HuggingFace...") | |
| raw = load_dataset("imdb") | |
| train_df = pd.DataFrame(raw["train"]) | |
| test_df = pd.DataFrame(raw["test"]) | |
| df = pd.concat([train_df, test_df], ignore_index=True) | |
| df.rename(columns={"text": "review"}, inplace=True) | |
| df["sentiment"] = df["label"].map({1: "positive", 0: "negative"}) | |
| out_path = "data/raw/IMDB Dataset.csv" | |
| df[["review", "sentiment"]].to_csv(out_path, index=False) | |
| print(f"β Dataset saved to {out_path} ({len(df):,} rows)") | |
| except Exception as e: | |
| print(f"β Download failed: {e}") | |
| print(" Please manually download from:") | |
| print(" https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews") | |
| def plot_training_history(history: dict, model_name: str, save: bool = True): | |
| """ | |
| Plot training loss and validation accuracy curves. | |
| Args: | |
| history: Dict with keys: 'train_loss', 'val_loss', 'val_acc' | |
| model_name: Used for title and filename | |
| save: Save to results/ if True | |
| """ | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5)) | |
| fig.patch.set_facecolor("#0f0f1a") | |
| for ax in (ax1, ax2): | |
| ax.set_facecolor("#1a1a2e") | |
| epochs = range(1, len(history["train_loss"]) + 1) | |
| # Loss | |
| ax1.plot(epochs, history["train_loss"], color="#6c63ff", lw=2, label="Train Loss") | |
| ax1.plot(epochs, history["val_loss"], color="#ff6584", lw=2, label="Val Loss") | |
| ax1.set_xlabel("Epoch", fontsize=12) | |
| ax1.set_ylabel("Loss", fontsize=12) | |
| ax1.set_title(f"{model_name} β Loss Curves", fontsize=13, fontweight="bold") | |
| ax1.legend(fontsize=10) | |
| ax1.grid(True, alpha=0.3) | |
| # Accuracy / F1 | |
| metric_key = "val_f1" if "val_f1" in history else "val_acc" | |
| ax2.plot(epochs, history[metric_key], color="#43aa8b", lw=2, | |
| label=metric_key.replace("_", " ").title()) | |
| ax2.set_xlabel("Epoch", fontsize=12) | |
| ax2.set_ylabel("Score", fontsize=12) | |
| ax2.set_title(f"{model_name} β {metric_key.replace('_', ' ').title()}", | |
| fontsize=13, fontweight="bold") | |
| ax2.legend(fontsize=10) | |
| ax2.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save: | |
| os.makedirs("results", exist_ok=True) | |
| safe = model_name.replace(" ", "_") | |
| path = f"results/{safe}_training_history.png" | |
| plt.savefig(path, dpi=150, bbox_inches="tight", | |
| facecolor=fig.get_facecolor()) | |
| print(f"π Training history saved β {path}") | |
| plt.close() | |
| def ensure_directories(): | |
| """Create all required project directories.""" | |
| dirs = [ | |
| "data/raw", "data/processed", "data/glove", | |
| "models", "models/bert_finetuned", | |
| "results", "results/confusion_matrices", | |
| "notebooks", | |
| ] | |
| for d in dirs: | |
| os.makedirs(d, exist_ok=True) | |
| print("π All directories created.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Sentiment Analysis Utilities") | |
| parser.add_argument("--download", action="store_true", | |
| help="Download IMDB dataset via HuggingFace") | |
| parser.add_argument("--setup", action="store_true", | |
| help="Create all project directories") | |
| args = parser.parse_args() | |
| if args.download: | |
| download_dataset() | |
| if args.setup: | |
| ensure_directories() | |
| if not any(vars(args).values()): | |
| parser.print_help() | |