File size: 4,404 Bytes
c247f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""

utils.py β€” Shared Utilities: Download, Seeding, Plotting

"""

import os
import random
import numpy as np
import argparse
import matplotlib.pyplot as plt


def set_seed(seed: int = 42):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
    except ImportError:
        pass


def download_dataset():
    """Auto-download IMDB dataset using HuggingFace datasets."""
    try:
        from datasets import load_dataset
        import pandas as pd

        os.makedirs("data/raw", exist_ok=True)
        print("πŸ“₯ Downloading IMDB dataset from HuggingFace...")
        raw = load_dataset("imdb")
        train_df = pd.DataFrame(raw["train"])
        test_df = pd.DataFrame(raw["test"])
        df = pd.concat([train_df, test_df], ignore_index=True)
        df.rename(columns={"text": "review"}, inplace=True)

        df["sentiment"] = df["label"].map({1: "positive", 0: "negative"})
        out_path = "data/raw/IMDB Dataset.csv"
        df[["review", "sentiment"]].to_csv(out_path, index=False)
        print(f"βœ… Dataset saved to {out_path} ({len(df):,} rows)")
    except Exception as e:
        print(f"❌ Download failed: {e}")
        print("   Please manually download from:")
        print("   https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews")


def plot_training_history(history: dict, model_name: str, save: bool = True):
    """

    Plot training loss and validation accuracy curves.

    

    Args:

        history: Dict with keys: 'train_loss', 'val_loss', 'val_acc'

        model_name: Used for title and filename

        save: Save to results/ if True

    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))
    fig.patch.set_facecolor("#0f0f1a")
    for ax in (ax1, ax2):
        ax.set_facecolor("#1a1a2e")

    epochs = range(1, len(history["train_loss"]) + 1)

    # Loss
    ax1.plot(epochs, history["train_loss"], color="#6c63ff", lw=2, label="Train Loss")
    ax1.plot(epochs, history["val_loss"], color="#ff6584", lw=2, label="Val Loss")
    ax1.set_xlabel("Epoch", fontsize=12)
    ax1.set_ylabel("Loss", fontsize=12)
    ax1.set_title(f"{model_name} β€” Loss Curves", fontsize=13, fontweight="bold")
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)

    # Accuracy / F1
    metric_key = "val_f1" if "val_f1" in history else "val_acc"
    ax2.plot(epochs, history[metric_key], color="#43aa8b", lw=2,
             label=metric_key.replace("_", " ").title())
    ax2.set_xlabel("Epoch", fontsize=12)
    ax2.set_ylabel("Score", fontsize=12)
    ax2.set_title(f"{model_name} β€” {metric_key.replace('_', ' ').title()}",
                   fontsize=13, fontweight="bold")
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()

    if save:
        os.makedirs("results", exist_ok=True)
        safe = model_name.replace(" ", "_")
        path = f"results/{safe}_training_history.png"
        plt.savefig(path, dpi=150, bbox_inches="tight",
                    facecolor=fig.get_facecolor())
        print(f"πŸ“ˆ Training history saved β†’ {path}")
    plt.close()


def ensure_directories():
    """Create all required project directories."""
    dirs = [
        "data/raw", "data/processed", "data/glove",
        "models", "models/bert_finetuned",
        "results", "results/confusion_matrices",
        "notebooks",
    ]
    for d in dirs:
        os.makedirs(d, exist_ok=True)
    print("πŸ“ All directories created.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Sentiment Analysis Utilities")
    parser.add_argument("--download", action="store_true",
                        help="Download IMDB dataset via HuggingFace")
    parser.add_argument("--setup", action="store_true",
                        help="Create all project directories")
    args = parser.parse_args()

    if args.download:
        download_dataset()
    if args.setup:
        ensure_directories()
    if not any(vars(args).values()):
        parser.print_help()