| import pytorch_lightning as pl |
| import torch |
| import torch.nn as nn |
| import os |
| import numpy as np |
| import hydra |
| from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection |
|
|
|
|
| class BaselineLightningModule(pl.LightningModule): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.construct_model() |
| self.save_hyperparameters() |
| |
| def construct_model(self): |
| self.feature_extractors = nn.ModuleList([ |
| load_ssl_model(cp_path='wav2vec_small.pt'), |
| DomainEmbedding(3,128), |
| ]) |
| output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) |
| output_layers = [ |
| LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim) |
| ] |
| output_dim = output_layers[-1].get_output_dim() |
| output_layers.append( |
| Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim) |
|
|
| ) |
|
|
| self.output_layers = nn.ModuleList(output_layers) |
|
|
| def forward(self, inputs): |
| outputs = {} |
| for feature_extractor in self.feature_extractors: |
| outputs.update(feature_extractor(inputs)) |
| x = outputs |
| for output_layer in self.output_layers: |
| x = output_layer(x,inputs) |
| return x |
|
|