bbc-document-classifier / src /models /train_transformer.py
pearlll's picture
Deploy document classifier app
492754f
import os
import mlflow
import mlflow.transformers
import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import (
DistilBertTokenizerFast,
DistilBertForSequenceClassification,
TrainingArguments,
Trainer
)
from sklearn.metrics import accuracy_score, f1_score
# Load datasets
train_df = pd.read_csv("data/splits/train.csv")
val_df = pd.read_csv("data/splits/val.csv")
# Label mapping
labels = sorted(train_df["label_text"].unique())
label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for label, idx in label2id.items()}
train_df["label_id"] = train_df["label_text"].map(label2id)
val_df["label_id"] = val_df["label_text"].map(label2id)
# Convert to Hugging Face Dataset
train_dataset = Dataset.from_pandas(
train_df[["clean_text", "label_id"]]
)
val_dataset = Dataset.from_pandas(
val_df[["clean_text", "label_id"]]
)
# Load tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained(
"distilbert-base-uncased"
)
# Tokenization function
def tokenize(batch):
return tokenizer(
batch["clean_text"],
padding="max_length",
truncation=True,
max_length=256
)
train_dataset = train_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.map(tokenize, batched=True)
# Model
model = DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=len(labels),
id2label=id2label,
label2id=label2id
)
# Metrics
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = accuracy_score(labels, predictions)
f1 = f1_score(
labels,
predictions,
average="weighted"
)
return {
"accuracy": accuracy,
"f1": f1
}
# Training arguments
training_args = TrainingArguments(
output_dir="models/distilbert_output",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=2,
weight_decay=0.01,
logging_dir="./logs",
load_best_model_at_end=True
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics
)
# MLflow
mlflow.set_experiment("bbc-document-classification")
with mlflow.start_run(run_name="distilbert_classifier"):
trainer.train()
metrics = trainer.evaluate()
print(metrics)
mlflow.log_params({
"model": "DistilBERT",
"epochs": 2,
"batch_size": 8,
"learning_rate": 2e-5
})
mlflow.log_metrics(metrics)
trainer.save_model("models/distilbert_model")
mlflow.transformers.log_model(
transformers_model={
"model": model,
"tokenizer": tokenizer
},
artifact_path="distilbert_model"
)
print("Transformer training completed!")