| | from __future__ import annotations |
| |
|
| | import os |
| | import sys |
| | from collections import OrderedDict |
| |
|
| | import tensorrt as trt |
| | from tensorrt_llm._common import default_net |
| |
|
| | from ..._utils import str_dtype_to_trt |
| | from ...functional import Tensor, concat |
| | from ...layers import Linear |
| | from ...module import Module, ModuleList |
| | from ...plugin import current_all_reduce_helper |
| | from ..modeling_utils import PretrainedConfig, PretrainedModel |
| | from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding |
| |
|
| |
|
| | current_file_path = os.path.abspath(__file__) |
| | parent_dir = os.path.dirname(current_file_path) |
| | sys.path.append(parent_dir) |
| |
|
| |
|
| | class InputEmbedding(Module): |
| | def __init__(self, mel_dim, text_dim, out_dim): |
| | super().__init__() |
| | self.proj = Linear(mel_dim * 2 + text_dim, out_dim) |
| | self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) |
| |
|
| | def forward(self, x, cond): |
| | x = self.proj(concat([x, cond], dim=-1)) |
| | return self.conv_pos_embed(x) + x |
| |
|
| |
|
| | class F5TTS(PretrainedModel): |
| | def __init__(self, config: PretrainedConfig): |
| | super().__init__(config) |
| | self.dtype = str_dtype_to_trt(config.dtype) |
| |
|
| | self.time_embed = TimestepEmbedding(config.hidden_size) |
| | self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) |
| |
|
| | self.dim = config.hidden_size |
| | self.depth = config.num_hidden_layers |
| | self.transformer_blocks = ModuleList( |
| | [ |
| | DiTBlock( |
| | dim=self.dim, |
| | heads=config.num_attention_heads, |
| | dim_head=config.dim_head, |
| | ff_mult=config.ff_mult, |
| | dropout=config.dropout, |
| | ) |
| | for _ in range(self.depth) |
| | ] |
| | ) |
| |
|
| | self.norm_out = AdaLayerNormZero_Final(config.hidden_size) |
| | self.proj_out = Linear(config.hidden_size, config.mel_dim) |
| |
|
| | def forward( |
| | self, |
| | noise, |
| | cond, |
| | time, |
| | rope_cos, |
| | rope_sin, |
| | input_lengths, |
| | scale=1.0, |
| | ): |
| | t = self.time_embed(time) |
| | x = self.input_embed(noise, cond) |
| | for block in self.transformer_blocks: |
| | x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) |
| | denoise = self.proj_out(self.norm_out(x, t)) |
| | denoise.mark_output("denoised", self.dtype) |
| | return denoise |
| |
|
| | def prepare_inputs(self, **kwargs): |
| | max_batch_size = kwargs["max_batch_size"] |
| | batch_size_range = [2, 2, max_batch_size] |
| | mel_size = 100 |
| | max_seq_len = 3000 |
| | num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size] |
| | hidden_size = 512 |
| | concat_feature_dim = mel_size + hidden_size |
| | freq_embed_dim = 256 |
| | head_dim = 64 |
| | mapping = self.config.mapping |
| | if mapping.tp_size > 1: |
| | current_all_reduce_helper().set_workspace_tensor(mapping, 1) |
| | if default_net().plugin_config.remove_input_padding: |
| | noise = Tensor( |
| | name="noise", |
| | dtype=self.dtype, |
| | shape=[-1, mel_size], |
| | dim_range=OrderedDict( |
| | [ |
| | ("num_frames", [num_frames_range]), |
| | ("n_mels", [mel_size]), |
| | ] |
| | ), |
| | ) |
| | cond = Tensor( |
| | name="cond", |
| | dtype=self.dtype, |
| | shape=[-1, concat_feature_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("num_frames", [num_frames_range]), |
| | ("embeded_length", [concat_feature_dim]), |
| | ] |
| | ), |
| | ) |
| | time = Tensor( |
| | name="time", |
| | dtype=self.dtype, |
| | shape=[-1, freq_embed_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("num_frames", [num_frames_range]), |
| | ("freq_dim", [freq_embed_dim]), |
| | ] |
| | ), |
| | ) |
| | rope_cos = Tensor( |
| | name="rope_cos", |
| | dtype=self.dtype, |
| | shape=[-1, head_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("num_frames", [num_frames_range]), |
| | ("head_dim", [head_dim]), |
| | ] |
| | ), |
| | ) |
| | rope_sin = Tensor( |
| | name="rope_sin", |
| | dtype=self.dtype, |
| | shape=[-1, head_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("num_frames", [num_frames_range]), |
| | ("head_dim", [head_dim]), |
| | ] |
| | ), |
| | ) |
| |
|
| | else: |
| | noise = Tensor( |
| | name="noise", |
| | dtype=self.dtype, |
| | shape=[-1, -1, mel_size], |
| | dim_range=OrderedDict( |
| | [ |
| | ("batch_size", [batch_size_range]), |
| | ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), |
| | ("n_mels", [mel_size]), |
| | ] |
| | ), |
| | ) |
| | cond = Tensor( |
| | name="cond", |
| | dtype=self.dtype, |
| | shape=[-1, -1, concat_feature_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("batch_size", [batch_size_range]), |
| | ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), |
| | ("embeded_length", [concat_feature_dim]), |
| | ] |
| | ), |
| | ) |
| | time = Tensor( |
| | name="time", |
| | dtype=self.dtype, |
| | shape=[-1, freq_embed_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("batch_size", [batch_size_range]), |
| | ("freq_dim", [freq_embed_dim]), |
| | ] |
| | ), |
| | ) |
| | rope_cos = Tensor( |
| | name="rope_cos", |
| | dtype=self.dtype, |
| | shape=[-1, -1, head_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("batch_size", [batch_size_range]), |
| | ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), |
| | ("head_dim", [head_dim]), |
| | ] |
| | ), |
| | ) |
| | rope_sin = Tensor( |
| | name="rope_sin", |
| | dtype=self.dtype, |
| | shape=[-1, -1, head_dim], |
| | dim_range=OrderedDict( |
| | [ |
| | ("batch_size", [batch_size_range]), |
| | ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), |
| | ("head_dim", [head_dim]), |
| | ] |
| | ), |
| | ) |
| | input_lengths = Tensor( |
| | name="input_lengths", |
| | dtype=trt.int32, |
| | shape=[-1], |
| | dim_range=OrderedDict([("batch_size", [batch_size_range])]), |
| | ) |
| | return { |
| | "noise": noise, |
| | "cond": cond, |
| | "time": time, |
| | "rope_cos": rope_cos, |
| | "rope_sin": rope_sin, |
| | "input_lengths": input_lengths, |
| | } |
| |
|