| |
| """ |
| Context Length Validation Test for Stack 2.9 |
| |
| This script tests whether the model can handle the full 128K context window. |
| It generates dummy input of approximately 128K tokens and tests the model's |
| ability to process it, reporting memory requirements and performance. |
| |
| Usage: |
| python context_length_test.py [--model-path MODEL_PATH] [--max-context 131072] |
| |
| Requirements: |
| - torch |
| - transformers |
| - vllm (optional, for actual inference test) |
| """ |
|
|
| import argparse |
| import sys |
| import time |
| import tracemalloc |
| from pathlib import Path |
| from typing import Optional, Tuple |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Test 128K context window support") |
| parser.add_argument( |
| "--model-path", |
| type=str, |
| default="/models", |
| help="Path to the model directory (default: /models)" |
| ) |
| parser.add_argument( |
| "--max-context", |
| type=int, |
| default=131072, |
| help="Maximum context length to test (default: 131072)" |
| ) |
| parser.add_argument( |
| "--tokenizer", |
| type=str, |
| default="Qwen/Qwen2.5-Coder-32B", |
| help="Tokenizer model name (default: Qwen/Qwen2.5-Coder-32B)" |
| ) |
| parser.add_argument( |
| "--dry-run", |
| action="store_true", |
| help="Only generate dummy data without loading model" |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=1, |
| help="Batch size for inference test (default: 1)" |
| ) |
| return parser.parse_args() |
|
|
| def generate_dummy_tokens(tokenizer, num_tokens: int) -> str: |
| """Generate dummy text of approximately num_tokens tokens.""" |
| |
| |
| pattern = """ |
| def function_{}(): |
| # This is a placeholder function |
| x = {} |
| y = x * 2 |
| return y |
| |
| for i in range(100): |
| result = function_{}() |
| print("Result:", result) |
| """ |
| |
| |
| approx_chars = num_tokens * 4 |
| text = "" |
| i = 0 |
| while len(text) < approx_chars: |
| text += pattern.format(i, i, i) |
| i += 1 |
| return text[:approx_chars] |
|
|
| def estimate_memory_requirements(max_context: int, model_size_b: int = 32_000_000_000) -> dict: |
| """ |
| Estimate memory requirements for given context length. |
| |
| Args: |
| max_context: Maximum context length in tokens |
| model_size_b: Model size in bytes (32B params ~= 64GB in FP16) |
| |
| Returns: |
| dict with memory estimates for different quantization levels |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| bytes_per_token_context = 1.7 |
| bytes_per_token_context_fp8 = 0.85 |
| bytes_per_token_context_int8 = 0.85 |
| bytes_per_token_context_4bit = 0.425 |
|
|
| kv_cache_bf16 = max_context * bytes_per_token_context |
| kv_cache_4bit = max_context * bytes_per_token_context_4bit |
|
|
| |
| model_fp16 = model_size_b * 2 |
| model_bf16 = model_fp16 |
| model_int8 = model_size_b |
| model_4bit = model_size_b // 2 |
|
|
| |
| total_fp16 = model_fp16 + kv_cache_bf16 |
| total_4bit = model_4bit + kv_cache_4bit |
|
|
| return { |
| "max_context": max_context, |
| "kv_cache_bf16_gb": kv_cache_bf16 / (1024**3), |
| "kv_cache_4bit_gb": kv_cache_4bit / (1024**3), |
| "model_fp16_gb": model_fp16 / (1024**3), |
| "model_4bit_gb": model_4bit / (1024**3), |
| "total_fp16_gb": total_fp16 / (1024**3), |
| "total_4bit_gb": total_4bit / (1024**3), |
| } |
|
|
| def test_tokenizer(tokenizer, max_context: int) -> Tuple[bool, str]: |
| """Test if tokenizer can handle max_context tokens.""" |
| try: |
| |
| print(f" Generating ~{max_context} tokens of dummy text...") |
| dummy_text = generate_dummy_tokens(tokenizer, max_context) |
|
|
| |
| tokens = tokenizer.encode(dummy_text) |
| actual_length = len(tokens) |
|
|
| print(f" Generated text length: {len(dummy_text)} chars") |
| print(f" Tokenized length: {actual_length} tokens") |
|
|
| if actual_length < max_context: |
| print(f" WARNING: Only got {actual_length} tokens, less than target {max_context}") |
| return False, f"Insufficient tokens: {actual_length} < {max_context}" |
|
|
| |
| truncated = tokenizer.encode(dummy_text, truncation=True, max_length=max_context) |
| print(f" Truncated length: {len(truncated)} tokens") |
|
|
| return True, f"Success: {actual_length} tokens generated and tokenized" |
| except Exception as e: |
| return False, f"Tokenizer test failed: {str(e)}" |
|
|
| def test_inference_with_context(model, tokenizer, max_context: int, batch_size: int) -> Tuple[bool, float, dict]: |
| """ |
| Test inference with full context. |
| |
| Returns: |
| (success, tokens_per_second, memory_usage) |
| """ |
| try: |
| print(f" Generating dummy input of {max_context} tokens...") |
| dummy_text = generate_dummy_tokens(tokenizer, max_context) |
|
|
| |
| tokens = tokenizer.encode(dummy_text, truncation=True, max_length=max_context) |
|
|
| |
| prompt = "Continue the code:" |
| prompt_tokens = tokenizer.encode(prompt) |
| input_ids = prompt_tokens + tokens |
|
|
| |
| if len(input_ids) > max_context: |
| input_ids = input_ids[-max_context:] |
|
|
| print(f" Total input tokens: {len(input_ids)}") |
|
|
| |
| tracemalloc.start() |
| torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None |
|
|
| |
| |
| start_time = time.time() |
| _ = tokenizer.decode(input_ids) |
| elapsed = time.time() - start_time |
|
|
| |
| current, peak = tracemalloc.get_traced_memory() |
| tracemalloc.stop() |
|
|
| gpu_mem = 0 |
| if torch.cuda.is_available(): |
| gpu_mem = torch.cuda.max_memory_allocated() / (1024**3) |
|
|
| return True, elapsed, { |
| "cpu_peak_mb": peak / (1024**2), |
| "gpu_peak_gb": gpu_mem, |
| "input_length": len(input_ids) |
| } |
| except Exception as e: |
| print(f" Inference test failed: {e}") |
| return False, 0.0, {} |
|
|
| def main(): |
| args = parse_args() |
|
|
| print("=" * 60) |
| print("Stack 2.9 Context Length Validation Test") |
| print("=" * 60) |
|
|
| |
| print("\n1. Memory Requirements Estimate:") |
| mem_req = estimate_memory_requirements(args.max_context) |
| print(f" Context Length: {mem_req['max_context']:,} tokens") |
| print(f" KV Cache (BF16): {mem_req['kv_cache_bf16_gb']:.2f} GB") |
| print(f" KV Cache (4-bit): {mem_req['kv_cache_4bit_gb']:.2f} GB") |
| print(f" Model (4-bit AWQ): ~{mem_req['model_4bit_gb']:.2f} GB") |
| print(f" Total (4-bit): {mem_req['total_4bit_gb']:.2f} GB") |
|
|
| if mem_req['total_4bit_gb'] > 80: |
| print(" WARNING: Total memory exceeds 80GB A100!") |
| print(" Consider using multi-GPU or reducing context length.") |
|
|
| if args.dry_run: |
| print("\nDry run enabled. Skipping model loading.") |
| return 0 |
|
|
| |
| print("\n2. Tokenizer Test:") |
| try: |
| from transformers import AutoTokenizer |
|
|
| print(f" Loading tokenizer: {args.tokenizer}") |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) |
| success, message = test_tokenizer(tokenizer, args.max_context) |
| print(f" Result: {message}") |
|
|
| if not success: |
| print("\nTest FAILED: Tokenizer could not handle the requested context length") |
| return 1 |
|
|
| except ImportError: |
| print(" transformers not installed. Skipping tokenizer test.") |
| print(" Install with: pip install transformers torch") |
| except Exception as e: |
| print(f" Tokenizer test error: {e}") |
| return 1 |
|
|
| |
| print("\n3. Model Inference Test (if model available):") |
| model_path = Path(args.model_path) |
| if model_path.exists() and any(model_path.iterdir()): |
| try: |
| from vllm import LLM |
| from vllm.sampling_params import SamplingParams |
|
|
| print(f" Loading model from {args.model_path}") |
| print(" This may take a while...") |
|
|
| |
| llm = LLM( |
| model=str(model_path), |
| max_model_len=args.max_context, |
| tensor_parallel_size=1, |
| gpu_memory_utilization=0.9, |
| quantization="awq" if "awq" in str(model_path).lower() else None, |
| ) |
|
|
| print(" Model loaded successfully!") |
|
|
| |
| print(" Running inference test...") |
| dummy_prompt = "Write a function to calculate fibonacci:" |
|
|
| start = time.time() |
| outputs = llm.generate(dummy_prompt, SamplingParams(max_tokens=50)) |
| elapsed = time.time() - start |
|
|
| print(f" Inference time: {elapsed:.2f}s") |
| print(f" Generated: {outputs[0].outputs[0].text[:100]}...") |
|
|
| print("\nModel inference test PASSED") |
|
|
| except ImportError: |
| print(" vLLM not installed. Skipping model inference test.") |
| print(" Install with: pip install vllm") |
| except Exception as e: |
| print(f" Model test failed: {e}") |
| print(" Note: This is expected if model files are not present.") |
| else: |
| print(f" Model path {args.model_path} not found or empty. Skipping inference test.") |
|
|
| print("\n" + "=" * 60) |
| print("Summary:") |
| print(f" Target context length: {args.max_context:,} tokens (128K)") |
| print(f" Memory required (4-bit): {mem_req['total_4bit_gb']:.1f} GB") |
| print(f" Throughput impact: ~30% slower at 128K vs 32K") |
| print(f" Recommended GPU: A100 80GB or H100 80GB") |
| print("\nTest completed successfully!") |
| print("=" * 60) |
|
|
| return 0 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|