| import argparse |
| import time |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch._dynamo as dynamo |
| from transformers import LlamaConfig |
| from transformers.cache_utils import DynamicCache |
|
|
| from specforge.modeling.draft.llama3_eagle import ( |
| LlamaAttention, |
| LlamaFlexAttention, |
| prepare_decoder_attention_mask, |
| ) |
|
|
| dynamo.config.recompile_limit = 64 |
|
|
| config_dict = { |
| "hidden_size": 4096, |
| "num_attention_heads": 32, |
| "num_key_value_heads": 8, |
| "max_position_embeddings": 16384, |
| "rms_norm_eps": 1e-05, |
| "vocab_size": 32000, |
| "hidden_act": "silu", |
| "num_hidden_layers": 1, |
| } |
|
|
| config = LlamaConfig(**config_dict) |
|
|
| TTT_LENGTH = 7 |
| BATCH_SIZE = 4 |
| HIDDEN_SIZE = config.hidden_size * 2 |
|
|
|
|
| def run_attention( |
| seq_len: int, |
| hidden_states_list: list[torch.Tensor], |
| attention_backend: str = "sdpa", |
| enable_profile: bool = False, |
| ): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| batch_size = hidden_states_list[0].shape[0] |
| |
| if attention_backend == "sdpa": |
| cache_hidden = [[], []] |
| past_key_values = None |
| attn_func = LlamaAttention(config).to(device).to(torch.bfloat16) |
| elif attention_backend == "flex_attention": |
| cache_hidden = None |
| past_key_values = DynamicCache() |
| attn_func = LlamaFlexAttention(config).to(device).to(torch.bfloat16) |
| else: |
| raise ValueError(f"Unknown attention backend: {attention_backend}") |
|
|
| |
| position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to(device) |
| input_embeds = torch.randn(batch_size, seq_len, config.hidden_size).to(device) |
| attention_mask = torch.ones(batch_size, seq_len).to(device) |
| decoder_attention_mask = prepare_decoder_attention_mask( |
| attention_mask=attention_mask, |
| input_shape=(batch_size, seq_len), |
| inputs_embeds=input_embeds, |
| past_key_values_length=0, |
| ) |
|
|
| loss_list = [] |
|
|
| if attention_backend == "flex_attention" and enable_profile: |
| profiler = torch.profiler.profile( |
| activities=[ |
| torch.profiler.ProfilerActivity.CPU, |
| torch.profiler.ProfilerActivity.CUDA, |
| ], |
| on_trace_ready=torch.profiler.tensorboard_trace_handler( |
| f"./profiler_logs/{attention_backend}" |
| ), |
| record_shapes=False, |
| profile_memory=False, |
| with_stack=True, |
| with_modules=False, |
| ) |
| profiler.start() |
| for idx in range(TTT_LENGTH): |
| is_last = idx == TTT_LENGTH - 1 |
| hidden_states = hidden_states_list[idx] |
| |
| if attention_backend == "sdpa": |
| output = attn_func( |
| hidden_states=hidden_states, |
| attention_mask=decoder_attention_mask, |
| position_ids=position_ids, |
| cache_hidden=cache_hidden, |
| output_attentions=False, |
| use_cache=True, |
| ) |
| else: |
| output = attn_func( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=False, |
| use_cache=True, |
| ) |
|
|
| |
| loss = output[0].sum() |
| loss_list.append(loss) |
|
|
| |
| if loss_list: |
| mean_loss = sum(loss_list) / len(loss_list) |
| mean_loss.backward() |
|
|
| if attention_backend == "flex_attention" and enable_profile: |
| profiler.stop() |
|
|
|
|
| def benchmark_function( |
| attention_backend: str, |
| seq_lengths: list, |
| enable_profile: bool = False, |
| enable_warmup: bool = True, |
| ): |
| """Benchmark a function for speed and GPU memory usage per sequence length.""" |
| print(f"\n=== Benchmarking {attention_backend} ===") |
|
|
| results_per_seq_len = [] |
|
|
| for seq_len in seq_lengths: |
| print(f"\nTesting sequence length: {seq_len}") |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| |
| if enable_warmup: |
| print("Warming up...") |
| for _ in range(2): |
| hidden_states = [ |
| torch.randn( |
| BATCH_SIZE, |
| seq_len, |
| HIDDEN_SIZE, |
| requires_grad=True, |
| device="cuda", |
| dtype=torch.bfloat16, |
| ) |
| for _ in range(TTT_LENGTH) |
| ] |
| run_attention(seq_len, hidden_states, attention_backend) |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
| |
| initial_memory = 0 |
| if torch.cuda.is_available(): |
| initial_memory = torch.cuda.memory_allocated() |
| hidden_states = [ |
| torch.randn( |
| BATCH_SIZE, |
| seq_len, |
| HIDDEN_SIZE, |
| requires_grad=True, |
| device="cuda", |
| dtype=torch.bfloat16, |
| ) |
| for _ in range(TTT_LENGTH) |
| ] |
| start_time = time.time() |
| run_attention( |
| seq_len, |
| hidden_states, |
| attention_backend, |
| enable_profile and seq_len == seq_lengths[0], |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| end_time = time.time() |
|
|
| |
| peak_memory = 0 |
| current_memory = 0 |
| if torch.cuda.is_available(): |
| peak_memory = torch.cuda.max_memory_allocated() |
| current_memory = torch.cuda.memory_allocated() |
| results_per_seq_len.append( |
| { |
| "seq_len": seq_len, |
| "time": end_time - start_time, |
| "peak_memory": peak_memory, |
| "memory_increase": current_memory - initial_memory, |
| } |
| ) |
|
|
| print(f" Time: {end_time - start_time:.3f}s") |
| print(f" Peak memory: {peak_memory / 1024**3:.3f} GB") |
| print( |
| f" Memory increase: {(current_memory - initial_memory) / 1024**3:.3f} GB" |
| ) |
|
|
| return results_per_seq_len |
|
|
|
|
| def plot_results(eagle_results, flex_results, seq_lengths): |
| """Plot speed and memory comparison between Eagle and Flex attention.""" |
|
|
| |
| eagle_times = [r["time"] for r in eagle_results] |
| flex_times = [r["time"] for r in flex_results] |
| eagle_memory = [r["peak_memory"] / 1024**3 for r in eagle_results] |
| flex_memory = [r["peak_memory"] / 1024**3 for r in flex_results] |
|
|
| |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) |
|
|
| |
| ax1.plot( |
| seq_lengths, eagle_times, "b-o", label="Eagle (SDPA)", linewidth=2, markersize=8 |
| ) |
| ax1.plot( |
| seq_lengths, |
| flex_times, |
| "r-s", |
| label="Flex Attention", |
| linewidth=2, |
| markersize=8, |
| ) |
| ax1.set_xlabel("Sequence Length") |
| ax1.set_ylabel("Time (seconds)") |
| ax1.set_title("Speed Comparison: Eagle vs Flex Attention") |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
| ax1.set_xscale("linear") |
| ax1.set_yscale("log") |
|
|
| |
| ax2.plot( |
| seq_lengths, |
| eagle_memory, |
| "b-o", |
| label="Eagle (SDPA)", |
| linewidth=2, |
| markersize=8, |
| ) |
| ax2.plot( |
| seq_lengths, |
| flex_memory, |
| "r-s", |
| label="Flex Attention", |
| linewidth=2, |
| markersize=8, |
| ) |
| ax2.set_xlabel("Sequence Length") |
| ax2.set_ylabel("Peak Memory (GB)") |
| ax2.set_title("Memory Usage Comparison: Eagle vs Flex Attention") |
| ax2.legend() |
| ax2.grid(True, alpha=0.3) |
|
|
| |
| max_memory = max(max(eagle_memory), max(flex_memory)) |
| ax2.set_yticks(np.arange(0, max_memory + 10, 10)) |
|
|
| plt.tight_layout() |
| plt.savefig("attention_benchmark_comparison.png", dpi=300, bbox_inches="tight") |
| plt.show() |
|
|
| |
| print(f"\n=== Performance Summary ===") |
| print(f"Sequence lengths tested: {seq_lengths}") |
| print(f"\nSpeed ratios (Eagle/Flex):") |
| for i, seq_len in enumerate(seq_lengths): |
| ratio = eagle_times[i] / flex_times[i] if flex_times[i] > 0 else float("inf") |
| print(f" {seq_len:4d}: {ratio:.2f}x") |
|
|
| print(f"\nMemory ratios (Eagle/Flex):") |
| for i, seq_len in enumerate(seq_lengths): |
| ratio = eagle_memory[i] / flex_memory[i] if flex_memory[i] > 0 else float("inf") |
| print(f" {seq_len:4d}: {ratio:.2f}x") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Benchmark attention mechanisms") |
| parser.add_argument( |
| "--enable-profile", action="store_true", help="Enable profiling" |
| ) |
| args = parser.parse_args() |
|
|
| print("PyTorch version:", torch.__version__) |
| if torch.cuda.is_available(): |
| print("CUDA available:", torch.cuda.is_available()) |
| print("GPU:", torch.cuda.get_device_name()) |
| print( |
| "GPU memory:", |
| torch.cuda.get_device_properties(0).total_memory / 1024**3, |
| "GB", |
| ) |
| else: |
| print("CUDA not available - running on CPU") |
|
|
| |
| seq_lengths = [128 * i for i in range(1, 28, 4)] |
| |
| seq_lengths.extend([16384, 32768]) |
|
|
| print(f"Testing sequence lengths: {seq_lengths}") |
|
|
| |
| print("\n" + "=" * 50) |
| |
| eagle_seq_lengths = [seq_len for seq_len in seq_lengths if seq_len <= 2560] |
| eagle_results = benchmark_function("sdpa", eagle_seq_lengths) |
| print("\n" + "=" * 50) |
| flex_results = benchmark_function( |
| "flex_attention", seq_lengths, enable_profile=args.enable_profile |
| ) |
| |
| max_time = max(result["time"] for result in flex_results) |
| for result in flex_results: |
| if result["seq_len"] not in eagle_seq_lengths: |
| eagle_results.append( |
| { |
| "seq_len": result["seq_len"], |
| "time": max_time, |
| "peak_memory": 80 * 1024**3, |
| "memory_increase": 0, |
| } |
| ) |
|
|
| |
| plot_results(eagle_results, flex_results, seq_lengths) |
|
|