File size: 10,936 Bytes
6379283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#!/usr/bin/env python3
"""
Context Length Validation Test for Stack 2.9

This script tests whether the model can handle the full 128K context window.
It generates dummy input of approximately 128K tokens and tests the model's
ability to process it, reporting memory requirements and performance.

Usage:
    python context_length_test.py [--model-path MODEL_PATH] [--max-context 131072]

Requirements:
    - torch
    - transformers
    - vllm (optional, for actual inference test)
"""

import argparse
import sys
import time
import tracemalloc
from pathlib import Path
from typing import Optional, Tuple

def parse_args():
    parser = argparse.ArgumentParser(description="Test 128K context window support")
    parser.add_argument(
        "--model-path",
        type=str,
        default="/models",
        help="Path to the model directory (default: /models)"
    )
    parser.add_argument(
        "--max-context",
        type=int,
        default=131072,
        help="Maximum context length to test (default: 131072)"
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        default="Qwen/Qwen2.5-Coder-32B",
        help="Tokenizer model name (default: Qwen/Qwen2.5-Coder-32B)"
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Only generate dummy data without loading model"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=1,
        help="Batch size for inference test (default: 1)"
    )
    return parser.parse_args()

def generate_dummy_tokens(tokenizer, num_tokens: int) -> str:
    """Generate dummy text of approximately num_tokens tokens."""
    # Generate a repeating pattern that tokenizes predictably
    # Use code-like structure which tokenizes efficiently
    pattern = """
def function_{}():
    # This is a placeholder function
    x = {}
    y = x * 2
    return y

for i in range(100):
    result = function_{}()
    print("Result:", result)
"""
    # Generate enough characters to exceed token count
    # Rough estimate: 4 chars per token for code
    approx_chars = num_tokens * 4
    text = ""
    i = 0
    while len(text) < approx_chars:
        text += pattern.format(i, i, i)
        i += 1
    return text[:approx_chars]

def estimate_memory_requirements(max_context: int, model_size_b: int = 32_000_000_000) -> dict:
    """
    Estimate memory requirements for given context length.

    Args:
        max_context: Maximum context length in tokens
        model_size_b: Model size in bytes (32B params ~= 64GB in FP16)

    Returns:
        dict with memory estimates for different quantization levels
    """
    # KV cache memory estimation
    # Formula: 2 * num_layers * hidden_size * num_heads * head_dim * num_tokens * precision_bytes
    # For Qwen2.5-Coder-32B:
    # - num_layers: 64
    # - hidden_size: 5120
    # - num_heads: 40
    # - head_dim: 128
    # Simplified: ~ 2 * 64 * 5120 * 40 * 128 = ~ 1.7GB per 1K tokens in BF16

    bytes_per_token_context = 1.7  # Approximate KV cache bytes per token (BF16)
    bytes_per_token_context_fp8 = 0.85  # Approximate for FP8 quantization
    bytes_per_token_context_int8 = 0.85  # Same for INT8
    bytes_per_token_context_4bit = 0.425  # ~half of BF16 for 4-bit

    kv_cache_bf16 = max_context * bytes_per_token_context
    kv_cache_4bit = max_context * bytes_per_token_context_4bit

    # Model memory (approximate)
    model_fp16 = model_size_b * 2  # FP16 = 2 bytes per param
    model_bf16 = model_fp16
    model_int8 = model_size_b  # INT8 = 1 byte per param
    model_4bit = model_size_b // 2  # 4-bit = 0.5 bytes per param

    # Total memory with KV cache (worst case: full context + model)
    total_fp16 = model_fp16 + kv_cache_bf16
    total_4bit = model_4bit + kv_cache_4bit

    return {
        "max_context": max_context,
        "kv_cache_bf16_gb": kv_cache_bf16 / (1024**3),
        "kv_cache_4bit_gb": kv_cache_4bit / (1024**3),
        "model_fp16_gb": model_fp16 / (1024**3),
        "model_4bit_gb": model_4bit / (1024**3),
        "total_fp16_gb": total_fp16 / (1024**3),
        "total_4bit_gb": total_4bit / (1024**3),
    }

def test_tokenizer(tokenizer, max_context: int) -> Tuple[bool, str]:
    """Test if tokenizer can handle max_context tokens."""
    try:
        # Generate dummy text
        print(f"  Generating ~{max_context} tokens of dummy text...")
        dummy_text = generate_dummy_tokens(tokenizer, max_context)

        # Tokenize and check actual length
        tokens = tokenizer.encode(dummy_text)
        actual_length = len(tokens)

        print(f"  Generated text length: {len(dummy_text)} chars")
        print(f"  Tokenized length: {actual_length} tokens")

        if actual_length < max_context:
            print(f"  WARNING: Only got {actual_length} tokens, less than target {max_context}")
            return False, f"Insufficient tokens: {actual_length} < {max_context}"

        # Test truncation/padding
        truncated = tokenizer.encode(dummy_text, truncation=True, max_length=max_context)
        print(f"  Truncated length: {len(truncated)} tokens")

        return True, f"Success: {actual_length} tokens generated and tokenized"
    except Exception as e:
        return False, f"Tokenizer test failed: {str(e)}"

