| | |
| | from datasets import load_dataset, Audio |
| | from transformers import ( |
| | WhisperProcessor, |
| | WhisperForConditionalGeneration, |
| | Seq2SeqTrainingArguments, |
| | Seq2SeqTrainer |
| | ) |
| | import torch |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Union |
| | from functools import partial |
| | import evaluate |
| |
|
| | |
| | dataset = load_dataset("") |
| | dataset |
| |
|
| | |
| | split_dataset = dataset['train'].train_test_split(test_size=0.2) |
| | split_dataset |
| |
|
| | |
| | split_dataset['train'] = split_dataset['train'].select_columns(["audio", "sentence"]) |
| | split_dataset['train'] |
| |
|
| | |
| | processor = WhisperProcessor.from_pretrained( |
| | "openai/whisper-small", |
| | language="swahili", |
| | task="transcribe" |
| | ) |
| |
|
| | |
| | print('BEFORE>>> ', split_dataset['train'].features['audio']) |
| | sampling_rate = processor.feature_extractor.sampling_rate |
| | split_dataset['train'] = split_dataset['train'].cast_column( |
| | "audio", |
| | Audio(sampling_rate=sampling_rate) |
| | ) |
| | print('AFTER>>> ', split_dataset['train'].features['audio']) |
| |
|
| | |
| | print('BEFORE>>> ', split_dataset['test'].features['audio']) |
| | split_dataset['test'] = split_dataset['test'].cast_column( |
| | "audio", |
| | Audio(sampling_rate=sampling_rate) |
| | ) |
| | print('AFTER>>> ', split_dataset['test'].features['audio']) |
| |
|
| | def prepare_dataset(example): |
| | """Preprocess audio and text data for Whisper model training""" |
| | audio = example["audio"] |
| | |
| | |
| | example = processor( |
| | audio=audio["array"], |
| | sampling_rate=audio["sampling_rate"], |
| | text=example["sentence"], |
| | ) |
| | |
| | |
| | example["input_length"] = len(audio["array"]) / audio["sampling_rate"] |
| | |
| | return example |
| |
|
| | |
| | split_dataset['train'] = split_dataset['train'].map( |
| | prepare_dataset, |
| | remove_columns=split_dataset['train'].column_names, |
| | num_proc=4 |
| | ) |
| |
|
| | split_dataset['test'] = split_dataset['test'].map( |
| | prepare_dataset, |
| | remove_columns=split_dataset['test'].column_names, |
| | num_proc=1 |
| | ) |
| |
|
| | |
| | max_input_length = 30.0 |
| | def is_audio_in_length_range(length): |
| | return length < max_input_length |
| |
|
| | split_dataset['train'] = split_dataset['train'].filter( |
| | is_audio_in_length_range, |
| | input_columns=["input_length"], |
| | ) |
| |
|
| | @dataclass |
| | class DataCollatorSpeechSeq2SeqWithPadding: |
| | """Custom data collator for Whisper speech-to-sequence tasks with padding""" |
| | processor: Any |
| |
|
| | def __call__( |
| | self, features: List[Dict[str, Union[List[int], torch.Tensor]]] |
| | ) -> Dict[str, torch.Tensor]: |
| | |
| | |
| | input_features = [ |
| | {"input_features": feature["input_features"][0]} for feature in features |
| | ] |
| | batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") |
| |
|
| | |
| | label_features = [{"input_ids": feature["labels"]} for feature in features] |
| | labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") |
| |
|
| | |
| | labels = labels_batch["input_ids"].masked_fill( |
| | labels_batch.attention_mask.ne(1), -100 |
| | ) |
| |
|
| | |
| | if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): |
| | labels = labels[:, 1:] |
| |
|
| | batch["labels"] = labels |
| |
|
| | return batch |
| |
|
| | |
| | data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) |
| |
|
| | |
| | metric = evaluate.load("wer") |
| |
|
| | |
| | from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
| | normalizer = BasicTextNormalizer() |
| |
|
| | def compute_metrics(pred): |
| | """Compute WER (Word Error Rate) metrics for evaluation""" |
| | pred_ids = pred.predictions |
| | label_ids = pred.label_ids |
| |
|
| | |
| | label_ids[label_ids == -100] = processor.tokenizer.pad_token_id |
| |
|
| | |
| | pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
| | label_str = processor.batch_decode(label_ids, skip_special_tokens=True) |
| |
|
| | |
| | wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str) |
| |
|
| | |
| | pred_str_norm = [normalizer(pred) for pred in pred_str] |
| | label_str_norm = [normalizer(label) for label in label_str] |
| | |
| | |
| | pred_str_norm = [ |
| | pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0 |
| | ] |
| | label_str_norm = [ |
| | label_str_norm[i] for i in range(len(label_str_norm)) if len(label_str_norm[i]) > 0 |
| | ] |
| |
|
| | wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm) |
| |
|
| | return {"wer_ortho": wer_ortho, "wer": wer} |
| |
|
| | |
| | model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") |
| |
|
| | |
| | model.config.use_cache = False |
| |
|
| | |
| | model.generate = partial( |
| | model.generate, |
| | language="swahili", |
| | task="transcribe", |
| | use_cache=True |
| | ) |
| |
|
| | |
| | training_args = Seq2SeqTrainingArguments( |
| | output_dir="./model", |
| | per_device_train_batch_size=16, |
| | gradient_accumulation_steps=1, |
| | learning_rate=1e-6, |
| | lr_scheduler_type="constant_with_warmup", |
| | warmup_steps=50, |
| | max_steps=10000, |
| | gradient_checkpointing=True, |
| | fp16=True, |
| | fp16_full_eval=True, |
| | evaluation_strategy="steps", |
| | per_device_eval_batch_size=16, |
| | predict_with_generate=True, |
| | generation_max_length=225, |
| | save_steps=500, |
| | eval_steps=500, |
| | logging_steps=100, |
| | report_to=["tensorboard", "wandb"], |
| | load_best_model_at_end=True, |
| | metric_for_best_model="wer", |
| | greater_is_better=False, |
| | push_to_hub=True, |
| | save_total_limit=3, |
| | ) |
| |
|
| | |
| | trainer = Seq2SeqTrainer( |
| | args=training_args, |
| | model=model, |
| | train_dataset=split_dataset['train'], |
| | eval_dataset=split_dataset['test'], |
| | data_collator=data_collator, |
| | compute_metrics=compute_metrics, |
| | tokenizer=processor, |
| | ) |
| |
|
| | |
| | trainer.train() |