| from typing import Optional |
| from transformers import ModernBertConfig, DacConfig |
|
|
|
|
| DFLT_DACVAE_CONFIG = { |
| "encoder_hidden_size": 64, |
| "downsampling_ratios": [2, 8, 10, 12], |
| "decoder_hidden_size": 1536, |
| "n_codebooks": 16, |
| "codebook_size": 1024, |
| "codebook_dim": 128, |
| "quantizer_dropout": 0, |
| "sampling_rate": 48000, |
| } |
|
|
| DFLT_TEXT_ENCODER_CONFIG = { |
| "classifier_pooling": "mean", |
| "global_attn_every_n_layers": 3, |
| "global_rope_theta": 160000.0, |
| "hidden_size": 1024, |
| "intermediate_size": 2624, |
| "layer_norm_eps": 1e-5, |
| "local_rope_theta": 10000.0, |
| "model_type": "modernbert", |
| "num_attention_heads": 16, |
| "num_hidden_layers": 28, |
| "position_embedding_type": "absolute", |
| "tie_word_embeddings": True, |
| "torch_dtype": "float32", |
| } |
|
|
|
|
| class TransformerConfig: |
| def __init__( |
| self, |
| hidden_size=1024, |
| intermediate_size=2752, |
| num_hidden_layers=16, |
| num_attention_heads=8, |
| num_key_value_heads=None, |
| hidden_act="silu", |
| max_position_embeddings=10_000, |
| rms_norm_eps=1e-5, |
| rope_theta=20000.0, |
| rope_scaling=None, |
| attention_bias=False, |
| attention_dropout=0.0, |
| **kwargs, |
| ): |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| if num_key_value_heads is None: |
| num_key_value_heads = num_attention_heads |
| self.num_key_value_heads = num_key_value_heads |
| self.head_dim = hidden_size // num_attention_heads |
| self.hidden_act = hidden_act |
| self.rms_norm_eps = rms_norm_eps |
| self.rope_theta = rope_theta |
| self.rope_scaling = rope_scaling |
| self.attention_bias = attention_bias |
| self.attention_dropout = attention_dropout |
|
|
|
|
| class AudioEncoderConfig: |
| def __init__( |
| self, |
| dac_vae_encoder: Optional[dict] = None, |
| audio_transformer: Optional[dict] = None, |
| **kwargs, |
| ): |
| dac_vae_encoder = dac_vae_encoder or DFLT_DACVAE_CONFIG |
| audio_transformer = audio_transformer or {} |
| self.dac_vae_encoder = DacConfig(**dac_vae_encoder) |
| self.audio_transformer = TransformerConfig(**audio_transformer) |
|
|
|
|
| class VisualEncoderConfig: |
| def __init__( |
| self, |
| pe_encoder: str = "PE-Core-L14-336", |
| visual_transformer: Optional[dict] = None, |
| fixed_len_video: bool = False, |
| **kwargs, |
| ): |
| visual_transformer = visual_transformer or {} |
|
|
| self.pe_encoder = pe_encoder |
| self.visual_transformer = TransformerConfig(**visual_transformer) |
| self.fixed_len_video = fixed_len_video |
|
|
|
|
| class PEAudioVisualEncoderConfig: |
| def __init__( |
| self, |
| audio_visual_transformer: Optional[dict] = None, |
| visual_model: Optional[dict] = None, |
| audio_model: Optional[dict] = None, |
| **kwargs, |
| ): |
| visual_model = visual_model or {} |
| audio_visual_transformer = audio_visual_transformer or {} |
| audio_model = audio_model or {} |
|
|
| self.visual_model = VisualEncoderConfig(**visual_model) |
| self.audio_model = AudioEncoderConfig(**audio_model) |
| self.audio_visual_transformer = TransformerConfig(**audio_visual_transformer) |
|
|
|
|
| class PEAudioVisualConfig: |
| def __init__( |
| self, |
| audio_visual_model: Optional[dict] = None, |
| text_model: Optional[dict] = None, |
| output_dim: int = 1024, |
| nth_text_layer: Optional[int] = 22, |
| **kwargs, |
| ): |
| text_model = text_model or DFLT_TEXT_ENCODER_CONFIG |
| audio_visual_model = audio_visual_model or {} |
| self.text_model = ModernBertConfig(**text_model) |
| self.audio_visual_model = PEAudioVisualEncoderConfig(**audio_visual_model) |
| self.output_dim = output_dim |
| self.nth_text_layer = nth_text_layer |
|
|
|
|
| class PEAudioFrameConfig: |
| def __init__( |
| self, |
| audio_model: Optional[dict] = None, |
| text_model: Optional[dict] = None, |
| output_dim: int = 1024, |
| nth_text_layer: Optional[int] = 22, |
| **kwargs, |
| ): |
| text_model = text_model or DFLT_TEXT_ENCODER_CONFIG |
| audio_model = audio_model or {} |
| self.text_model = ModernBertConfig(**text_model) |
| self.audio_model = AudioEncoderConfig(**audio_model) |
| self.output_dim = output_dim |
| self.nth_text_layer = nth_text_layer |
|
|