File size: 6,538 Bytes
58d3955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from typing import Any, Mapping, Optional
import torch
from peft import LoraConfig, get_peft_model
from transformers import WhisperConfig, WhisperModel
class WhisperEncoderBackbone(torch.nn.Module):
def __init__(
self,
model: str = "openai/whisper-small.en",
hf_config: Optional[Mapping[str, Any]] = None,
lora_params: Optional[Mapping[str, Any]] = None,
):
"""Whisper encoder model with optional Low-Rank Adaptation.
Parameters
----------
model: Name of WhisperModel whose encoder to load from HuggingFace
hf_config: Optional config for HuggingFace model
lora_params: Parameters for Low-Rank Adaptation
"""
super().__init__()
hf_config = hf_config if hf_config is not None else dict()
backbone_config = WhisperConfig.from_pretrained(model, **hf_config)
self.backbone = (
WhisperModel.from_pretrained(
model,
config=backbone_config,
)
.get_encoder()
.train()
)
if lora_params is not None and len(lora_params) > 0:
lora_config = LoraConfig(**lora_params)
self.backbone = get_peft_model(self.backbone, lora_config)
self.backbone_dim = backbone_config.hidden_size
def forward(self, whisper_feature_batch):
return self.backbone(whisper_feature_batch).last_hidden_state.mean(dim=1)
class SharedLayers(torch.nn.Module):
def __init__(self, input_dim: int, proj_dims: list[int]):
"""Fully connected network with Mish nonlinearities between linear layers. No nonlinearity at input or output.
Parameters
----------
input_dim: Dimension of input features
proj_dims: Dimensions of layers to create
"""
super().__init__()
modules = []
for output_dim in proj_dims[:-1]:
modules.extend([torch.nn.Linear(input_dim, output_dim), torch.nn.Mish()])
input_dim = output_dim
modules.append(torch.nn.Linear(input_dim, proj_dims[-1]))
self.shared_layers = torch.nn.Sequential(*modules)
def forward(self, x):
return self.shared_layers(x)
class TaskHead(torch.nn.Module):
def __init__(self, input_dim: int, proj_dim: int, dropout: float = 0.0):
"""Fully connected network with one hidden layer, dropout, and a scalar output."""
super().__init__()
self.linear = torch.nn.Linear(input_dim, proj_dim)
self.activation = torch.nn.Mish()
self.dropout = torch.nn.Dropout(dropout)
self.final_layer = torch.nn.Linear(proj_dim, 1, bias=False)
def forward(self, x):
x = self.linear(x)
x = self.activation(x)
x = self.dropout(x)
x = self.final_layer(x)
return x
class MultitaskHead(torch.nn.Module):
def __init__(
self,
backbone_dim: int,
shared_projection_dim: list[int],
tasks: Mapping[str, Mapping[str, Any]],
):
"""Fully connected network with multiple named scalar outputs."""
super().__init__()
# Initialize the shared network and task-specific networks
self.shared_layers = SharedLayers(backbone_dim, shared_projection_dim)
self.classifier_head = torch.nn.ModuleDict(
{
task: TaskHead(shared_projection_dim[-1], **task_config)
for task, task_config in tasks.items()
}
)
def forward(self, x):
x = self.shared_layers(x)
return {task: head(x) for task, head in self.classifier_head.items()}
def average_tensor_in_segments(tensor: torch.Tensor, lengths: list[int] | torch.Tensor):
"""Average segments of a `tensor` along dimension 0 based on a list of `lengths`
For example, with input tensor `t` and `lengths` [1, 3, 2], the output would be
[t[0], (t[1] + t[2] + t[3]) / 3, (t[4] + t[5]) / 2]
Parameters
----------
tensor : torch.Tensor
The tensor to average
lengths : list of ints
The lengths of each segment to average in the tensor, in order
Returns
-------
torch.Tensor
The tensor with relevant segments averaged
"""
if not torch.is_tensor(lengths):
lengths = torch.tensor(lengths, device=tensor.device)
index = torch.repeat_interleave(
torch.arange(len(lengths), device=tensor.device), lengths
)
out = torch.zeros(
lengths.shape + tensor.shape[1:], device=tensor.device, dtype=tensor.dtype
)
out.index_add_(0, index, tensor)
broadcastable_lengths = lengths.view((-1,) + (1,) * (len(out.shape) - 1))
return out / broadcastable_lengths
class Classifier(torch.nn.Module):
def __init__(
self,
backbone_configs: Mapping[str, Mapping[str, Any]],
classifier_config: Mapping[str, Any],
inference_thresholds: Mapping[str, Any],
preprocessor_config: Mapping[str, Any],
):
"""Full Kintsugi Depression and Anxiety model.
Whisper encoder -> Mean pooling over time -> Layers shared across tasks -> Per-task heads
Parameters
----------
backbone_configs:
classifier_config:
inference_thresholds:
preprocessor_config:
"""
super().__init__()
self.backbone = torch.nn.ModuleDict(
{
key: WhisperEncoderBackbone(**backbone_configs[key])
for key in sorted(backbone_configs.keys())
}
)
backbone_dim = sum(layer.backbone_dim for layer in self.backbone.values())
self.head = MultitaskHead(backbone_dim, **classifier_config)
self.inference_thresholds = inference_thresholds
self.preprocessor_config = preprocessor_config
def forward(self, x, lengths):
backbone_outputs = {
key: average_tensor_in_segments(layer(x), lengths)
for key, layer in self.backbone.items()
}
backbone_output = torch.cat(list(backbone_outputs.values()), dim=1)
return self.head(backbone_output), torch.ones_like(lengths)
def quantize_scores(self, scores: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
"""Map per-task scores to discrete predictions per `inference_thresholds` config."""
return {
key: torch.searchsorted(torch.tensor(self.inference_thresholds[key], device=value.device), value.mean(), out_int32=True)
for key, value in scores.items()
}
|