File size: 5,255 Bytes
f5b3371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

import torch
import torch.nn as nn
import transformers
from typing import Optional, Tuple, Union, List
from config import ModelConfig

class ModelProjector(nn.Module):
    def __init__(self, config: ModelConfig, audio_hidden_size: int):
        super().__init__()
        self.stack_factor = config.stack_factor
        input_dim = audio_hidden_size * self.stack_factor
        
        self.linear1 = nn.Linear(input_dim, config.hidden_size)
        self.act = nn.GELU() if config.projector_act == 'gelu' else nn.ReLU()
        self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.norm = nn.LayerNorm(config.hidden_size)
    
    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        if audio_features.dim() == 3 and audio_features.shape[1] < audio_features.shape[2]:
             audio_features = audio_features.transpose(1, 2)
        
        B, T, C = audio_features.shape
        
        if T % self.stack_factor != 0:
            pad_len = self.stack_factor - (T % self.stack_factor)
            audio_features = torch.nn.functional.pad(audio_features, (0, 0, 0, pad_len))
            T = T + pad_len
            
        audio_features = audio_features.view(B, T // self.stack_factor, C * self.stack_factor)
        
        x = self.linear1(audio_features)
        x = self.act(x)
        x = self.linear2(x)
        x = self.norm(x)
        return x

class MultiModalModel(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        self.audio_encoder = transformers.AutoModel.from_pretrained(config.audio_model_id).encoder
        for param in self.audio_encoder.parameters():
            param.requires_grad = False
            
        audio_hidden_size = self.audio_encoder.config.hidden_size
        
        self.llm = transformers.AutoModelForCausalLM.from_pretrained(config.text_model_id, trust_remote_code=True)
        self.llm_hidden_size = self.llm.config.hidden_size
        
        self.projector = ModelProjector(config, audio_hidden_size)
        if config.hidden_size != self.llm_hidden_size:
             self.projector.linear2 = nn.Linear(config.hidden_size, self.llm_hidden_size)
             self.projector.norm = nn.LayerNorm(self.llm_hidden_size)

        
    def forward(
        self, 
        input_ids: torch.Tensor, 
        audio_values: Optional[torch.Tensor] = None, 
        labels: Optional[torch.Tensor] = None,
        **kwargs
    ):
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)
        
        if audio_values is not None:
             audio_outputs = self.audio_encoder(audio_values)
             audio_features = audio_outputs.last_hidden_state
             
             audio_projected = self.projector(audio_features)
             
             inputs_embeds = torch.cat([audio_projected, inputs_embeds], dim=1)
             
             if labels is not None:
                  audio_labels = torch.full((audio_projected.shape[0], audio_projected.shape[1]), -100, device=labels.device, dtype=labels.dtype)
                  labels = torch.cat([audio_labels, labels], dim=1)
                  
             if "attention_mask" in kwargs:
                  audio_mask = torch.ones((audio_projected.shape[0], audio_projected.shape[1]), device=inputs_embeds.device, dtype=kwargs["attention_mask"].dtype)
                  kwargs["attention_mask"] = torch.cat([audio_mask, kwargs["attention_mask"]], dim=1)

        # Match LLM dtype (e.g. bfloat16) to avoid "float != bfloat16" in linear layers
        llm_dtype = next(self.llm.parameters()).dtype
        inputs_embeds = inputs_embeds.to(llm_dtype)
        if labels is not None:
            labels = labels.to(llm_dtype) if labels.dtype.is_floating_point else labels

        # Drop non-tensor keys (e.g. continuation) so LLM forward doesn't receive them
        kwargs = {k: v for k, v in kwargs.items() if isinstance(v, torch.Tensor)}
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            labels=labels,
            **kwargs
        )
        
        return outputs
    
    def generate(self, input_ids, audio_values=None, **kwargs):
        inputs_embeds = self.llm.get_input_embeddings()(input_ids)
        
        if audio_values is not None:
             audio_outputs = self.audio_encoder(audio_values)
             audio_features = audio_outputs.last_hidden_state
             audio_projected = self.projector(audio_features)
             inputs_embeds = torch.cat([audio_projected, inputs_embeds], dim=1)
             
             if "attention_mask" in kwargs:
                  audio_mask = torch.ones((audio_projected.shape[0], audio_projected.shape[1]), device=inputs_embeds.device, dtype=kwargs["attention_mask"].dtype)
                  kwargs["attention_mask"] = torch.cat([audio_mask, kwargs["attention_mask"]], dim=1)
             inputs_embeds = inputs_embeds.to(next(self.llm.parameters()).dtype)

        return self.llm.generate(inputs_embeds=inputs_embeds, **kwargs)

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        self.llm.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)