| | from transformers import AutoTokenizer, GPT2LMHeadModel |
| | from datasets import load_dataset, Dataset, DatasetDict |
| | import random |
| | import string |
| | import torch |
| |
|
| | from torchmetrics.text import WordErrorRate, CharErrorRate |
| |
|
| | wer = WordErrorRate() |
| | cer = CharErrorRate() |
| |
|
| | def process(text): |
| |
|
| | |
| | text = text.lower() |
| |
|
| | |
| | punctuation_to_remove = string.punctuation.replace("'", "") |
| | translation_table = str.maketrans('', '', punctuation_to_remove) |
| | text = text.translate(translation_table) |
| |
|
| | |
| | while text[0] == ' ' or text[-1] == ' ': |
| | if text[0] == ' ': |
| | text = text[1:] |
| | if text[-1] == ' ': |
| | text = text[:-1] |
| | |
| | return text |
| |
|
| | import jiwer |
| | from edit_distance import SequenceMatcher |
| | def correct_text(text): |
| | transforms = jiwer.Compose( |
| | [ |
| | jiwer.ExpandCommonEnglishContractions(), |
| | jiwer.ToLowerCase(), |
| | jiwer.RemoveMultipleSpaces(), |
| | jiwer.Strip(), |
| | jiwer.RemovePunctuation(), |
| | jiwer.ReduceToListOfListOfWords(), |
| | ] |
| | ) |
| | return transforms(text) |
| |
|
| | def align_gt_asr(gt, asr): |
| | sm = SequenceMatcher(a=gt, b=asr) |
| | best_path = [] |
| | opcodes = sm.get_opcodes() |
| | for tag, i1, i2, j1, j2 in opcodes: |
| | if tag == "delete": |
| | for i in range(i1, i2): |
| | best_path.append([gt[i], ""]) |
| | if tag == "replace" or tag == "equal": |
| | for i, j in zip(range(i1, i2), range(j1, j2)): |
| | best_path.append([gt[i], asr[j]]) |
| | if tag == "insert": |
| | for j in range(j1, j2): |
| | best_path.append(["", asr[j]]) |
| | return best_path |
| |
|
| | dtype = torch.float16 |
| |
|
| | dataset_name = "./../libripseech_tokenized" |
| | dataset = DatasetDict.load_from_disk(dataset_name) |
| |
|
| | with open("./../prompting/blist/all_rare_words.txt") as fin: |
| | rarewords = [process(word.strip()) for word in fin] |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("./../tokenizer") |
| | tokenizer.pad_token_id = 0 |
| | tokenizer.pad_token = "<|padding|>" |
| | tokenizer.padding_side = "left" |
| |
|
| | |
| | tokenizer.add_tokens(["<|startofprompt|>", "<|sepofprompt|>", "<|endofprompt|>"]) |
| | sot_token = tokenizer.encode("<|startoftranscript|>")[0] |
| | eot_token = tokenizer.encode("<|endoftranscript|>")[0] |
| |
|
| | from math import ceil |
| | from tqdm import tqdm |
| |
|
| | val_bs = 32 |
| | n_bwords = 25 |
| | context_length = 2048 |
| |
|
| | def prepare(element): |
| | |
| | |
| | audio_tkns = element["audio_tokens"] |
| | data = "".join([f"<|audio:{tkn}|>" for tkn in audio_tkns]) |
| | |
| | |
| | b_words = element["b_words"] |
| | if n_bwords > len(b_words): |
| | context = b_words + random.sample(rarewords, n_bwords - len(b_words)) |
| | else: |
| | context = random.sample(b_words, n_bwords) |
| | random.shuffle(context) |
| | |
| | |
| | data += "<|startofprompt|>" + "<|sepofprompt|>".join(context) + "<|endofprompt|>" |
| | |
| | |
| | data += "<|startoftranscript|>" |
| | |
| | return {"data": data, "context": context} |
| |
|
| | @torch.no_grad() |
| | def evaluate_model(model): |
| |
|
| | transcripts = [] |
| | |
| | processed_data = dataset["test.clean"].map(prepare) |
| | data = processed_data["data"] |
| |
|
| | for idx in tqdm(range(ceil(len(data)/val_bs))): |
| |
|
| | outputs = tokenizer(data[idx * val_bs: (idx + 1) * val_bs], truncation=False, max_length=None, padding=True, return_tensors="pt").to(model.device) |
| | input_ids = outputs["input_ids"] |
| | par = input_ids.shape[-1] |
| |
|
| | generations = model.generate( |
| | input_ids, |
| | max_new_tokens=context_length - par - 1, |
| | eos_token_id = eot_token |
| | ) |
| | transcripts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True) |
| | |
| | bias_word_cnt = 0 |
| | normal_word_cnt = 0 |
| | u_wer = 0.0 |
| | b_wer = 0.0 |
| | pred_list = correct_text(transcripts) |
| | text_list = correct_text(processed_data["text"]) |
| | prompt_list = processed_data["context"] |
| | for a, b, c in zip(pred_list, text_list, prompt_list): |
| | aligned_pair = align_gt_asr(b, a) |
| | for gt_word, asr_word in aligned_pair: |
| | if gt_word in c or asr_word in c: |
| | if gt_word != asr_word: |
| | b_wer += 1.0 |
| | if gt_word in c: |
| | bias_word_cnt += 1 |
| | else: |
| | if gt_word != asr_word: |
| | u_wer += 1.0 |
| | if gt_word != "": |
| | normal_word_cnt += 1 |
| | u_wer = u_wer / normal_word_cnt * 100 |
| | b_wer = b_wer / bias_word_cnt * 100 |
| | |
| | return wer(transcripts, processed_data["text"]).item() * 100, cer(transcripts, processed_data["text"]).item() * 100, b_wer, u_wer |