bbc-document-classifier / src /models /train_compare_models.py
pearlll's picture
Deploy document classifier app
492754f
import os
import mlflow
import mlflow.sklearn
import pandas as pd
import joblib
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
train_df = pd.read_csv("data/splits/train.csv")
val_df = pd.read_csv("data/splits/val.csv")
X_train = train_df["clean_text"]
y_train = train_df["label_text"]
X_val = val_df["clean_text"]
y_val = val_df["label_text"]
os.makedirs("models", exist_ok=True)
models = {
"logistic_regression": LogisticRegression(max_iter=1000),
"linear_svm": LinearSVC(),
"random_forest": RandomForestClassifier(n_estimators=100, random_state=42),
}
mlflow.set_experiment("bbc-document-classification")
best_model = None
best_model_name = None
best_f1 = 0
for model_name, classifier in models.items():
with mlflow.start_run(run_name=model_name):
pipeline = Pipeline([
("tfidf", TfidfVectorizer(max_features=5000)),
("classifier", classifier)
])
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_val)
accuracy = accuracy_score(y_val, y_pred)
f1 = f1_score(y_val, y_pred, average="weighted")
print("\n==============================")
print(f"Model: {model_name}")
print("Accuracy:", accuracy)
print("F1 Score:", f1)
print(classification_report(y_val, y_pred))
mlflow.log_param("model_name", model_name)
mlflow.log_param("vectorizer", "TF-IDF")
mlflow.log_param("max_features", 5000)
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("f1_score", f1)
mlflow.sklearn.log_model(pipeline, model_name)
if f1 > best_f1:
best_f1 = f1
best_model = pipeline
best_model_name = model_name
joblib.dump(best_model, "models/best_model.pkl")
print("\nBest model:", best_model_name)
print("Best F1:", best_f1)
print("Saved to models/best_model.pkl")