| | import argparse |
| |
|
| | import torch.nn as nn |
| | |
| |
|
| | from .macros import ( |
| | NUM_AUDIO_TOKENS, |
| | NUM_MEL_BINS, |
| | NUM_SPEAKER_CLASSES, |
| | NUM_TEXT_TOKENS, |
| | SPEAKER_EMBEDDING_DIM, |
| | ) |
| | from .vallex import VALLE, VALLF |
| |
|
| |
|
| | def add_model_arguments(parser: argparse.ArgumentParser): |
| | parser.add_argument( |
| | "--model-name", |
| | type=str, |
| | default="VALL-E", |
| | help="VALL-E, VALL-F, Transformer.", |
| | ) |
| | parser.add_argument( |
| | "--decoder-dim", |
| | type=int, |
| | default=1024, |
| | help="Embedding dimension in the decoder model.", |
| | ) |
| | parser.add_argument( |
| | "--nhead", |
| | type=int, |
| | default=16, |
| | help="Number of attention heads in the Decoder layers.", |
| | ) |
| | parser.add_argument( |
| | "--num-decoder-layers", |
| | type=int, |
| | default=12, |
| | help="Number of Decoder layers.", |
| | ) |
| | parser.add_argument( |
| | "--scale-factor", |
| | type=float, |
| | default=1.0, |
| | help="Model scale factor which will be assigned different meanings in different models.", |
| | ) |
| | parser.add_argument( |
| | "--norm-first", |
| | type=bool, |
| | default=True, |
| | help="Pre or Post Normalization.", |
| | ) |
| | parser.add_argument( |
| | "--add-prenet", |
| | type=bool, |
| | default=False, |
| | help="Whether add PreNet after Inputs.", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--prefix-mode", |
| | type=int, |
| | default=1, |
| | help="The mode for how to prefix VALL-E NAR Decoder, " |
| | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", |
| | ) |
| | parser.add_argument( |
| | "--share-embedding", |
| | type=bool, |
| | default=True, |
| | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", |
| | ) |
| | parser.add_argument( |
| | "--prepend-bos", |
| | type=bool, |
| | default=False, |
| | help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.", |
| | ) |
| | parser.add_argument( |
| | "--num-quantizers", |
| | type=int, |
| | default=8, |
| | help="Number of Audio/Semantic quantization layers.", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--scaling-xformers", |
| | type=bool, |
| | default=False, |
| | help="Apply Reworked Conformer scaling on Transformers.", |
| | ) |
| |
|
| |
|
| | def get_model(params) -> nn.Module: |
| | if params.model_name.lower() in ["vall-f", "vallf"]: |
| | model = VALLF( |
| | params.decoder_dim, |
| | params.nhead, |
| | params.num_decoder_layers, |
| | norm_first=params.norm_first, |
| | add_prenet=params.add_prenet, |
| | prefix_mode=params.prefix_mode, |
| | share_embedding=params.share_embedding, |
| | nar_scale_factor=params.scale_factor, |
| | prepend_bos=params.prepend_bos, |
| | num_quantizers=params.num_quantizers, |
| | ) |
| | elif params.model_name.lower() in ["vall-e", "valle"]: |
| | model = VALLE( |
| | params.decoder_dim, |
| | params.nhead, |
| | params.num_decoder_layers, |
| | norm_first=params.norm_first, |
| | add_prenet=params.add_prenet, |
| | prefix_mode=params.prefix_mode, |
| | share_embedding=params.share_embedding, |
| | nar_scale_factor=params.scale_factor, |
| | prepend_bos=params.prepend_bos, |
| | num_quantizers=params.num_quantizers, |
| | ) |
| | else: |
| | raise ValueError("No such model") |
| |
|
| | return model |
| |
|