SmolLM2-ADI / train.py
Alibrown's picture
Update train.py
32c1712 verified
# =============================================================================
# train.py
# Dataset Preparation + Finetuning Entry Point
# SmolLM2 Service Space
# Copyright 2026 - Volkan KΓΌcΓΌkbudak
# Apache License V2 + ESOL 1.1
# =============================================================================
# Usage:
# python train.py --mode export β†’ export HF dataset to training format
# python train.py --mode validate β†’ validate ADI weights against dataset
# python train.py --mode finetune β†’ finetune SmolLM2 on exported data
# =============================================================================
import os
import argparse
import json
import logging
from datetime import datetime
from pathlib import Path
# ── Path Resolution ───────────────────────────────────────────────────────────
# HF Spaces: /tmp/ (read-only filesystem)
# Local dev: current directory
_TMP = Path("/tmp") if os.getenv("SPACE_ID") else Path(".")
TRAIN_DATA = _TMP / "train_data.jsonl"
VALID_RESULT = _TMP / "validation_results.json"
MODEL_OUTPUT = _TMP / "finetuned_model"
import model as model_module
from adi import DumpindexAnalyzer
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("train")
# =============================================================================
# Mode 1 β€” Export dataset to training format
# =============================================================================
def export_dataset(output_path: str = None):
"""
Export HF dataset logs to JSONL format for training.
Includes HIGH_PRIORITY, MEDIUM_PRIORITY and BLOCKED entries.
BLOCKED entries teach the model what to reject.
REJECT entries (ADI noise/quality fail) are skipped β€” no response logged.
"""
output = Path(output_path) if output_path else TRAIN_DATA
logger.info("Loading dataset from HF...")
entries = model_module.load_logs()
# ── DEBUG: remove after fix ───────────────────────────────────────────
#if entries:
#logger.info(f"Keys: {list(entries[0].keys())}")
#logger.info(f"Sample: {entries[0]}")
# ─────────────────────────────────────────────────────────────────────
if not entries:
logger.warning("Dataset empty β€” nothing to export")
return
count = 0
skipped = 0
with open(output, "w") as f:
for entry in entries:
# Skip ADI-rejected entries β€” no meaningful response logged
if entry.get("adi_decision") == "REJECT":
skipped += 1
continue
if not entry.get("response"):
skipped += 1
continue
# Format as instruction tuning pair
# BLOCKED entries are included β€” model learns what to refuse
record = {
"instruction": entry.get("system_prompt", "You are a helpful assistant."),
"input": entry.get("prompt", ""),
"output": entry.get("response", ""),
"adi_score": entry.get("adi_score"),
"adi_decision": entry.get("adi_decision"),
"is_safe": entry.get("adi_decision") != "BLOCKED",
}
f.write(json.dumps(record) + "\n")
count += 1
logger.info(f"Exported {count}/{len(entries)} entries β†’ {output} (skipped: {skipped})")
# =============================================================================
# Mode 2 β€” Validate ADI weights against collected data
# =============================================================================
def validate_adi():
"""
Run ADI weight validation against dataset.
Uses entries that have human_label field (manually labeled).
"""
logger.info("Loading dataset for ADI validation...")
entries = model_module.load_logs()
labeled = [(e["prompt"], e["human_label"]) for e in entries if e.get("human_label")]
if not labeled:
logger.warning("No labeled entries found β€” add 'human_label' field to dataset entries")
logger.info("Expected labels: REJECT | MEDIUM_PRIORITY | HIGH_PRIORITY")
return
analyzer = DumpindexAnalyzer()
accuracy = analyzer.validate_weights(labeled)
logger.info(f"ADI Validation accuracy: {accuracy:.1%} on {len(labeled)} samples")
result = {
"timestamp": datetime.utcnow().isoformat(),
"accuracy": accuracy,
"samples": len(labeled),
"weights": analyzer.weights,
}
VALID_RESULT.write_text(json.dumps(result, indent=2))
logger.info(f"Results saved β†’ {VALID_RESULT}")
# =============================================================================
# Mode 3 β€” Finetune SmolLM2 with TRL SFTTrainer
# =============================================================================
def finetune():
"""
Finetune SmolLM2 on exported dataset using TRL SFTTrainer.
Requires export first + enough data (500+ samples recommended).
On completion: pushes finetuned weights to private HF model repo.
"""
if not TRAIN_DATA.exists():
logger.error(f"train_data.jsonl not found at {TRAIN_DATA} β€” run export first")
return
lines = TRAIN_DATA.read_text().strip().splitlines()
logger.info(f"Training samples available: {len(lines)}")
if len(lines) < 10:
logger.error(f"Too few samples ({len(lines)}) β€” aborting finetune")
return
if len(lines) < 500:
logger.warning(f"Only {len(lines)} samples β€” recommend 500+ for meaningful finetuning")
# ── Imports ───────────────────────────────────────────────────────────────
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import Dataset
import torch
except ImportError as e:
logger.error(f"Missing dependency: {e} β€” run: pip install trl transformers datasets torch")
return
# ── Load dataset ──────────────────────────────────────────────────────────
logger.info("Loading training data...")
records = [json.loads(l) for l in lines]
def format_record(record):
"""Format record into chat template string."""
instruction = record.get("instruction", "You are a helpful assistant.")
user_input = record.get("input", "")
output = record.get("output", "")
return {
"text": f"<|system|>\n{instruction}\n<|user|>\n{user_input}\n<|assistant|>\n{output}"
}
formatted = [format_record(r) for r in records]
dataset = Dataset.from_list(formatted)
logger.info(f"Dataset ready: {len(dataset)} samples")
# ── Load model + tokenizer ────────────────────────────────────────────────
model_id = model_module.get_model_id()
kwargs = model_module.get_model_kwargs()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Loading base model: {model_id} on {device}...")
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
# Ensure pad token exists
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ── Training config ───────────────────────────────────────────────────────
# Conservative settings for CPU / low RAM (2-8GB)
sft_config = SFTConfig(
output_dir=str(MODEL_OUTPUT),
num_train_epochs=3,
per_device_train_batch_size=1, # CPU friendly
gradient_accumulation_steps=4, # effective batch size = 4
learning_rate=2e-5,
warmup_steps=10,
logging_steps=10,
save_steps=50,
save_total_limit=2,
fp16=False, # no GPU, no fp16
bf16=False,
dataloader_num_workers=0, # HF Spaces: no multiprocessing
report_to="none", # no wandb/tensorboard
max_seq_length=512, # SmolLM2 context limit
dataset_text_field="text",
)
# ── SFTTrainer ────────────────────────────────────────────────────────────
logger.info("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset,
#tokenizer=tokenizer,
)
# ── Train ─────────────────────────────────────────────────────────────────
logger.info("Starting finetuning...")
start = datetime.utcnow()
trainer.train()
duration = (datetime.utcnow() - start).total_seconds()
logger.info(f"Training complete in {duration:.0f}s")
# ── Save locally ──────────────────────────────────────────────────────────
trainer.save_model(str(MODEL_OUTPUT))
tokenizer.save_pretrained(str(MODEL_OUTPUT))
logger.info(f"Model saved β†’ {MODEL_OUTPUT}")
# ── Push to HF private repo ───────────────────────────────────────────────
token = model_module.TOKEN
private_repo = model_module.PRIVATE_MODEL
if token and private_repo:
logger.info(f"Pushing to HF: {private_repo}...")
try:
model.push_to_hub(private_repo, token=token, private=True)
tokenizer.push_to_hub(private_repo, token=token, private=True)
model_module.push_model_card({
"model_id": model_id,
"samples": len(dataset),
"epochs": 3,
"duration_sec": int(duration),
"finetuned_from": model_id,
})
logger.info(f"Model pushed β†’ {private_repo}")
except Exception as e:
logger.error(f"Push failed: {type(e).__name__}: {e}")
else:
logger.warning("No token or private repo configured β€” skipping HF push")
# =============================================================================
# CLI
# =============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SmolLM2 Training Utilities")
parser.add_argument(
"--mode",
choices=["export", "validate", "finetune"],
required=True,
help="export: dump dataset to JSONL | validate: test ADI weights | finetune: train model"
)
parser.add_argument("--output", default=None, help="Output file for export mode (default: auto)")
args = parser.parse_args()
if args.mode == "export":
export_dataset(args.output)
elif args.mode == "validate":
validate_adi()
elif args.mode == "finetune":
finetune()