File size: 6,538 Bytes
58d3955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from typing import Any, Mapping, Optional

import torch
from peft import LoraConfig, get_peft_model
from transformers import WhisperConfig, WhisperModel


class WhisperEncoderBackbone(torch.nn.Module):
    def __init__(
        self,
        model: str = "openai/whisper-small.en",
        hf_config: Optional[Mapping[str, Any]] = None,
        lora_params: Optional[Mapping[str, Any]] = None,
    ):
        """Whisper encoder model with optional Low-Rank Adaptation.

        Parameters
        ----------
        model: Name of WhisperModel whose encoder to load from HuggingFace
        hf_config: Optional config for HuggingFace model
        lora_params: Parameters for Low-Rank Adaptation

        """
        super().__init__()
        hf_config = hf_config if hf_config is not None else dict()
        backbone_config = WhisperConfig.from_pretrained(model, **hf_config)
        self.backbone = (
            WhisperModel.from_pretrained(
                model,
                config=backbone_config,
            )
            .get_encoder()
            .train()
        )
        if lora_params is not None and len(lora_params) > 0:
            lora_config = LoraConfig(**lora_params)
            self.backbone = get_peft_model(self.backbone, lora_config)
        self.backbone_dim = backbone_config.hidden_size

    def forward(self, whisper_feature_batch):
        return self.backbone(whisper_feature_batch).last_hidden_state.mean(dim=1)


class SharedLayers(torch.nn.Module):
    def __init__(self, input_dim: int, proj_dims: list[int]):
        """Fully connected network with Mish nonlinearities between linear layers. No nonlinearity at input or output.

        Parameters
        ----------
        input_dim: Dimension of input features
        proj_dims: Dimensions of layers to create

        """
        super().__init__()
        modules = []
        for output_dim in proj_dims[:-1]:
            modules.extend([torch.nn.Linear(input_dim, output_dim), torch.nn.Mish()])
            input_dim = output_dim
        modules.append(torch.nn.Linear(input_dim, proj_dims[-1]))
        self.shared_layers = torch.nn.Sequential(*modules)

    def forward(self, x):
        return self.shared_layers(x)

class TaskHead(torch.nn.Module):
    def __init__(self, input_dim: int, proj_dim: int, dropout: float = 0.0):
        """Fully connected network with one hidden layer, dropout, and a scalar output."""
        super().__init__()

        self.linear = torch.nn.Linear(input_dim, proj_dim)
        self.activation = torch.nn.Mish()
        self.dropout = torch.nn.Dropout(dropout)
        self.final_layer = torch.nn.Linear(proj_dim, 1, bias=False)

    def forward(self, x):
        x = self.linear(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.final_layer(x)
        return x


class MultitaskHead(torch.nn.Module):
    def __init__(
        self,
        backbone_dim: int,
        shared_projection_dim: list[int],
        tasks: Mapping[str, Mapping[str, Any]],
    ):
        """Fully connected network with multiple named scalar outputs."""
        super().__init__()

        # Initialize the shared network and task-specific networks
        self.shared_layers = SharedLayers(backbone_dim, shared_projection_dim)
        self.classifier_head = torch.nn.ModuleDict(
            {
                task: TaskHead(shared_projection_dim[-1], **task_config)
                for task, task_config in tasks.items()
            }
        )

    def forward(self, x):
        x = self.shared_layers(x)
        return {task: head(x) for task, head in self.classifier_head.items()}


def average_tensor_in_segments(tensor: torch.Tensor, lengths: list[int] | torch.Tensor):
    """Average segments of a `tensor` along dimension 0 based on a list of `lengths`

    For example, with input tensor `t` and `lengths` [1, 3, 2], the output would be
    [t[0], (t[1] + t[2] + t[3]) / 3, (t[4] + t[5]) / 2]

    Parameters
    ----------
    tensor : torch.Tensor
        The tensor to average
    lengths : list of ints
        The lengths of each segment to average in the tensor, in order

    Returns
    -------
    torch.Tensor
        The tensor with relevant segments averaged
    """
    if not torch.is_tensor(lengths):
        lengths = torch.tensor(lengths, device=tensor.device)
    index = torch.repeat_interleave(
        torch.arange(len(lengths), device=tensor.device), lengths
    )
    out = torch.zeros(
        lengths.shape + tensor.shape[1:], device=tensor.device, dtype=tensor.dtype
    )
    out.index_add_(0, index, tensor)
    broadcastable_lengths = lengths.view((-1,) + (1,) * (len(out.shape) - 1))
    return out / broadcastable_lengths


class Classifier(torch.nn.Module):
    def __init__(
        self,
        backbone_configs: Mapping[str, Mapping[str, Any]],
        classifier_config: Mapping[str, Any],
        inference_thresholds: Mapping[str, Any],
        preprocessor_config: Mapping[str, Any],
    ):
        """Full Kintsugi Depression and Anxiety model.

        Whisper encoder -> Mean pooling over time -> Layers shared across tasks -> Per-task heads

        Parameters
        ----------
        backbone_configs:
        classifier_config:
        inference_thresholds:
        preprocessor_config:

        """
        super().__init__()

        self.backbone = torch.nn.ModuleDict(
            {
                key: WhisperEncoderBackbone(**backbone_configs[key])
                for key in sorted(backbone_configs.keys())
            }
        )

        backbone_dim = sum(layer.backbone_dim for layer in self.backbone.values())
        self.head = MultitaskHead(backbone_dim, **classifier_config)
        self.inference_thresholds = inference_thresholds
        self.preprocessor_config = preprocessor_config

    def forward(self, x, lengths):
        backbone_outputs = {
            key: average_tensor_in_segments(layer(x), lengths)
            for key, layer in self.backbone.items()
        }
        backbone_output = torch.cat(list(backbone_outputs.values()), dim=1)
        return self.head(backbone_output), torch.ones_like(lengths)

    def quantize_scores(self, scores: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
        """Map per-task scores to discrete predictions per `inference_thresholds` config."""
        return {
            key: torch.searchsorted(torch.tensor(self.inference_thresholds[key], device=value.device), value.mean(), out_int32=True)
            for key, value in scores.items()
        }