Stack-2-9-finetuned / stack /eval /context_length_test.py
walidsobhie-code
refactor: Squeeze folders further - cleaner structure
65888d5
#!/usr/bin/env python3
"""
Context Length Validation Test for Stack 2.9
This script tests whether the model can handle the full 128K context window.
It generates dummy input of approximately 128K tokens and tests the model's
ability to process it, reporting memory requirements and performance.
Usage:
python context_length_test.py [--model-path MODEL_PATH] [--max-context 131072]
Requirements:
- torch
- transformers
- vllm (optional, for actual inference test)
"""
import argparse
import sys
import time
import tracemalloc
from pathlib import Path
from typing import Optional, Tuple
def parse_args():
parser = argparse.ArgumentParser(description="Test 128K context window support")
parser.add_argument(
"--model-path",
type=str,
default="/models",
help="Path to the model directory (default: /models)"
)
parser.add_argument(
"--max-context",
type=int,
default=131072,
help="Maximum context length to test (default: 131072)"
)
parser.add_argument(
"--tokenizer",
type=str,
default="Qwen/Qwen2.5-Coder-32B",
help="Tokenizer model name (default: Qwen/Qwen2.5-Coder-32B)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Only generate dummy data without loading model"
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size for inference test (default: 1)"
)
return parser.parse_args()
def generate_dummy_tokens(tokenizer, num_tokens: int) -> str:
"""Generate dummy text of approximately num_tokens tokens."""
# Generate a repeating pattern that tokenizes predictably
# Use code-like structure which tokenizes efficiently
pattern = """
def function_{}():
# This is a placeholder function
x = {}
y = x * 2
return y
for i in range(100):
result = function_{}()
print("Result:", result)
"""
# Generate enough characters to exceed token count
# Rough estimate: 4 chars per token for code
approx_chars = num_tokens * 4
text = ""
i = 0
while len(text) < approx_chars:
text += pattern.format(i, i, i)
i += 1
return text[:approx_chars]
def estimate_memory_requirements(max_context: int, model_size_b: int = 32_000_000_000) -> dict:
"""
Estimate memory requirements for given context length.
Args:
max_context: Maximum context length in tokens
model_size_b: Model size in bytes (32B params ~= 64GB in FP16)
Returns:
dict with memory estimates for different quantization levels
"""
# KV cache memory estimation
# Formula: 2 * num_layers * hidden_size * num_heads * head_dim * num_tokens * precision_bytes
# For Qwen2.5-Coder-32B:
# - num_layers: 64
# - hidden_size: 5120
# - num_heads: 40
# - head_dim: 128
# Simplified: ~ 2 * 64 * 5120 * 40 * 128 = ~ 1.7GB per 1K tokens in BF16
bytes_per_token_context = 1.7 # Approximate KV cache bytes per token (BF16)
bytes_per_token_context_fp8 = 0.85 # Approximate for FP8 quantization
bytes_per_token_context_int8 = 0.85 # Same for INT8
bytes_per_token_context_4bit = 0.425 # ~half of BF16 for 4-bit
kv_cache_bf16 = max_context * bytes_per_token_context
kv_cache_4bit = max_context * bytes_per_token_context_4bit
# Model memory (approximate)
model_fp16 = model_size_b * 2 # FP16 = 2 bytes per param
model_bf16 = model_fp16
model_int8 = model_size_b # INT8 = 1 byte per param
model_4bit = model_size_b // 2 # 4-bit = 0.5 bytes per param
# Total memory with KV cache (worst case: full context + model)
total_fp16 = model_fp16 + kv_cache_bf16
total_4bit = model_4bit + kv_cache_4bit
return {
"max_context": max_context,
"kv_cache_bf16_gb": kv_cache_bf16 / (1024**3),
"kv_cache_4bit_gb": kv_cache_4bit / (1024**3),
"model_fp16_gb": model_fp16 / (1024**3),
"model_4bit_gb": model_4bit / (1024**3),
"total_fp16_gb": total_fp16 / (1024**3),
"total_4bit_gb": total_4bit / (1024**3),
}
def test_tokenizer(tokenizer, max_context: int) -> Tuple[bool, str]:
"""Test if tokenizer can handle max_context tokens."""
try:
# Generate dummy text
print(f" Generating ~{max_context} tokens of dummy text...")
dummy_text = generate_dummy_tokens(tokenizer, max_context)
# Tokenize and check actual length
tokens = tokenizer.encode(dummy_text)
actual_length = len(tokens)
print(f" Generated text length: {len(dummy_text)} chars")
print(f" Tokenized length: {actual_length} tokens")
if actual_length < max_context:
print(f" WARNING: Only got {actual_length} tokens, less than target {max_context}")
return False, f"Insufficient tokens: {actual_length} < {max_context}"
# Test truncation/padding
truncated = tokenizer.encode(dummy_text, truncation=True, max_length=max_context)
print(f" Truncated length: {len(truncated)} tokens")
return True, f"Success: {actual_length} tokens generated and tokenized"
except Exception as e:
return False, f"Tokenizer test failed: {str(e)}"
def test_inference_with_context(model, tokenizer, max_context: int, batch_size: int) -> Tuple[bool, float, dict]:
"""
Test inference with full context.
Returns:
(success, tokens_per_second, memory_usage)
"""
try:
print(f" Generating dummy input of {max_context} tokens...")
dummy_text = generate_dummy_tokens(tokenizer, max_context)
# Tokenize
tokens = tokenizer.encode(dummy_text, truncation=True, max_length=max_context)
# Add a short prompt
prompt = "Continue the code:"
prompt_tokens = tokenizer.encode(prompt)
input_ids = prompt_tokens + tokens
# Pad/truncate to exact max_context if needed
if len(input_ids) > max_context:
input_ids = input_ids[-max_context:] # Keep most recent
print(f" Total input tokens: {len(input_ids)}")
# Measure memory before
tracemalloc.start()
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
# Run inference (this would be with vLLM in real scenario)
# For testing, we just measure tokenization memory
start_time = time.time()
_ = tokenizer.decode(input_ids) # Simple operation to include in timing
elapsed = time.time() - start_time
# Get memory stats
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
gpu_mem = 0
if torch.cuda.is_available():
gpu_mem = torch.cuda.max_memory_allocated() / (1024**3)
return True, elapsed, {
"cpu_peak_mb": peak / (1024**2),
"gpu_peak_gb": gpu_mem,
"input_length": len(input_ids)
}
except Exception as e:
print(f" Inference test failed: {e}")
return False, 0.0, {}
def main():
args = parse_args()
print("=" * 60)
print("Stack 2.9 Context Length Validation Test")
print("=" * 60)
# Print memory requirements estimate
print("\n1. Memory Requirements Estimate:")
mem_req = estimate_memory_requirements(args.max_context)
print(f" Context Length: {mem_req['max_context']:,} tokens")
print(f" KV Cache (BF16): {mem_req['kv_cache_bf16_gb']:.2f} GB")
print(f" KV Cache (4-bit): {mem_req['kv_cache_4bit_gb']:.2f} GB")
print(f" Model (4-bit AWQ): ~{mem_req['model_4bit_gb']:.2f} GB")
print(f" Total (4-bit): {mem_req['total_4bit_gb']:.2f} GB")
if mem_req['total_4bit_gb'] > 80:
print(" WARNING: Total memory exceeds 80GB A100!")
print(" Consider using multi-GPU or reducing context length.")
if args.dry_run:
print("\nDry run enabled. Skipping model loading.")
return 0
# Try to import and test with actual tokenizer
print("\n2. Tokenizer Test:")
try:
from transformers import AutoTokenizer
print(f" Loading tokenizer: {args.tokenizer}")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
success, message = test_tokenizer(tokenizer, args.max_context)
print(f" Result: {message}")
if not success:
print("\nTest FAILED: Tokenizer could not handle the requested context length")
return 1
except ImportError:
print(" transformers not installed. Skipping tokenizer test.")
print(" Install with: pip install transformers torch")
except Exception as e:
print(f" Tokenizer test error: {e}")
return 1
# Try to test with actual model if available
print("\n3. Model Inference Test (if model available):")
model_path = Path(args.model_path)
if model_path.exists() and any(model_path.iterdir()):
try:
from vllm import LLM
from vllm.sampling_params import SamplingParams
print(f" Loading model from {args.model_path}")
print(" This may take a while...")
# Load with vLLM
llm = LLM(
model=str(model_path),
max_model_len=args.max_context,
tensor_parallel_size=1, # Adjust based on GPUs
gpu_memory_utilization=0.9,
quantization="awq" if "awq" in str(model_path).lower() else None,
)
print(" Model loaded successfully!")
# Generate a small test
print(" Running inference test...")
dummy_prompt = "Write a function to calculate fibonacci:"
start = time.time()
outputs = llm.generate(dummy_prompt, SamplingParams(max_tokens=50))
elapsed = time.time() - start
print(f" Inference time: {elapsed:.2f}s")
print(f" Generated: {outputs[0].outputs[0].text[:100]}...")
print("\nModel inference test PASSED")
except ImportError:
print(" vLLM not installed. Skipping model inference test.")
print(" Install with: pip install vllm")
except Exception as e:
print(f" Model test failed: {e}")
print(" Note: This is expected if model files are not present.")
else:
print(f" Model path {args.model_path} not found or empty. Skipping inference test.")
print("\n" + "=" * 60)
print("Summary:")
print(f" Target context length: {args.max_context:,} tokens (128K)")
print(f" Memory required (4-bit): {mem_req['total_4bit_gb']:.1f} GB")
print(f" Throughput impact: ~30% slower at 128K vs 32K")
print(f" Recommended GPU: A100 80GB or H100 80GB")
print("\nTest completed successfully!")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())