import argparse import time import random from itertools import chain from types import SimpleNamespace from loguru import logger import numpy as np import torch from rich import print from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache from model import DFlashDraftModel, sample, load_and_process_dataset, extract_context_feature import distributed as dist def cuda_time() -> float: torch.cuda.synchronize() return time.perf_counter() @torch.inference_mode() def dflash_generate( model: DFlashDraftModel, target: AutoModelForCausalLM, input_ids: torch.Tensor, mask_token_id: int, max_new_tokens: int, block_size: int, stop_token_ids: list[int], temperature: float = 0.0, ) -> SimpleNamespace: num_input_tokens = input_ids.shape[1] max_length = num_input_tokens + max_new_tokens output_ids = torch.full( (1, max_length + block_size), mask_token_id, dtype=torch.long, device=model.device, ) position_ids = torch.arange(output_ids.shape[1], device=model.device).unsqueeze(0) past_key_values_target = DynamicCache() past_key_values_draft = DynamicCache() # Prefill stage prefill_start = cuda_time() output = target( input_ids, position_ids=position_ids[:, :num_input_tokens], past_key_values=past_key_values_target, use_cache=True, logits_to_keep=1, output_hidden_states=True if block_size > 1 else False, ) output_ids[:, :num_input_tokens] = input_ids output_ids[:, num_input_tokens:num_input_tokens+1] = sample(output.logits, temperature) if block_size > 1: target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids) time_to_first_token = cuda_time() - prefill_start # Decode stage decode_start = cuda_time() start = input_ids.shape[1] acceptance_lengths = [] draft_prefill = True while start < max_length: block_output_ids = output_ids[:, start : start + block_size].clone() block_position_ids = position_ids[:, start : start + block_size] if block_size > 1: noise_embedding = target.model.embed_tokens(block_output_ids) draft_logits = target.lm_head(model( target_hidden=target_hidden, noise_embedding=noise_embedding, position_ids=position_ids[:, past_key_values_draft.get_seq_length(): start + block_size], past_key_values=past_key_values_draft, use_cache=True, is_causal=False, )[:, -block_size+1:, :]) past_key_values_draft.crop(start) block_output_ids[:, 1:] = sample(draft_logits) if draft_prefill: draft_prefill = False decode_start = cuda_time() output = target( block_output_ids, position_ids=block_position_ids, past_key_values=past_key_values_target, use_cache=True, output_hidden_states=True if block_size > 1 else False, ) posterior = sample(output.logits, temperature) acceptance_length = (block_output_ids[:, 1:] == posterior[:, :-1]).cumprod(dim=1).sum(dim=1)[0].item() output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1] output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length] acceptance_lengths.append(acceptance_length+1) start += acceptance_length + 1 past_key_values_target.crop(start) if block_size > 1: target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids)[:, :acceptance_length + 1, :] if stop_token_ids is not None and any( stop_token_id in output_ids[:, num_input_tokens:] for stop_token_id in stop_token_ids ): break output_ids = output_ids[:, :max_length] output_ids = output_ids[:, output_ids[0] != mask_token_id] if stop_token_ids is not None: stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) stop_token_indices = torch.isin(output_ids[0][num_input_tokens:], stop_token_ids).nonzero(as_tuple=True)[0] if stop_token_indices.numel() > 0: output_ids = output_ids[:, : num_input_tokens + stop_token_indices[0] + 1] num_output_tokens = output_ids.shape[1] - num_input_tokens total_decode_time = cuda_time() - decode_start time_per_output_token = total_decode_time / num_output_tokens return SimpleNamespace( output_ids=output_ids, num_input_tokens=num_input_tokens, num_output_tokens=num_output_tokens, time_to_first_token=time_to_first_token, time_per_output_token=time_per_output_token, acceptance_lengths=acceptance_lengths, ) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--model-name-or-path", type=str, required=True) parser.add_argument("--draft-name-or-path", type=str, required=True) parser.add_argument("--block-size", type=int, default=None) parser.add_argument("--dataset", type=str, required=True) parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--max-new-tokens", type=int, default=16384) parser.add_argument("--temperature", type=float, default=0.0) args = parser.parse_args() random.seed(0) np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed_all(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False dist.init() torch.cuda.set_device(dist.local_rank()) device = torch.device(f"cuda:{dist.local_rank()}") def has_flash_attn(): try: import flash_attn return True except ImportError: logger.warning("flash_attn is not installed. Falling back to torch.sdpa. The speedup will be lower.") return False installed_flash_attn = has_flash_attn() target = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, attn_implementation="flash_attention_2" if installed_flash_attn else "sdpa", dtype=torch.bfloat16, ).to(device).eval() draft_model = DFlashDraftModel.from_pretrained( args.draft_name_or_path, attn_implementation="flash_attention_2" if installed_flash_attn else "sdpa", dtype=torch.bfloat16, ).to(device).eval() block_size = args.block_size if args.block_size is not None else draft_model.block_size tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) dataset = load_and_process_dataset(args.dataset) if args.max_samples is not None and len(dataset) > args.max_samples: dataset = dataset.shuffle(seed=0).select(range(args.max_samples)) responses = [] indices = range(dist.rank(), len(dataset), dist.size()) for idx in tqdm(indices, disable=not dist.is_main()): instance = dataset[idx] messages = [] for turn_index, user_content in enumerate(instance["turns"]): messages.append({"role": "user", "content": user_content}) input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) input_ids = tokenizer.encode(input_text, return_tensors="pt").to(target.device) response = {} for bs in [1, block_size]: response[bs] = dflash_generate( model=draft_model, target=target, input_ids=input_ids, mask_token_id=draft_model.mask_token_id, max_new_tokens=args.max_new_tokens, block_size=bs, stop_token_ids=[tokenizer.eos_token_id], temperature=args.temperature, ) spec_response = response[block_size] generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens:] output_text = tokenizer.decode(generated_ids, skip_special_tokens=True) messages.append({"role": "assistant", "content": output_text}) responses.append(response) if dist.size() > 1: responses = dist.gather(responses, dst=0) if not dist.is_main(): return responses = list(chain(*responses)) t1 = np.mean([r[1].time_per_output_token for r in responses]) tb = np.mean([r[block_size].time_per_output_token for r in responses]) print(f"Decoding speedup: {t1 / tb:.2f}") tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses]) print(f"Average Acceptance length: {tau:.2f}") acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses])) histogram = [acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1)] print(f"Acceptance length histogram: {[f'{x * 100:.1f}%' for x in histogram]}") if __name__ == "__main__": main()