| | |
| | |
| | import copy |
| | import time |
| |
|
| | import torch |
| | from datasets import load_dataset |
| | from transformers import ( |
| | AutoProcessor, |
| | WhisperForConditionalGeneration, |
| | ) |
| |
|
| |
|
| | DEVICE = "cuda" |
| | DTYPE = torch.float16 |
| | SAMPLING_RATE = 16_000 |
| | BATCH_SIZE = 1 |
| | USE_FLASH_ATTN_2 = True |
| |
|
| | |
| | GAMMAS = [5, 7, 6, 5, 4, 3, 5] |
| | COUNT = 0 |
| |
|
| | |
| | teacher = WhisperForConditionalGeneration.from_pretrained( |
| | "/home/patrick/distil_whisper/", |
| | torch_dtype=DTYPE, |
| | variant="fp16", |
| | low_cpu_mem_usage=True, |
| | use_flash_attention_2=USE_FLASH_ATTN_2, |
| | ) |
| | student = WhisperForConditionalGeneration.from_pretrained( |
| | "/home/patrick/distil_whisper_student/", |
| | torch_dtype=DTYPE, |
| | variant="fp16", |
| | low_cpu_mem_usage=True, |
| | use_flash_attention_2=USE_FLASH_ATTN_2, |
| | ) |
| | |
| |
|
| | student.generation_config = copy.deepcopy(teacher.generation_config) |
| | student.generation_config.num_assistant_tokens_schedule = "constant" |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | teacher.to(DEVICE) |
| | student.to(DEVICE) |
| |
|
| | processor = AutoProcessor.from_pretrained("sanchit-gandhi/large-32-2-gpu-flat-lr") |
| |
|
| | ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
| |
|
| | total_time_default = 0 |
| | total_time_spec = 0 |
| | total_time_spec_2 = 0 |
| |
|
| | input_values = ds[0]["audio"]["array"] |
| | inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE) |
| | input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE) |
| |
|
| | _ = teacher.generate(input_features, max_length=100) |
| |
|
| | end_idx = ds.shape[0] |
| | for audio_idx in range(0, end_idx, BATCH_SIZE): |
| | input_values = ds[audio_idx : audio_idx + BATCH_SIZE] |
| | input_values = [i["array"] for i in input_values["audio"]] |
| |
|
| | inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE) |
| | input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE) |
| |
|
| | start_time = time.time() |
| | out = teacher.generate(input_features, max_length=100) |
| | run_time = time.time() - start_time |
| | print(f"Normal Decoding: {run_time}") |
| | total_time_default += run_time |
| |
|
| | default_out = processor.batch_decode(out, skip_special_tokens=True) |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | start_time = time.time() |
| | with torch.no_grad(): |
| | encoder_outputs = teacher.get_encoder()(input_features) |
| |
|
| | out = teacher.generate( |
| | assistant_model=student, |
| | assistant_encoder_outputs=encoder_outputs, |
| | encoder_outputs=encoder_outputs, |
| | max_length=100, |
| | ) |
| | run_time = time.time() - start_time |
| |
|
| | spec_out_2 = processor.batch_decode(out, skip_special_tokens=True) |
| |
|
| | print(f"Speculative Decoding 2: {run_time}") |
| | total_time_spec_2 += run_time |
| |
|
| | if spec_out_2 != default_out: |
| | COUNT += 1 |
| | print(f"Audio {audio_idx} does not match. Spec: {spec_out_2}, True: {default_out}") |
| |
|
| |
|
| | print(20 * "=") |
| | print("Total time", total_time_default) |
| | print(f"Overall speed-up spec 2 {total_time_default / total_time_spec_2}") |
| | |
| |
|