|
|
import os |
|
|
import dataclasses |
|
|
import torch |
|
|
import transformers |
|
|
from transformers import Trainer, TrainingArguments, TrainerCallback |
|
|
from peft import LoraConfig, get_peft_model, TaskType |
|
|
from huggingface_hub import HfApi, login |
|
|
import wandb |
|
|
from dotenv import load_dotenv |
|
|
from config import TrainConfig, ModelConfig |
|
|
from model import MultiModalModel |
|
|
from data import AudioTextDataset, DataCollator |
|
|
|
|
|
|
|
|
class SamplePredictionCallback(TrainerCallback): |
|
|
"""Every N steps, print ground-truth vs model-predicted transcript for a few samples.""" |
|
|
|
|
|
def __init__(self, tokenizer, data_collator, train_dataset, sample_every_n_steps: int = 100, num_samples: int = 2, prompt: str = "Transcribe the following audio:"): |
|
|
self.tokenizer = tokenizer |
|
|
self.data_collator = data_collator |
|
|
self.train_dataset = train_dataset |
|
|
self.sample_every_n_steps = sample_every_n_steps |
|
|
self.num_samples = num_samples |
|
|
self.prompt = prompt |
|
|
def on_log(self, args, state, control, model=None, **kwargs): |
|
|
if state.global_step == 0 or state.global_step % self.sample_every_n_steps != 0: |
|
|
return |
|
|
if model is None: |
|
|
return |
|
|
model.eval() |
|
|
device = next(model.parameters()).device |
|
|
try: |
|
|
indices = [i % len(self.train_dataset) for i in range(self.num_samples)] |
|
|
batch = self.data_collator([self.train_dataset[i] for i in indices]) |
|
|
audio_values = batch["audio_values"].to(device) |
|
|
labels_batch = batch["labels"] |
|
|
continuations = batch.get("continuation", [""] * audio_values.size(0)) |
|
|
prompt_ids = self.tokenizer(self.prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) |
|
|
prompt_ids = prompt_ids.expand(audio_values.size(0), -1) |
|
|
with torch.no_grad(): |
|
|
gen_ids = model.generate( |
|
|
input_ids=prompt_ids, |
|
|
audio_values=audio_values, |
|
|
max_new_tokens=120, |
|
|
do_sample=False, |
|
|
pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id, |
|
|
) |
|
|
prompt_len = prompt_ids.size(1) |
|
|
|
|
|
|
|
|
columns = ["Step", "Audio Index", "Ground Truth", "Prediction", "Continuation"] |
|
|
table = wandb.Table(columns=columns) |
|
|
|
|
|
print(f"\n[WandB] Logging sample predictions at step {state.global_step}") |
|
|
|
|
|
for i in range(audio_values.size(0)): |
|
|
gt_tokens = [t for t in labels_batch[i].tolist() if t != -100] |
|
|
gt_text = self.tokenizer.decode(gt_tokens, skip_special_tokens=True).strip() |
|
|
pred_text = self.tokenizer.decode(gen_ids[i][prompt_len:], skip_special_tokens=True).strip() |
|
|
|
|
|
cont_ref = continuations[i] if i < len(continuations) else "" |
|
|
|
|
|
|
|
|
table.add_data(state.global_step, i, gt_text, pred_text, cont_ref) |
|
|
|
|
|
|
|
|
if wandb.run is not None: |
|
|
wandb.log({"sample_predictions": table}, step=state.global_step) |
|
|
else: |
|
|
print("Warning: WandB run not active, skipping logging.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[SamplePredictionCallback] Error: {e}\n") |
|
|
finally: |
|
|
model.train() |
|
|
|
|
|
|
|
|
import shutil |
|
|
import glob |
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
|
|
|
class AggressiveDeleteCallback(TrainerCallback): |
|
|
""" |
|
|
Deletes ALL existing checkpoints in output_dir *before* saving a new one |
|
|
to ensure we don't run out of disk space. |
|
|
Only keeps the one we are currently training on (in memory) effectively, |
|
|
but on disk we want 0 checkpoints just before save. |
|
|
|
|
|
WARNING: If save fails, we have NO checkpoints on disk. Risk accepted by user. |
|
|
""" |
|
|
def __init__(self, output_dir): |
|
|
self.output_dir = output_dir |
|
|
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
|
|
|
|
|
|
|
if args.save_strategy == "steps" and args.save_steps > 0: |
|
|
if state.global_step > 0 and state.global_step % args.save_steps == 0: |
|
|
|
|
|
print(f"\n[AggressiveDeleteCallback] Step {state.global_step}: Deleting old checkpoints to free space before saving...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ckpts = glob.glob(os.path.join(self.output_dir, "checkpoint-*")) |
|
|
for ckpt in ckpts: |
|
|
try: |
|
|
shutil.rmtree(ckpt) |
|
|
print(f" Deleted {ckpt}") |
|
|
except Exception as e: |
|
|
print(f" Failed to delete {ckpt}: {e}") |
|
|
|
|
|
def train(): |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
train_config = TrainConfig() |
|
|
model_config = ModelConfig() |
|
|
|
|
|
|
|
|
wandb.init( |
|
|
project=train_config.wandb_project, |
|
|
entity=train_config.wandb_entity, |
|
|
name=train_config.wandb_run_name, |
|
|
config=dataclasses.asdict(train_config), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.text_model_id) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
processor = transformers.AutoProcessor.from_pretrained(model_config.audio_model_id) |
|
|
|
|
|
|
|
|
model = MultiModalModel(model_config) |
|
|
|
|
|
|
|
|
if train_config.use_lora: |
|
|
peft_config = LoraConfig( |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
inference_mode=False, |
|
|
r=train_config.lora_r, |
|
|
lora_alpha=train_config.lora_alpha, |
|
|
lora_dropout=train_config.lora_dropout, |
|
|
target_modules=["q_proj", "v_proj"] |
|
|
) |
|
|
model.llm = get_peft_model(model.llm, peft_config) |
|
|
model.llm.print_trainable_parameters() |
|
|
|
|
|
|
|
|
train_dataset = AudioTextDataset(train_config, processor, model_config, tokenizer) |
|
|
data_collator = DataCollator(processor, tokenizer) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=train_config.output_dir, |
|
|
per_device_train_batch_size=train_config.batch_size, |
|
|
gradient_accumulation_steps=train_config.accum_steps, |
|
|
learning_rate=train_config.learning_rate, |
|
|
lr_scheduler_type=train_config.lr_scheduler_type, |
|
|
num_train_epochs=train_config.num_epochs, |
|
|
max_steps=train_config.max_steps, |
|
|
bf16=train_config.use_bf16, |
|
|
gradient_checkpointing=train_config.gradient_checkpointing, |
|
|
dataloader_num_workers=train_config.dataloader_num_workers, |
|
|
dataloader_pin_memory=train_config.dataloader_pin_memory, |
|
|
logging_steps=train_config.log_steps, |
|
|
logging_first_step=True, |
|
|
logging_nan_inf_filter=True, |
|
|
save_steps=train_config.save_steps, |
|
|
save_total_limit=train_config.save_total_limit, |
|
|
eval_strategy="no", |
|
|
remove_unused_columns=False, |
|
|
report_to="wandb", |
|
|
log_level="info", |
|
|
log_level_replica="info", |
|
|
) |
|
|
|
|
|
sample_callback = SamplePredictionCallback( |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
train_dataset=train_dataset, |
|
|
sample_every_n_steps=train_config.sample_pred_every_steps, |
|
|
num_samples=2, |
|
|
prompt="Transcribe the following audio:", |
|
|
) |
|
|
|
|
|
aggressive_delete_callback = AggressiveDeleteCallback(train_config.output_dir) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
data_collator=data_collator, |
|
|
callbacks=[sample_callback, aggressive_delete_callback], |
|
|
) |
|
|
|
|
|
total_steps = train_config.max_steps |
|
|
print(f"\n>>> Training: max_steps={total_steps}, batch_size={train_config.batch_size}, " |
|
|
f"grad_accum={train_config.accum_steps} (effective batch={train_config.batch_size * train_config.accum_steps})") |
|
|
print(f">>> Sample predictions (GT vs predicted transcript) every {train_config.sample_pred_every_steps} steps.\n") |
|
|
|
|
|
|
|
|
last_checkpoint = get_last_checkpoint(train_config.output_dir) |
|
|
if last_checkpoint is not None: |
|
|
print(f">>> Resuming from checkpoint: {last_checkpoint}") |
|
|
trainer.train(resume_from_checkpoint=last_checkpoint) |
|
|
else: |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
trainer.save_model(train_config.output_dir) |
|
|
tokenizer.save_pretrained(train_config.output_dir) |
|
|
processor.save_pretrained(train_config.output_dir) |
|
|
|
|
|
|
|
|
if train_config.push_to_hub: |
|
|
print(f"\n>>> Pushing model to Hugging Face Hub: {train_config.hub_model_id}") |
|
|
if train_config.hub_token: |
|
|
login(token=train_config.hub_token) |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
api.create_repo(repo_id=train_config.hub_model_id, private=train_config.hub_private_repo, exist_ok=True) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not create repo {train_config.hub_model_id}. Error: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
api.upload_folder( |
|
|
folder_path=train_config.output_dir, |
|
|
repo_id=train_config.hub_model_id, |
|
|
repo_type="model", |
|
|
) |
|
|
|
|
|
|
|
|
for file in ["model.py", "config.py", "data.py", "inference.py"]: |
|
|
if os.path.exists(file): |
|
|
api.upload_file( |
|
|
path_or_fileobj=file, |
|
|
path_in_repo=file, |
|
|
repo_id=train_config.hub_model_id, |
|
|
repo_type="model", |
|
|
) |
|
|
|
|
|
print(f">>> Successfully pushed to {train_config.hub_model_id}") |
|
|
except Exception as e: |
|
|
print(f"Error pushing to hub: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|