| | import entrypoint_setup |
| |
|
| | import argparse |
| | import copy |
| | import random |
| | import time |
| | from pathlib import Path |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import pandas as pd |
| | import seaborn as sns |
| | import torch |
| | from tqdm.auto import tqdm |
| | from transformers import AutoModelForMaskedLM |
| |
|
| |
|
| | def set_seed(seed: int): |
| | random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | np.random.seed(seed) |
| |
|
| |
|
| | SUPPORTED_BACKENDS = ("sdpa", "flex", "kernels_flash") |
| |
|
| |
|
| | class ThroughputChecker: |
| | def __init__( |
| | self, |
| | warmup_batches: int = 10, |
| | timed_batches: int = 100, |
| | ): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.warmup_batches = warmup_batches |
| | self.timed_batches = timed_batches |
| | self.canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY" |
| |
|
| | def _load_model(self, model_path: str): |
| | model = AutoModelForMaskedLM.from_pretrained( |
| | model_path, |
| | dtype=torch.bfloat16, |
| | device_map=self.device, |
| | trust_remote_code=True, |
| | ).eval() |
| | return model |
| |
|
| | def _generate_random_sequence(self, length: int) -> str: |
| | return "M" + "".join(random.choices(self.canonical_amino_acids, k=length - 1)) |
| |
|
| | def _generate_random_batch(self, batch_size: int, min_length: int, max_length: int) -> list[str]: |
| | max_length_example = self._generate_random_sequence(max_length) |
| | return [max_length_example] + [ |
| | self._generate_random_sequence(random.randint(min_length, max_length)) |
| | for _ in range(batch_size - 1) |
| | ] |
| |
|
| | @torch.inference_mode() |
| | def _time(self, model, tokenizer, batch_size: int, min_length: int, max_length: int): |
| | model = model.to(self.device).eval() |
| | set_seed(42) |
| | min_dynamic_warmup_batches = self.warmup_batches |
| | max_dynamic_warmup_batches = self.warmup_batches * 10 |
| | stability_window = 3 |
| | relative_stability_tolerance = 0.10 |
| |
|
| | def synchronize(): |
| | if self.device.type == "cuda": |
| | torch.cuda.synchronize() |
| |
|
| | def run_one_batch() -> int: |
| | batch = self._generate_random_batch(batch_size, min_length, max_length) |
| | tokenized = tokenizer( |
| | batch, |
| | return_tensors="pt", |
| | padding="max_length", |
| | max_length=max_length, |
| | truncation=True, |
| | add_special_tokens=True, |
| | ) |
| | input_ids = tokenized["input_ids"] |
| | if "attention_mask" in tokenized: |
| | nonpad_tokens_this = tokenized["attention_mask"].sum().item() |
| | else: |
| | pad_token_id = tokenizer.pad_token_id |
| | if pad_token_id is not None: |
| | nonpad_tokens_this = (input_ids != pad_token_id).sum().item() |
| | else: |
| | nonpad_tokens_this = input_ids.numel() |
| | tokenized = {k: v.to(self.device) for k, v in tokenized.items()} |
| | _ = model(**tokenized, output_hidden_states=True) |
| | return nonpad_tokens_this |
| |
|
| | def time_batches(num_batches: int, message: str): |
| | processed_tokens = 0 |
| | synchronize() |
| | start_time = time.time() |
| | for _ in tqdm(range(num_batches), desc=message, leave=False): |
| | processed_tokens += run_one_batch() |
| | synchronize() |
| | end_time = time.time() |
| | return end_time - start_time, processed_tokens |
| |
|
| | |
| | model = torch.compile(model) |
| | warmup_latencies = [] |
| | for warmup_idx in tqdm(range(max_dynamic_warmup_batches), desc="Warmup (dynamic)", leave=False): |
| | synchronize() |
| | warmup_start = time.time() |
| | _ = run_one_batch() |
| | synchronize() |
| | warmup_latency = time.time() - warmup_start |
| | warmup_latencies.append(warmup_latency) |
| |
|
| | if warmup_idx + 1 < min_dynamic_warmup_batches: |
| | continue |
| | if len(warmup_latencies) < 2 * stability_window: |
| | continue |
| |
|
| | previous_window = warmup_latencies[-2 * stability_window:-stability_window] |
| | current_window = warmup_latencies[-stability_window:] |
| | previous_mean = sum(previous_window) / stability_window |
| | current_mean = sum(current_window) / stability_window |
| | assert previous_mean > 0.0, "Warmup latency mean should be positive." |
| | relative_change = abs(current_mean - previous_mean) / previous_mean |
| | if relative_change <= relative_stability_tolerance: |
| | break |
| |
|
| | time_taken, timed_tokens_sum = time_batches(self.timed_batches, "Timed") |
| | if self.device.type == "cuda": |
| | torch.cuda.empty_cache() |
| | return time_taken, timed_tokens_sum |
| |
|
| | def evaluate(self, model_path: str, batch_sizes: list[int], min_length: int, sequence_lengths: list[int], backends: list[str]): |
| | results = {backend: {} for backend in backends} |
| |
|
| | original_model = self._load_model(model_path) |
| | tokenizer = original_model.tokenizer |
| |
|
| | for backend in backends: |
| | print(f"Benchmarking {model_path} with backend={backend}") |
| | try: |
| | backend_model = copy.deepcopy(original_model) |
| | backend_model.attn_backend = backend |
| | except AssertionError as error: |
| | print(f"Skipping backend '{backend}' for {model_path}: {error}") |
| | continue |
| |
|
| | for bs in batch_sizes: |
| | for max_length in sequence_lengths: |
| | model_copy = copy.deepcopy(backend_model) |
| | time_taken, tokens = self._time( |
| | model_copy, |
| | tokenizer, |
| | bs, |
| | min_length, |
| | max_length, |
| | ) |
| | results[backend][(bs, max_length)] = {"time": time_taken, "tokens": tokens} |
| |
|
| | original_model.cpu() |
| | del original_model |
| | if self.device.type == "cuda": |
| | torch.cuda.empty_cache() |
| | return results |
| |
|
| |
|
| | def plot_results(all_results: dict, output_path: str): |
| | sns.set_theme(style="whitegrid") |
| | plot_data = [] |
| |
|
| | for model_path, results in all_results.items(): |
| | model_name = Path(model_path).name |
| | for backend in sorted(results.keys()): |
| | for (bs, max_length), entry in results[backend].items(): |
| | time_taken = entry["time"] |
| | nonpad_tokens = entry["tokens"] |
| | tokens_per_sec = nonpad_tokens / time_taken if time_taken > 0 else 0.0 |
| | plot_data.append( |
| | { |
| | "Model": model_name, |
| | "Backend": backend, |
| | "Batch": bs, |
| | "SeqLen": max_length, |
| | "TokensPerSec": tokens_per_sec, |
| | "NonPadTokens": nonpad_tokens, |
| | "Seconds": time_taken, |
| | } |
| | ) |
| |
|
| | if not plot_data: |
| | return |
| |
|
| | plot_df = pd.DataFrame(plot_data) |
| | sequence_lengths = sorted(plot_df["SeqLen"].dropna().unique().tolist()) |
| |
|
| | plot = sns.relplot( |
| | data=plot_df, |
| | x="SeqLen", |
| | y="TokensPerSec", |
| | hue="Backend", |
| | style="Batch", |
| | kind="line", |
| | marker="o", |
| | dashes=False, |
| | col="Model", |
| | col_wrap=1, |
| | height=4.5, |
| | aspect=1.5, |
| | facet_kws={"sharey": False}, |
| | ) |
| | plot.set_titles("{col_name}") |
| | plot.set(xticks=sequence_lengths) |
| | plot.set_axis_labels("Sequence length", "Non-pad tokens/s") |
| | plot.figure.suptitle("Throughput comparison by model") |
| | plot.tight_layout() |
| | plot.figure.subplots_adjust(top=0.93, right=0.95, bottom=0.06) |
| | plot.add_legend(title="Backend / Batch") |
| | plt.savefig(output_path, dpi=300) |
| | print(f"Results saved to {output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--hf_token", type=str, default=None) |
| | parser.add_argument( |
| | "--model_paths", |
| | nargs="+", |
| | default=["Synthyra/ESM2-8M", "Synthyra/ESMplusplus_small"], |
| | ) |
| | parser.add_argument("--batch_sizes", nargs="+", type=int, default=[2, 4, 8]) |
| | parser.add_argument("--sequence_lengths", nargs="+", type=int, default=[64, 128, 256, 512, 1024, 2048]) |
| | parser.add_argument("--backends", nargs="+", choices=SUPPORTED_BACKENDS, default=list(SUPPORTED_BACKENDS)) |
| | parser.add_argument("--min_length", type=int, default=32) |
| | parser.add_argument("--warmup_batches", type=int, default=10) |
| | parser.add_argument("--timed_batches", type=int, default=100) |
| | parser.add_argument("--output_path", type=str, default="throughput_comparison.png") |
| | args = parser.parse_args() |
| |
|
| | if args.hf_token: |
| | from huggingface_hub import login |
| |
|
| | login(token=args.hf_token) |
| |
|
| | checker = ThroughputChecker(warmup_batches=args.warmup_batches, timed_batches=args.timed_batches) |
| |
|
| | all_results = {} |
| | for model_path in args.model_paths: |
| | all_results[model_path] = checker.evaluate( |
| | model_path, |
| | args.batch_sizes, |
| | min_length=args.min_length, |
| | sequence_lengths=args.sequence_lengths, |
| | backends=args.backends, |
| | ) |
| |
|
| | plot_results(all_results, args.output_path) |
| |
|