Stack-2-9-finetuned / scripts /eval /compare_models.py
walidsobhie-code
chore: Rename MCP server to Stack2.9
c7f1596
#!/usr/bin/env python3
"""
compare_models.py — Compare different Stack 2.9 model versions.
Reads from models/registry.json and produces a side-by-side comparison
of model properties and performance metrics.
Usage:
python scripts/compare_models.py
python scripts/compare_models.py --models stack-2.9-1.5B stack-2.9-7B
python scripts/compare_models.py --metrics hellaswag mmlu humaneval
python scripts/compare_models.py --verbose
"""
import argparse
import json
import sys
from pathlib import Path
from typing import Optional
REGISTRY_PATH = Path(__file__).parent.parent / "models" / "registry.json"
ALL_METRICS = ["hellaswag", "arc_challenge", "mmlu", "humaneval", "loss"]
def load_registry(registry_path: Path = REGISTRY_PATH) -> dict:
"""Load the model registry JSON."""
if not registry_path.exists():
print(f"ERROR: Registry not found at {registry_path}", file=sys.stderr)
sys.exit(1)
with open(registry_path) as f:
return json.load(f)
def format_params(n: int) -> str:
if n >= 1_000_000_000:
return f"{n / 1_000_000_000:.1f}B"
elif n >= 1_000_000:
return f"{n / 1_000_000:.0f}M"
return str(n)
def compare_params(a: int, b: int) -> str:
"""Compare two parameter counts."""
ratio = b / a
if ratio > 1:
return f" {ratio:.1f}x larger ({format_params(b)} vs {format_params(a)})"
else:
return f" {1/ratio:.1f}x smaller ({format_params(b)} vs {format_params(a)})"
def build_row(version: str, key: str, value) -> str:
"""Build a comparison table row."""
if value is None:
val_str = "—"
elif isinstance(value, float):
val_str = f"{value:.4f}"
elif isinstance(value, int):
val_str = f"{value:,}"
else:
val_str = str(value)
return f" {version:<22} {key:<30} {val_str}"
def print_comparison(models: list, metrics: list, verbose: bool = False):
"""Print a side-by-side comparison table."""
# Header
versions = [m["version"] for m in models]
max_ver_len = max(len(v) for v in versions)
print(f"\n{'='*72}")
print(f" Model Comparison — Stack 2.9")
print(f"{'='*72}")
# Non-metric fields
fields = [
("Base Model", "base_model"),
("Parameters", "parameters"),
("Quantization", "quantization"),
("Precision", "precision"),
("Context Length", "context_length"),
("Vocabulary Size", "vocabulary_size"),
("Dataset", "dataset"),
("LoRA Rank", ("lora", "rank")),
("LoRA Alpha", ("lora", "alpha")),
("LoRA Dropout", ("lora", "dropout")),
("Status", "status"),
("Created", "created_at"),
("Use Case", "use_case"),
]
print(f"\n {'Model':<{max_ver_len}} {'Field':<30} {'Value'}")
print(f" {'-'*max_ver_len} {'-'*30} {'-'*20}")
for label, key in fields:
row_values = []
for m in models:
if isinstance(key, tuple):
nested = m
for k in key:
nested = nested.get(k, {}) if isinstance(nested, dict) else {}
row_values.append(nested if nested else None)
else:
val = m.get(key)
# Format parameters as human-readable
if key == "parameters" and val:
val = f"{format_params(val)} ({val:,})"
row_values.append(val)
unique = set(str(v) for v in row_values)
if len(unique) == 1 and row_values[0] is None:
continue
print(f"\n {label}:")
for i, (ver, val) in enumerate(zip(versions, row_values)):
if val is None:
val_str = "—"
elif isinstance(val, float):
val_str = f"{val:.4f}"
elif isinstance(val, int):
val_str = f"{val:,}"
else:
val_str = str(val)
marker = " →" if i > 0 and row_values[i] != row_values[0] else " "
print(f" {marker} {ver:<{max_ver_len}} {val_str}")
# Performance metrics comparison
has_any_metrics = any(
any(m.get("performance", {}).get(metric) is not None for m in models)
for metric in metrics
)
if has_any_metrics:
print(f"\n\n Performance Benchmarks")
print(f" {'-'*max_ver_len} {'-'*30} {'-'*10}")
for metric in metrics:
metric_name = metric.replace("_", " ").title()
values = [m.get("performance", {}).get(metric) for m in models]
if all(v is None for v in values):
continue
print(f"\n {metric_name}:")
for i, (ver, val) in enumerate(zip(versions, values)):
if val is None:
val_str = "N/A"
else:
val_str = f"{val:.4f}"
marker = " →" if i > 0 else " "
print(f" {marker} {ver:<{max_ver_len}} {val_str}")
# Parameter size comparison (pairwise)
if len(models) >= 2:
print(f"\n\n Parameter Size Comparison:")
for i in range(len(models)):
for j in range(i + 1, len(models)):
a, b = models[i], models[j]
pa = a.get("parameters", 0)
pb = b.get("parameters", 0)
if pa and pb:
ratio = pb / pa
direction = "larger" if ratio > 1 else "smaller"
print(f" {b['version']} is {ratio:.2f}x {direction} than {a['version']}")
print(f"\n{'='*72}\n")
def main():
parser = argparse.ArgumentParser(
description="Compare Stack 2.9 model versions side by side."
)
parser.add_argument(
"--models", "-m",
nargs="+",
metavar="VERSION",
help="Model versions to compare (e.g., stack-2.9-1.5B stack-2.9-7B). "
"If omitted, compares all available models."
)
parser.add_argument(
"--metrics", "-M",
nargs="+",
choices=ALL_METRICS,
default=ALL_METRICS,
help=f"Benchmark metrics to include (default: all). Choices: {ALL_METRICS}"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Show verbose output."
)
parser.add_argument(
"--registry",
default=REGISTRY_PATH,
metavar="PATH",
help=f"Path to registry.json (default: {REGISTRY_PATH})."
)
args = parser.parse_args()
registry_path = Path(args.registry)
registry = load_registry(registry_path)
models = registry.get("models", [])
if args.models:
selected = []
for v in args.models:
found = next((m for m in models if m["version"] == v), None)
if found:
selected.append(found)
else:
print(f"WARNING: Model '{v}' not found in registry. Skipping.", file=sys.stderr)
available = ", ".join(m["version"] for m in models)
print(f" Available: {available}", file=sys.stderr)
if not selected:
print("ERROR: No valid models selected.", file=sys.stderr)
sys.exit(1)
else:
selected = models
print_comparison(selected, metrics=args.metrics, verbose=args.verbose or args.verbose)
if __name__ == "__main__":
main()