| from __future__ import annotations |
|
|
| import json |
| import wave |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| from safetensors.torch import load_file as load_safetensors_file |
|
|
| from .constants import DEFAULT_ESPEAK_VOICE, DEFAULT_SAMPLE_RATE |
| from .processor import PreparedInput, prepare_input |
| from .vits import SynthesizerTrn |
|
|
|
|
| def _repo_root() -> Path: |
| return Path(__file__).resolve().parents[2] |
|
|
|
|
| def _default_model_path() -> Path: |
| safetensors_path = _repo_root() / "model.safetensors" |
| if safetensors_path.exists(): |
| return safetensors_path |
|
|
| return _repo_root() / "model.ckpt" |
|
|
|
|
| def _default_config_path() -> Path: |
| return _repo_root() / "config.json" |
|
|
|
|
| def _import_torch() -> Any: |
| try: |
| import torch |
| except ImportError as exc: |
| raise ImportError("torch is required for checkpoint inference") from exc |
|
|
| return torch |
|
|
|
|
| def load_release_config(config_path: str | Path) -> dict[str, Any]: |
| with Path(config_path).open("r", encoding="utf-8") as config_file: |
| return json.load(config_file) |
|
|
|
|
| def audio_float_to_int16(audio: np.ndarray, max_wav_value: float = 32767.0) -> np.ndarray: |
| audio = np.asarray(audio, dtype=np.float32) |
| scale = max(0.01, float(np.max(np.abs(audio)))) if audio.size else 1.0 |
| audio_norm = audio * (max_wav_value / scale) |
| audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value) |
| return audio_norm.astype(np.int16) |
|
|
|
|
| def write_wave(path: str | Path, samples: np.ndarray, sample_rate: int) -> Path: |
| path = Path(path) |
| pcm = audio_float_to_int16(samples) |
|
|
| with wave.open(str(path), "wb") as wav_file: |
| wav_file.setnchannels(1) |
| wav_file.setsampwidth(2) |
| wav_file.setframerate(sample_rate) |
| wav_file.writeframes(pcm.tobytes()) |
|
|
| return path |
|
|
|
|
| def _generator_kwargs_from_config(config: dict[str, Any]) -> dict[str, Any]: |
| model = config.get("model", {}) |
|
|
| return { |
| "n_vocab": int(config["num_symbols"]), |
| "spec_channels": int(model["filter_length"]) // 2 + 1, |
| "segment_size": int(model["segment_size"]) // int(model["hop_length"]), |
| "inter_channels": int(model["inter_channels"]), |
| "hidden_channels": int(model["hidden_channels"]), |
| "filter_channels": int(model["filter_channels"]), |
| "n_heads": int(model["n_heads"]), |
| "n_layers": int(model["n_layers"]), |
| "kernel_size": int(model["kernel_size"]), |
| "p_dropout": float(model["p_dropout"]), |
| "resblock": model["resblock"], |
| "resblock_kernel_sizes": tuple(model["resblock_kernel_sizes"]), |
| "resblock_dilation_sizes": tuple(tuple(x) for x in model["resblock_dilation_sizes"]), |
| "upsample_rates": tuple(model["upsample_rates"]), |
| "upsample_initial_channel": int(model["upsample_initial_channel"]), |
| "upsample_kernel_sizes": tuple(model["upsample_kernel_sizes"]), |
| "n_speakers": int(config["num_speakers"]), |
| "gin_channels": int(model["gin_channels"]), |
| "use_sdp": bool(model.get("use_sdp", True)), |
| } |
|
|
|
|
| def _load_generator_state(model_path: Path, torch_module: Any) -> dict[str, Any]: |
| if model_path.suffix == ".safetensors": |
| return load_safetensors_file(str(model_path), device="cpu") |
|
|
| checkpoint = torch_module.load(model_path, map_location="cpu", weights_only=False) |
| state_dict = checkpoint["state_dict"] |
| return { |
| key[len("model_g.") :]: value |
| for key, value in state_dict.items() |
| if key.startswith("model_g.") |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class GeneratedAudio: |
| samples: np.ndarray |
| sample_rate: int |
| prepared_input: PreparedInput |
|
|
|
|
| class WfloatGenerator: |
| def __init__( |
| self, |
| checkpoint_path: str | Path | None = None, |
| config_path: str | Path | None = None, |
| device: str = "cpu", |
| ) -> None: |
| self.checkpoint_path = Path(checkpoint_path or _default_model_path()) |
| self.config_path = Path(config_path or _default_config_path()) |
| self.device = device |
|
|
| if not self.checkpoint_path.exists(): |
| raise FileNotFoundError( |
| f"Checkpoint not found at {self.checkpoint_path}. " |
| "Place a compatible multi-speaker checkpoint there or pass --checkpoint." |
| ) |
|
|
| if not self.config_path.exists(): |
| raise FileNotFoundError(f"Config not found at {self.config_path}") |
|
|
| self.config = load_release_config(self.config_path) |
| self.sample_rate = int(self.config.get("audio", {}).get("sample_rate", DEFAULT_SAMPLE_RATE)) |
| self.espeak_voice = self.config.get("espeak", {}).get("voice", DEFAULT_ESPEAK_VOICE) |
| self.num_speakers = int(self.config.get("num_speakers", 1)) |
|
|
| torch = _import_torch() |
| self._torch = torch |
| self._model = SynthesizerTrn(**_generator_kwargs_from_config(self.config)) |
| state_dict = _load_generator_state(self.checkpoint_path, torch) |
| self._model.load_state_dict(state_dict, strict=True) |
| self._model.eval() |
|
|
| with torch.no_grad(): |
| self._model.dec.remove_weight_norm() |
|
|
| self._model.to(self.device) |
| self.num_speakers = int(getattr(self._model, "n_speakers", self.num_speakers)) |
|
|
| configured_num_speakers = int(self.config.get("num_speakers", self.num_speakers)) |
| if configured_num_speakers != self.num_speakers: |
| raise ValueError( |
| "Checkpoint/config mismatch: " |
| f"config.json declares num_speakers={configured_num_speakers}, " |
| f"but checkpoint reports num_speakers={self.num_speakers}." |
| ) |
|
|
| def generate( |
| self, |
| text: str, |
| sid: int = 0, |
| emotion: str = "neutral", |
| intensity: float = 0.5, |
| noise_scale: float | None = None, |
| length_scale: float | None = None, |
| noise_w: float | None = None, |
| ) -> GeneratedAudio: |
| if self.num_speakers <= 1: |
| if sid not in (0, None): |
| raise ValueError( |
| f"Loaded checkpoint is single-speaker but sid={sid} was provided" |
| ) |
| sid_tensor = None |
| else: |
| sid_tensor = self._torch.LongTensor([int(sid)]).to(self.device) |
|
|
| prepared = prepare_input( |
| text=text, |
| config=self.config, |
| emotion=emotion, |
| intensity=intensity, |
| espeak_voice=self.espeak_voice, |
| ) |
|
|
| text_tensor = self._torch.LongTensor(prepared.token_ids).unsqueeze(0).to(self.device) |
| text_lengths = self._torch.LongTensor([len(prepared.token_ids)]).to(self.device) |
|
|
| inference = self.config.get("inference", {}) |
| scales = [ |
| float(inference.get("noise_scale", 0.667) if noise_scale is None else noise_scale), |
| float(inference.get("length_scale", 1.0) if length_scale is None else length_scale), |
| float(inference.get("noise_w", 0.8) if noise_w is None else noise_w), |
| ] |
|
|
| with self._torch.no_grad(): |
| audio, *_ = self._model.infer( |
| text_tensor, |
| text_lengths, |
| sid=sid_tensor, |
| noise_scale=scales[0], |
| length_scale=scales[1], |
| noise_scale_w=scales[2], |
| ) |
|
|
| samples = audio.detach().cpu().numpy().squeeze().astype(np.float32) |
|
|
| return GeneratedAudio( |
| samples=samples, |
| sample_rate=self.sample_rate, |
| prepared_input=prepared, |
| ) |
|
|
|
|
| def load_generator( |
| checkpoint_path: str | Path | None = None, |
| config_path: str | Path | None = None, |
| device: str = "cpu", |
| ) -> WfloatGenerator: |
| return WfloatGenerator( |
| checkpoint_path=checkpoint_path, |
| config_path=config_path, |
| device=device, |
| ) |
|
|