walidsobhie-code
refactor: Squeeze folders further - cleaner structure
65888d5
#!/usr/bin/env python3
"""
Stack 2.9 Model Quantization Script
Applies AWQ/GPTQ quantization to the trained model for efficient inference.
"""
import argparse
import os
import sys
import torch
from pathlib import Path
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
def parse_args():
parser = argparse.ArgumentParser(description="Quantize Stack 2.9 model")
parser.add_argument(
"--model-path",
type=str,
default="./output/stack-2.9-merged",
help="Path to merged LoRA model"
)
parser.add_argument(
"--output-path",
type=str,
default="./output/stack-2.9-quantized",
help="Output path for quantized model"
)
parser.add_argument(
"--method",
type=str,
choices=["awq", "gptq", "bnb"],
default="bnb",
help="Quantization method (awq, gptq, or bitsandbytes)"
)
parser.add_argument(
"--bits",
type=int,
default=4,
help="Quantization bits (2, 4, or 8)"
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Run benchmark after quantization"
)
return parser.parse_args()
def get_model_size(path: str) -> float:
"""Calculate model size in GB."""
total_size = 0
for dirpath, dirnames, filenames in os.walk(path):
for f in filenames:
fp = os.path.join(dirpath, f)
if os.path.exists(fp):
total_size += os.path.getsize(fp)
return total_size / (1024 ** 3)
def benchmark_model(model, tokenizer, device="cuda"):
"""Benchmark model inference speed."""
if not torch.cuda.is_available():
print("CUDA not available, skipping GPU benchmark")
return None
# Warm up
prompt = "Write a hello world program in Python"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
_ = model.generate(**inputs, max_new_tokens=20, do_sample=False)
# Benchmark
import time
num_runs = 5
times = []
for _ in range(num_runs):
torch.cuda.synchronize()
start = time.perf_counter()
with torch.no_grad():
_ = model.generate(**inputs, max_new_tokens=100, do_sample=False)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
avg_time = sum(times) / len(times)
tokens_generated = 100
return {
"avg_time": avg_time,
"tokens_per_sec": tokens_generated / avg_time,
"memory_allocated": torch.cuda.max_memory_allocated() / (1024 ** 3)
}
def quantize_awq(args):
"""Apply AWQ quantization."""
try:
from awq import AWQ4BitConfig, prepare_model
from transformers import AutoModelForCausalLM
print(f"Loading model from {args.model_path}...")
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
device_map="auto"
)
awq_config = AWQ4BitConfig(
num_groups=32,
min_coeff=0.01,
max_coeff=1.0
)
print("Applying AWQ quantization...")
quantized_model = prepare_model(model, awq_config)
print(f"Saving to {args.output_path}...")
os.makedirs(args.output_path, exist_ok=True)
quantized_model.save_pretrained(args.output_path)
return quantized_model
except ImportError:
print("AWQ not available, falling back to bitsandbytes")
return quantize_bnb(args)
def quantize_gptq(args):
"""Apply GPTQ quantization."""
try:
from transformers import AutoModelForCausalLM
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
print(f"Loading model from {args.model_path}...")
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
device_map="auto"
)
print("Applying GPTQ quantization...")
quantize_config = BaseQuantizeConfig(
bits=args.bits,
group_size=128,
desc_act=False
)
# GPTQ quantization would need calibration data
# For now, save as is with bitsandbytes fallback
print("GPTQ requires calibration - using bitsandbytes instead")
return quantize_bnb(args)
except ImportError:
print("GPTQ not available, falling back to bitsandbytes")
return quantize_bnb(args)
def quantize_bnb(args):
"""Apply bitsandbytes quantization (default, most compatible)."""
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"Loading model from {args.model_path}...")
load_in_4bit = args.bits == 4
load_in_8bit = args.bits == 8
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
device_map="auto",
trust_remote_code=True
)
print(f"Saving to {args.output_path}...")
os.makedirs(args.output_path, exist_ok=True)
model.save_pretrained(args.output_path)
# Also save tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
tokenizer.save_pretrained(args.output_path)
except:
print("Could not save tokenizer separately")
return model
def main():
args = parse_args()
# Validate input
if not os.path.exists(args.model_path):
print(f"Error: Model path {args.model_path} does not exist")
print("Please run training first or specify correct path")
sys.exit(1)
print("=" * 60)
print("Stack 2.9 Model Quantization")
print("=" * 60)
print(f"Input model: {args.model_path}")
print(f"Output path: {args.output_path}")
print(f"Method: {args.method}")
print(f"Bits: {args.bits}")
print("=" * 60)
# Get original size
original_size = get_model_size(args.model_path)
print(f"Original model size: {original_size:.2f} GB")
# Quantize based on method
if args.method == "awq":
model = quantize_awq(args)
elif args.method == "gptq":
model = quantize_gptq(args)
else:
model = quantize_bnb(args)
# Get quantized size
quantized_size = get_model_size(args.output_path)
compression_ratio = original_size / quantized_size if quantized_size > 0 else 0
print("=" * 60)
print("Quantization Complete!")
print("=" * 60)
print(f"Original size: {original_size:.2f} GB")
print(f"Quantized size: {quantized_size:.2f} GB")
print(f"Compression ratio: {compression_ratio:.2f}x")
print(f"Output saved to: {args.output_path}")
# Benchmark if requested
if args.benchmark:
print("\nRunning benchmark...")
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
results = benchmark_model(model, tokenizer)
if results:
print(f"\nBenchmark Results:")
print(f" Average time: {results['avg_time']:.2f}s")
print(f" Tokens/sec: {results['tokens_per_sec']:.1f}")
print(f" GPU memory: {results['memory_allocated']:.2f} GB")
except Exception as e:
print(f"Benchmark failed: {e}")
print("\n✓ Quantization complete!")
return 0
if __name__ == "__main__":
sys.exit(main())