def test_inference_with_context(model, tokenizer, max_context: int, batch_size: int) -> Tuple[bool, float, dict]:
    """
    Test inference with full context.

    Returns:
        (success, tokens_per_second, memory_usage)
    """
    try:
        print(f"  Generating dummy input of {max_context} tokens...")
        dummy_text = generate_dummy_tokens(tokenizer, max_context)

        # Tokenize
        tokens = tokenizer.encode(dummy_text, truncation=True, max_length=max_context)

        # Add a short prompt
        prompt = "Continue the code:"
        prompt_tokens = tokenizer.encode(prompt)
        input_ids = prompt_tokens + tokens

        # Pad/truncate to exact max_context if needed
        if len(input_ids) > max_context:
            input_ids = input_ids[-max_context:]  # Keep most recent

        print(f"  Total input tokens: {len(input_ids)}")

        # Measure memory before
        tracemalloc.start()
        torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None

        # Run inference (this would be with vLLM in real scenario)
        # For testing, we just measure tokenization memory
        start_time = time.time()
        _ = tokenizer.decode(input_ids)  # Simple operation to include in timing
        elapsed = time.time() - start_time

        # Get memory stats
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()

        gpu_mem = 0
        if torch.cuda.is_available():
            gpu_mem = torch.cuda.max_memory_allocated() / (1024**3)

        return True, elapsed, {
            "cpu_peak_mb": peak / (1024**2),
            "gpu_peak_gb": gpu_mem,
            "input_length": len(input_ids)
        }
    except Exception as e:
        print(f"  Inference test failed: {e}")
        return False, 0.0, {}

def main():
    args = parse_args()

    print("=" * 60)
    print("Stack 2.9 Context Length Validation Test")
    print("=" * 60)

    # Print memory requirements estimate
    print("\n1. Memory Requirements Estimate:")
    mem_req = estimate_memory_requirements(args.max_context)
    print(f"   Context Length: {mem_req['max_context']:,} tokens")
    print(f"   KV Cache (BF16): {mem_req['kv_cache_bf16_gb']:.2f} GB")
    print(f"   KV Cache (4-bit): {mem_req['kv_cache_4bit_gb']:.2f} GB")
    print(f"   Model (4-bit AWQ): ~{mem_req['model_4bit_gb']:.2f} GB")
    print(f"   Total (4-bit): {mem_req['total_4bit_gb']:.2f} GB")

    if mem_req['total_4bit_gb'] > 80:
        print("   WARNING: Total memory exceeds 80GB A100!")
        print("   Consider using multi-GPU or reducing context length.")

    if args.dry_run:
        print("\nDry run enabled. Skipping model loading.")
        return 0

    # Try to import and test with actual tokenizer
    print("\n2. Tokenizer Test:")
    try:
        from transformers import AutoTokenizer

        print(f"  Loading tokenizer: {args.tokenizer}")
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
        success, message = test_tokenizer(tokenizer, args.max_context)
        print(f"  Result: {message}")

        if not success:
            print("\nTest FAILED: Tokenizer could not handle the requested context length")
            return 1

    except ImportError:
        print("  transformers not installed. Skipping tokenizer test.")
        print("  Install with: pip install transformers torch")
    except Exception as e:
        print(f"  Tokenizer test error: {e}")
        return 1

    # Try to test with actual model if available
    print("\n3. Model Inference Test (if model available):")
    model_path = Path(args.model_path)
    if model_path.exists() and any(model_path.iterdir()):
        try:
            from vllm import LLM
            from vllm.sampling_params import SamplingParams

            print(f"  Loading model from {args.model_path}")
            print("  This may take a while...")

            # Load with vLLM
            llm = LLM(
                model=str(model_path),
                max_model_len=args.max_context,
                tensor_parallel_size=1,  # Adjust based on GPUs
                gpu_memory_utilization=0.9,
                quantization="awq" if "awq" in str(model_path).lower() else None,
            )

            print("  Model loaded successfully!")

            # Generate a small test
            print("  Running inference test...")
            dummy_prompt = "Write a function to calculate fibonacci:"

            start = time.time()
            outputs = llm.generate(dummy_prompt, SamplingParams(max_tokens=50))
            elapsed = time.time() - start

            print(f"  Inference time: {elapsed:.2f}s")
            print(f"  Generated: {outputs[0].outputs[0].text[:100]}...")

            print("\nModel inference test PASSED")

        except ImportError:
            print("  vLLM not installed. Skipping model inference test.")
            print("  Install with: pip install vllm")
        except Exception as e:
            print(f"  Model test failed: {e}")
            print("  Note: This is expected if model files are not present.")
    else:
        print(f"  Model path {args.model_path} not found or empty. Skipping inference test.")

    print("\n" + "=" * 60)
    print("Summary:")
    print(f"  Target context length: {args.max_context:,} tokens (128K)")
    print(f"  Memory required (4-bit): {mem_req['total_4bit_gb']:.1f} GB")
    print(f"  Throughput impact: ~30% slower at 128K vs 32K")
    print(f"  Recommended GPU: A100 80GB or H100 80GB")
    print("\nTest completed successfully!")
    print("=" * 60)

    return 0

if __name__ == "__main__":
    sys.exit(main())