Mayank022's picture
Update config.py
2ca914e verified
import os
import dataclasses
from typing import Optional, Tuple
@dataclasses.dataclass
class ModelConfig:
audio_model_id: str = "openai/whisper-medium"
text_model_id: str = "sarvamai/sarvam-m"
hidden_size: int = 2048
projector_act: str = "gelu"
stack_factor: int = 8
def to_dict(self):
return dataclasses.asdict(self)
@dataclasses.dataclass
class TrainConfig:
# --- Batch & GPU (tuned for A100 80GB) ---
batch_size: int = 8 # per-device; try 64 if no OOM
accum_steps: int = 2 # effective batch = 32*2=64; reduce if OOM
use_bf16: bool = True # A100 native bf16: faster + less VRAM
gradient_checkpointing: bool = False # set True if OOM to trade compute for memory
dataloader_num_workers: int = 8
dataloader_pin_memory: bool = True
learning_rate: float = 1e-4
lr_scheduler_type: str = "cosine"
num_epochs: int = 1
max_steps: int = 10000 # Use either epochs or steps
# Paths
output_dir: str = "./checkpoints"
# data_path: str = "./data/train.jsonl" # REMOVED
dataset_name: str = "fixie-ai/common_voice_17_0"
dataset_subset: str = "hi" # Hindi
dataset_split: str = "train"
val_dataset_split: str = "validation"
# LoRA
use_lora: bool = True
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
# Hub
push_to_hub: bool = False
hub_model_id: Optional[str] = os.getenv("HUB_MODEL_ID", None) # e.g. "username/model-name"
hub_token: Optional[str] = os.getenv("HUB_TOKEN", None)
hub_private_repo: bool = True
# WandB
wandb_project: str = os.getenv("WANDB_PROJECT", "audio-language-model")
wandb_entity: Optional[str] = os.getenv("WANDB_ENTITY", None)
wandb_run_name: Optional[str] = None
wandb_watch: str = "false" # "gradients", "all", "false"
wandb_log_model: str = "false" # "true", "false"
# Misc
seed: int = 42
log_steps: int = 10
eval_steps: int = 250
save_steps: int = 500
save_total_limit: int = 1
sample_pred_every_steps: int = 250 # print ground-truth vs predicted transcript every N steps