diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..989588081e5b95f478167c20cb1434c76dddf36d 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+data/ref.wav filter=lfs diff=lfs merge=lfs -text
+data/sample.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7e99e367f8443d86e5e8825b9fda39dfbb39630d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+*.pyc
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 0bac9b53087b061ab5661c5dc6c591a9abaa1d26..178fe97a46523e31f5873b1414dd44f28acaafa0 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,155 @@
----
-title: Fun CineForge Demo
-emoji: 🏢
-colorFrom: green
-colorTo: gray
-sdk: gradio
-sdk_version: 6.9.0
-app_file: app.py
-pinned: false
-license: apache-2.0
-short_description: Fun-CineForge-zh-en-v1-0.5B
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+###
「English | [简体中文](./README_zh.md)」
+
+
+🎬 Fun-CineForge: A Unified Dataset Pipeline and Model for Zero-Shot Movie Dubbing
+in Diverse Cinematic Scenes
+
+
+
+
+
+

+

+

+
+
+
+
+**Fun-CineForge** contains an end-to-end dataset pipeline for producing large-scale dubbing datasets and an MLLM-based dubbing model designed for diverse cinematic scenes. Using this pipeline, we constructed the first large-scale Chinese television dubbing dataset CineDub-CN, which includes rich annotations and diverse scenes. In monologue, narration, dialogue, and multi-speaker scenes, our dubbing model consistently outperforms state-of-the-art methods in terms of audio quality, lip-sync, timbre transition, and instruction following.
+
+
+## Dataset & Demo 🎬
+You can access [https://funcineforge.github.io/](https://funcineforge.github.io/) to get our CineDub-CN dataset samples and demo samples.
+
+
+## Environmental Installation
+
+Fun-CineForge relies on Conda and Python environments. Execute **setup.py** to automatically install the entire project environment and open-source model.
+
+```shell
+# Conda
+git clone git@github.com:FunAudioLLM/FunCineForge.git
+conda create -n FunCineForge python=3.10 -y && conda activate FunCineForge
+sudo apt-get install ffmpeg
+# Initial settings
+python setup.py
+```
+
+
+## Dataset Pipeline 🔨
+
+### Data collection
+If you want to produce your own data,
+we recommend that you refer to the following requirements to collect the corresponding movies or television series.
+
+1. Video source: TV dramas or movies, non documentaries, with more monologues or dialogue scenes, clear and unobstructed faces (such as without masks and veils).
+2. Speech Requirements: Standard pronunciation, clear articulation, prominent human voice. Avoid materials with strong dialects, excessive background noise, or strong colloquialism.
+3. Image Requirements: High resolution, clear facial details, sufficient lighting, avoiding extremely dark or strong backlit scenes.
+
+### How to use
+
+- [1] Standardize video format and name; trim the beginning and end of long videos; extract the audio from the trimmed video. (default is to trim 10 seconds from both the beginning and end.)
+```shell
+python normalize_trim.py --root datasets/raw_zh --intro 10 --outro 10
+```
+
+- [2] [Speech Separation](./speech_separation/README.md). The audio is used to separate the vocals from the instrumental music.
+```shell
+cd speech_separation
+python run.py --root datasets/clean/zh --gpus 0 1 2 3
+```
+
+- [3] [VideoClipper](./video_clip/README.md). For long videos, VideoClipper is used to obtain sentence-level subtitle files and clip the long video into segments based on timestamps. Now it supports bilingualism in both Chinese and English. Below is an example in Chinese. It is recommended to use gpu acceleration for English.
+```shell
+cd video_clip
+bash run.sh --stage 1 --stop_stage 2 --input datasets/raw_zh --output datasets/clean/zh --lang zh --device cpu
+```
+
+- Video duration limit and check for cleanup. (Without --execute, only pre-deleted files will be printed. After checking, add --execute to confirm the deletion.)
+```shell
+python clean_video.py --root datasets/clean/zh
+python clean_srt.py --root datasets/clean/zh --lang zh
+```
+
+- [4] [Speaker Diarization](./speaker_diarization/README.md). Multimodal active speaker recognition obtains RTTM files; identifies the speaker's facial frames, extracts frame-level speaker face and lip raw data.
+```shell
+cd speaker_diarization
+bash run.sh --stage 1 --stop_stage 4 --hf_access_token hf_xxx --root datasets/clean/zh --gpus "0 1 2 3"
+```
+
+- (Reference) Extract speech tokens based on the CosyVoice3 tokenizer for llm training.
+```shell
+python speech_tokenizer.py --root datasets/clean/zh
+```
+
+- [5] Multimodal CoT Correction. Based on general-purpose MLLMs, the system uses audio, ASR text, and RTTM files as input. It leverages Chain-of-Thought (CoT) reasoning to extract clues and corrects the results of the specialized models. It also annotates character age, gender, and vocal timbre. Experimental results show that this strategy reduces the CER from 4.53% to 0.94% and the speaker diarization error rate from 8.38% to 1.20%, achieving quality comparable to or even better than manual transcription. Adding the --resume enables breakpoint COT inference to prevent wasted resources from repeated COT inferences. Now supports both Chinese and English.
+```shell
+python cot.py --root_dir datasets/clean/zh --lang zh --provider google --model gemini-3-pro-preview --api_key xxx --resume
+python cot.py --root_dir datasets/clean/en --lang en --provider google --model gemini-3-pro-preview --api_key xxx --resume
+```
+
+- The construction of the dataset retrieval file will read all production data, perform bidirectional verification of script content and speaker separation results.
+```shell
+python build_datasets.py --root_zh datasets/clean/zh --root_en datasets/clean/en --out_dir datasets/clean --save
+```
+
+
+## Dubbing Model ⚙️
+We've open-sourced the inference code and the **infer.sh** script, and provided some test cases in the data folder for your experience. Inference requires a consumer-grade GPU. Run the following command:
+
+```shell
+cd exps
+bash infer.sh
+```
+
+The API for multi-speaker dubbing from raw videos and SRT scripts is under development ...
+
+
+## Recent Updates 🚀
+- 2025/12/18: Fun-CineForge dataset pipeline toolkit is online! 🔥
+- 2026/01/19: Chinese demo samples and CineDub-CN dataset samples released. 🔥
+- 2026/01/25: Fix some environmental and operational issues.
+- 2026/02/09: Optimized the data pipeline and added support for English videos.
+- 2026/03/05: English demo samples and CineDub-EN dataset samples released. 🔥
+- 2026/03/16: Open source inference code and checkpoints. 🔥
+
+
+## Publication 📚
+If you use our dataset or code, please cite the following paper:
+
+@misc{liu2026funcineforgeunifieddatasettoolkit,
+ title={FunCineForge: A Unified Dataset Toolkit and Model for Zero-Shot Movie Dubbing in Diverse Cinematic Scenes},
+ author={Jiaxuan Liu and Yang Xiang and Han Zhao and Xiangang Li and Zhenhua Ling},
+ year={2026},
+ eprint={2601.14777},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+}
+
+
+
+## Comminicate 🍟
+The Fun-CineForge open-source project is developed and maintained by the Tongyi Lab Speech Team and a student from NERCSLIP, University of Science and Technology of China.
+We welcome you to participate in discussions on Fun-CineForge [GitHub Issues](https://github.com/FunAudioLLM/FunCineForge/issues) or contact us for collaborative development.
+For any questions, you can contact the [developer](mailto:jxliu@mail.ustc.edu.cn).
+
+⭐ Hope you will support Fun-CineForge. Thank you.
+
+### Disclaimer
+
+This repository contains research artifacts:
+
+⚠️ Currently not a commercial product of Tongyi Lab.
+
+⚠️ Released for academic research / cutting-edge exploration purposes
+
+⚠️ CineDub Dataset samples are subject to specific license terms.
\ No newline at end of file
diff --git a/README_zh.md b/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..074275eaee57de3019dc56c926776d393c9337ed
--- /dev/null
+++ b/README_zh.md
@@ -0,0 +1,153 @@
+### 「[English](./README.md) | 简体中文」
+
+
+🎬 Fun-CineForge:一种用于多样化影视场景零样本配音的统一数据集管道和模型
+
+
+
+
+
+

+

+

+
+
+
+
+**Fun-CineForge** 包含一个生产大规模配音数据集的端到端数据集管道,和一个基于多模态大模型的配音模型,该模型专为多样的电影场景而设计。利用该管道,我们构建了首个大规模中文电视剧配音数据集 CineDub-CN,该数据集包含丰富的标注和多样化的场景。在独白、旁白、对话和多说话人场景中,我们的配音模型在音频质量、唇形同步、音色转换和指令遵循等方面全部优于最先进的方法。
+
+
+## 数据集 & 样例 🎬
+您可以访问此 [https://funcineforge.github.io/](https://funcineforge.github.io/) 获取我们的 CineDub-CN 数据集和 CineDub-EN 数据集样例和演示样例。
+
+
+## 环境安装
+
+Fun-CineForge 依赖 Conda 和 Python 环境。执行 **setup.py** 自动安装整个项目环境和开源模型。
+
+```shell
+# Conda
+git clone git@github.com:FunAudioLLM/FunCineForge.git
+conda create -n FunCineForge python=3.10 -y && conda activate FunCineForge
+sudo apt-get install ffmpeg
+# 初始化设置
+python setup.py
+```
+
+
+## 数据集管道 🔨
+
+### 数据收集
+如果您想自行生产数据,我们建议您参考下面的要求收集相应的电影或影视剧。
+
+1. 视频来源:电视剧或电影,非纪录片,人物独白或对话场景较多,人脸清晰且无遮挡(如无面罩、面纱)。
+2. 语音要求:发音标准,吐字清晰,人声突出。避免方言浓重、背景噪音过大或口语感过强的素材。
+3. 图片要求:高分辨率,面部细节清晰,光线充足,避免极端阴暗或强烈逆光的场景。
+
+### 使用方法
+
+- [1] 将视频格式、名称标准化;裁剪长视频的片头片尾;提取裁剪后视频的音频。(默认是从起止各裁剪 10 秒。)
+```shell
+python normalize_trim.py --root datasets/raw_zh --intro 10 --outro 10
+```
+
+- [2] [Speech Separation](./speech_separation/README.md). 音频进行人声乐声分离。
+```shell
+cd speech_separation
+python run.py --root datasets/clean/zh --gpus 0 1 2 3
+```
+
+- [3] [VideoClipper](./video_clip/README.md). 对于长视频,使用 VideoClipper 获取句子级别的字幕文件,并根据时间戳将长视频剪辑成片段。现在它支持中英双语。以下是中文示例。英文建议采用 gpu 加速处理。
+```shell
+cd video_clip
+bash run.sh --stage 1 --stop_stage 2 --input datasets/raw_zh --output datasets/clean/zh --lang zh --device cpu
+```
+
+- 视频时长限制及清理检查。(若不使用--execute参数,则仅打印已预删除的文件。检查后,若需确认删除,请添加--execute参数。)
+```shell
+python clean_video.py --root datasets/clean/zh
+python clean_srt.py --root datasets/clean/zh --lang zh
+```
+
+- [4] [Speaker Diarization](./speaker_diarization/README.md). 多模态主动说话人识别,得到 RTTM 文件;识别说话人的面部帧,提取帧级的说话人面部和唇部原始数据,从面部帧中识别说话帧,提取说话帧的面部特征。
+```shell
+cd speaker_diarization
+bash run.sh --stage 1 --stop_stage 4 --hf_access_token hf_xxx --root datasets/clean/zh --gpus "0 1 2 3"
+```
+
+- (参考)基于 CosyVoice3 tokenizer 提取 speech tokens 用于大模型训练。
+```shell
+python speech_tokenizer.py --root datasets/clean/zh
+```
+
+- [5] 多模态思维链校正。该系统基于通用多模态大模型,以音频、ASR 抄本和 RTTM 文件为输入,利用思维链推理来提取线索,并校正专用模型的结果,并标注人物年龄、性别和音色。实验结果表明,该策略将词错率从4.53% 降低到 0.94%,说话人识别错误率从 8.38% 降低到 1.20%,其质量可与人工转录相媲美,甚至更优。添加--resume选项可启用断点思维链推理,以避免重复思维链推理造成的资源浪费。现支持中英文。
+```shell
+python cot.py --root_dir datasets/clean/zh --lang zh --provider google --model gemini-3-pro-preview --api_key xxx --resume
+python cot.py --root_dir datasets/clean/en --lang en --provider google --model gemini-3-pro-preview --api_key xxx --resume
+```
+
+- 数据集检索文件的构建会读取生产的所有数据,双向校验脚本内容和说话人分离结果。
+```shell
+python build_datasets.py --root_zh datasets/clean/zh --root_en datasets/clean/en --out_dir datasets/clean --save
+```
+
+
+## 配音模型 ⚙️
+我们开源了推理代码和 **infer.sh** 脚本,在 data 文件夹中提供了一些测试样例,以供体验。推理需要一张消费级 GPU。按下面的命令运行:
+
+```shell
+cd exps
+bash infer.sh
+```
+
+从原始视频和 SRT 脚本进行多人配音的 API 调用接口在开发中 ...
+
+
+## 近期更新 🚀
+- 2025/12/18:Fun-CineForge 数据集管道工具包上线!🔥
+- 2026/01/19:发布中文演示样例和 CineDub-CN 数据集样例。 🔥
+- 2026/01/25:修复了一些环境和运行问题。
+- 2026/02/09:优化了数据管道,新增支持英文视频的能力。
+- 2026/03/05:发布英文演示样例和 CineDub-EN 数据集样例。 🔥
+- 2026/03/16:开源推理代码和 checkpoints。 🔥
+
+
+## 发表 📚
+如果您使用了我们的数据集或代码,请引用以下论文:
+
+@misc{liu2026funcineforgeunifieddatasettoolkit,
+ title={FunCineForge: A Unified Dataset Toolkit and Model for Zero-Shot Movie Dubbing in Diverse Cinematic Scenes},
+ author={Jiaxuan Liu and Yang Xiang and Han Zhao and Xiangang Li and Zhenhua Ling},
+ year={2026},
+ eprint={2601.14777},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+}
+
+
+
+
+## 社区交流 🍟
+Fun-CineForge 开源项目由通义实验室语音团队和中国科学技术大学 NERCSLIP 学生开发并维护,我们欢迎您在 Fun-CineForge [GitHub Issues](https://github.com/FunAudioLLM/FunCineForge/issues) 参与问题讨论,或联系我们合作开发。
+有任何问题您可以联系[开发者](mailto:jxliu@mail.ustc.edu.cn)。
+
+⭐ 希望您你支持 Fun-CineForge,谢谢。
+
+### 免责声明
+
+该仓库包含的研究成果:
+
+⚠️ 目前非通义实验室商业化产品
+
+⚠️ 供学术研究/前沿探索用途
+
+⚠️ 数据集样例受特定许可条款约束
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f569abb1de3bd49a35cfe0a9e89422d2b1d77d8c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,415 @@
+# app.py
+import os
+import json
+import torch
+import gradio as gr
+import typing
+import time
+import shutil
+from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
+from moviepy.audio.AudioClip import CompositeAudioClip
+from modelscope import snapshot_download
+from utils import get_video_duration, generate_jsonl_data, validate_timestamps, parse_srt_content
+# 尝试导入模型库
+from funcineforge import AutoFrontend
+from speaker_diarization.run import GlobalModels
+snapshot_download(
+ repo_id="FunAudioLLM/Fun-CineForge",
+ revision='v1.0.0',
+ local_dir='pretrained_models',
+ ignore_patterns=[
+ "*.md",
+ ".git*",
+ "funcineforge_zh_en/llm/config.yaml"
+ ],
+ repo_type="model",
+)
+
+
+# ==================== 配置区域 ====================
+DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
+SERVER_PORT = 7860
+TEMP_DIR = "temp_workdir"
+CONFIG_FRONTEND = "decode_conf/diar.yaml"
+CONFIG_MODEL = "decode_conf/decode.yaml"
+PRETRAIN = "pretrained_models"
+MAX_SEGMENTS = 8 # UI 片段数上限
+DEFAULT_VIDEO_PATH="data/sample.mp4"
+DEFAULT_AUDIO_PATH="data/ref.wav"
+DEFAULT_TEXT = "我军无粮,利在急战。今乘魏兵新败,不敢出兵,出其不意,乘机退去,方可平安无事。"
+DEFAULT_CLUE = "一位中年男性以沉稳但略带担忧的语调,分析我军无粮急战的困境与敌军心败状态。他随即提出一种撤退方案,整体流露出对战局的担忧和谋求生路。"
+# 全局模型实例(延迟加载)
+model_pool: typing.Optional[GlobalModels] = None
+engine = None
+
+def init_engine():
+ """延迟加载模型,避免启动时卡住"""
+ global engine
+ engine = AutoFrontend(PRETRAIN, CONFIG_MODEL, TEMP_DIR, DEVICE)
+ return engine
+
+def init_frontend_models():
+ global model_pool
+ model_pool = GlobalModels(
+ hf_token = None,
+ config_path = CONFIG_FRONTEND,
+ pretrained_dir= PRETRAIN,
+ device = DEVICE,
+ pool_sizes = {"face": 1, "asd": 1, "fr": 1},
+ batch_size = 1,
+ preload = True
+ )
+ return model_pool
+
+# ==================== Gradio UI 逻辑 ====================
+
+def create_segments_ui():
+ segments = []
+ accordions = []
+ for i in range(MAX_SEGMENTS):
+ with gr.Accordion(f"🎬 配音片段 {i + 1}", open=(i == 0), visible=(i == 0)) as acc:
+ accordions.append(acc)
+ with gr.Row():
+ text_input = gr.Textbox(label="📝 配音文本内容", placeholder="输入台词...", lines=2, scale=3, elem_id=f"text_{i}")
+ clue_input = gr.Textbox(label="💡 线索描述", placeholder="一位中年男性角色语气沉稳且坚定,流露出对自身忠诚的强烈自信与决心。整体情感是忠贞不渝的承诺和不容置疑的信念。", lines=2, scale=3, elem_id=f"clue_{i}")
+ with gr.Row():
+ start_time = gr.Number(label="⏱️ 起始时间 (s)", value=0.0 + i*5, precision=2, scale=2, elem_id=f"start_{i}")
+ end_time = gr.Number(label="⏱️ 终止时间 (s)", value=5.0 + i*5, precision=2, scale=2, elem_id=f"end_{i}")
+ with gr.Row():
+ age_input = gr.Dropdown(label="👤 年龄", choices=["儿童", "青年", "中年", "中老年", "老年", "不确定"], value="不确定", scale=2, elem_id=f"age_{i}")
+ gender_input = gr.Dropdown(label="👤 性别", choices=["男", "女", "不确定"], value="不确定", scale=2, elem_id=f"gender_{i}")
+ with gr.Row():
+ ref_audio = gr.Audio(label="🎤 参考语音 (可选,默认以视频原声作为参考音频)", sources=["upload"], type="filepath", scale=4,elem_id=f"audio_{i}")
+ load_audio_btn = gr.Button("📂 加载示例音频", size="sm", variant="secondary", scale=1) if i == 0 else None
+ with gr.Row():
+ enable_check = gr.Checkbox(label="启用此片段", value=(i == 0), scale=1, elem_id=f"enable_{i}")
+
+ segments.append({
+ "accordion": acc, "text": text_input, "clue": clue_input, "start": start_time, "end": end_time,
+ "age": age_input, "gender": gender_input, "audio": ref_audio,
+ "enable": enable_check, "index": i, "load_audio_btn": load_audio_btn})
+ return segments, accordions
+
+def add_segment_fn(current_count):
+ """点击加号:显示下一个片段,到达上限则禁用按钮"""
+ if current_count >= MAX_SEGMENTS:
+ return [current_count] + [gr.update() for _ in range(MAX_SEGMENTS)] + [gr.update(interactive=False, value=f"已达上限 ({MAX_SEGMENTS})")]
+
+ new_count = current_count + 1
+ vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
+ btn = gr.update(interactive=(new_count < MAX_SEGMENTS), value="➕新片段")
+ return [new_count] + vis + [btn]
+
+def load_srt_fn(srt_file, current_count):
+ empty_fields = [gr.update() for _ in range(MAX_SEGMENTS * 4)]
+ empty_vis = [gr.update() for _ in range(MAX_SEGMENTS)]
+ if not srt_file:
+ return [current_count] + empty_fields + empty_vis + [gr.update()]
+ try:
+ with open(srt_file, 'r', encoding='utf-8-sig') as f:
+ content = f.read()
+ except Exception as e:
+ gr.Warning(f"读取 SRT 文件失败: {e}")
+ return [current_count] + empty_fields + empty_vis + [gr.update()]
+ parsed = parse_srt_content(content)
+ if not parsed:
+ print(" 未解析到有效字幕,请检查 SRT 格式")
+ return [current_count] + empty_fields + empty_vis + [gr.update()]
+ updates = []
+ for i in range(MAX_SEGMENTS):
+ if i < len(parsed):
+ seg = parsed[i]
+ updates.append(gr.update(value=seg['text']))
+ updates.append(gr.update(value=round(seg['start'], 2)))
+ updates.append(gr.update(value=round(seg['end'], 2)))
+ updates.append(gr.update(value=True))
+ else:
+ updates.append(gr.update(value=""))
+ updates.append(gr.update(value=0.0))
+ updates.append(gr.update(value=5.0 + i*5))
+ updates.append(gr.update(value=False))
+ new_count = min(len(parsed), MAX_SEGMENTS)
+ vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
+ btn = gr.update(interactive=(new_count < MAX_SEGMENTS))
+ if len(parsed) > MAX_SEGMENTS:
+ gr.Warning(f"SRT 包含 {len(parsed)} 个片段,已截取前 {MAX_SEGMENTS} 条")
+
+ return [new_count] + updates + vis + [btn]
+
+def process_dubbing(video_file, *segment_inputs, progress=gr.Progress()):
+ """主推理流程"""
+ if not video_file:
+ return None, "❌ 请上传视频文件"
+
+ video_duration = get_video_duration(video_file)
+ if video_duration <= 0:
+ return None, "❌ 无法获取视频时长,请检查视频文件"
+
+ if os.path.exists(TEMP_DIR):
+ try:
+ shutil.rmtree(TEMP_DIR)
+ except Exception as e:
+ return None, f"❌ 清空临时目录失败:{e}"
+ os.makedirs(TEMP_DIR, exist_ok=True)
+
+ # 解析 segment_inputs
+ segments_data = []
+ for i in range(MAX_SEGMENTS):
+ base_idx = i * 8
+ enable = segment_inputs[base_idx + 7] # enable_check
+ if not enable: continue
+ text = segment_inputs[base_idx + 0]
+ if not text or not text.strip(): continue
+
+ clue = segment_inputs[base_idx + 1]
+ start = segment_inputs[base_idx + 2]
+ end = segment_inputs[base_idx + 3]
+ age = segment_inputs[base_idx + 4]
+ gender = segment_inputs[base_idx + 5]
+ ref_audio = segment_inputs[base_idx + 6]
+
+ errors = validate_timestamps(start, end, video_duration)
+ if errors:
+ return None, f"❌ 片段 {i+1} 时间戳错误:\n" + "\n".join(errors)
+
+ data = {
+ "text": str(text).strip(),
+ "clue": str(clue) if clue else "",
+ "start": float(start) if start else 0.0,
+ "end": float(end) if end else 0.0,
+ "age": str(age) if age else "不确定",
+ "gender": str(gender) if gender else "不确定",
+ "ref_audio": str(ref_audio) if ref_audio else ""
+ }
+
+ segments_data.append(data)
+
+ if not segments_data:
+ return None, "❌ 有效片段数据为空,请启用并填写至少一个片段"
+
+ try:
+ progress(0.1, desc="📋 预处理视频,生成 JSONL 数据...")
+ frontend = init_frontend_models()
+ jsonl_path, jsonl_items = generate_jsonl_data(frontend, video_file, segments_data, TEMP_DIR, video_duration)
+ report_lines = [f"✅ 任务完成!共生成 **{len(jsonl_items)}** 个片段数据。\n", "详细 JSONL 数据预览:**", "=" * 40]
+ for idx, item in enumerate(jsonl_items):
+ report_lines.extend([f"\n---片段 #{idx + 1} ---", json.dumps(item, ensure_ascii=False, indent=2), "-" * 40])
+ full_report = "\n".join(report_lines)
+
+ progress(0.3, desc="🔄 FunCineForge 模型加载中...")
+
+ eng = init_engine()
+ if eng and jsonl_items:
+ try:
+ progress(0.5, desc="🚀 FunCineForge 模型推理中...")
+ eng.inference(jsonl_path)
+
+ progress(0.8, desc="🎵 正在将配音语音粘贴回静音视频...")
+
+ output_wav_dir = os.path.join(TEMP_DIR, "wav")
+ final_video_path = os.path.join(TEMP_DIR, "dubbed_video.mp4")
+
+ if not os.path.exists(output_wav_dir):
+ return None, f"⚠️ 未找到音频输出目录:{output_wav_dir}"
+
+ wav_files = sorted([f for f in os.listdir(output_wav_dir) if f.endswith('.wav')])
+ if not wav_files:
+ return None, f"⚠️ 未生成任何音频文件:{output_wav_dir}"
+
+ time_mapping = {}
+ for item in jsonl_items:
+ for wf in wav_files:
+ if wf.startswith(item['utt']):
+ time_mapping[wf] = float(item['start'])
+ break
+
+ original_clip = VideoFileClip(video_file)
+ video_duration = original_clip.duration
+ is_silent = original_clip.audio is None
+ video_only = original_clip if is_silent else original_clip.without_audio()
+ audio_clips = []
+ for wav_file, start_time in time_mapping.items():
+ wav_path = os.path.join(output_wav_dir, wav_file)
+ audio_clip = AudioFileClip(wav_path).with_start(start_time)
+ audio_clips.append(audio_clip)
+
+ final_audio = CompositeAudioClip(audio_clips)
+ if final_audio.duration < video_duration:
+ final_audio = final_audio.with_duration(video_duration)
+ final_clip = video_only.with_audio(final_audio)
+ final_clip.write_videofile(
+ final_video_path,
+ codec='libx264',
+ audio_codec='aac',
+ preset='veryfast',
+ threads=8,
+ fps=original_clip.fps,
+ logger=None
+ )
+ original_clip.close(); video_only.close()
+ for ac in audio_clips: ac.close()
+ if 'final_audio' in locals(): final_audio.close()
+ final_clip.close()
+
+ progress(1.0, desc="✅ 配音完成")
+ return final_video_path, full_report
+ except Exception as e:
+ import traceback; traceback.print_exc()
+ if "index out of range" in str(e):
+ return None, f"⚠️ 模型推理失败。错误:{str(e)},建议补齐输入的线索描述和说话人属性"
+ else:
+ return None, f"⚠️ 模型推理失败。错误:{str(e)}"
+ else:
+ time.sleep(1)
+ progress(1.0, desc="模拟完成")
+ return video_file, full_report
+
+ except Exception as e:
+ import traceback; traceback.print_exc()
+ return None, f"❌ 发生错误:{str(e)}"
+
+
+# ==================== 主程序 ====================
+
+def main():
+ os.makedirs(TEMP_DIR, exist_ok=True)
+ with gr.Blocks(
+ title="Fun-CineForge 影视配音平台",
+ theme=gr.themes.Soft(),
+ css="""
+ .segment-accordion { margin: 10px 0; }
+ .gr-button-primary { background: #1976d2; }
+ .gr-button-stop { background: #d32f2f; }
+ """
+ ) as demo:
+
+ gr.Markdown("""
+ # 🎬 Fun-CineForge
+
+ **工作流程:** 上传短视频 → 配音片段信息(或上传 .srt 字幕文件) → 上传参考音色(可选) → 预处理、模型加载和推理 → 输出配音视频
+ """)
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ video_input = gr.Video(label="上传视频", sources=["upload"])
+ load_video_btn = gr.Button("📂 加载示例视频", variant="secondary", size="sm")
+ srt_input = gr.UploadButton("上传 SRT 字幕", file_types=[".srt"], size="sm", variant="secondary")
+ # with gr.Row(elem_classes=["srt-compact"]):
+ # srt_input = gr.File(label="上传 SRT 字幕", file_types=[".srt"], height="auto")
+ gr.Markdown("### 🎛️ 配音片段配置")
+
+ segments, accordions = create_segments_ui()
+ seg_count_state = gr.State(1) #🔑记录当前可见片段数
+ add_segment_btn = gr.Button("➕添加新片段", size="sm", variant="secondary")
+ submit_btn = gr.Button("🚀 开始生成配音", variant="stop", size="lg")
+
+ with gr.Column(scale=1):
+ video_output = gr.Video(label="📺 配音后视频", autoplay=True)
+
+ status_text = gr.Textbox(label="结果状态", interactive=False, lines=2)
+
+ gr.Markdown("""
+ ### 📝 使用说明
+ | 字段 | 说明 |
+ |------|------|
+ | 配音文本 | 该片段台词内容(支持中/英) |
+ | 线索描述 | 请参考样例格式,阐述配音要求,重点描述说话人的性别年龄、语气和情感 |
+ | 时间戳 | 起止时间戳 (可精确到毫秒),模型对时间戳敏感,建议紧邻有声区间。时长 ≤30s/片段 |
+ | 年龄/性别 | 说话人属性选项 |
+ | 参考语音 | 音色克隆参考 (可选) |
+
+ **⚠️ 注意:** 确保每个片段的时间戳不重叠,且时间戳不超过视频总时长。模型会根据片段的时间长度进行强制时间对齐,弱监督对齐唇部运动。
+ """)
+
+ # ==================== 事件绑定 ====================
+
+ # 收集所有片段组件作为输入
+ segment_inputs = []
+ for seg in segments:
+ segment_inputs.extend([
+ seg["text"],
+ seg["clue"],
+ seg["start"],
+ seg["end"],
+ seg["age"],
+ seg["gender"],
+ seg["audio"],
+ seg["enable"]
+ ])
+
+ srt_update_fields = []
+ for seg in segments:
+ srt_update_fields.extend([seg["text"], seg["start"], seg["end"], seg["enable"]])
+
+ # 动态添加片段
+ add_segment_btn.click(
+ fn=add_segment_fn,
+ inputs=[seg_count_state],
+ outputs=[seg_count_state] + accordions + [add_segment_btn]
+ )
+
+ # SRT 加载
+ srt_input.upload(
+ fn=load_srt_fn,
+ inputs=[srt_input, seg_count_state],
+ outputs=[seg_count_state] + srt_update_fields + accordions + [add_segment_btn]
+ )
+
+ # 主推理
+ submit_btn.click(
+ fn=process_dubbing,
+ inputs=[video_input] + segment_inputs,
+ outputs=[video_output, status_text]
+ )
+
+ # 视频上传联动时间戳
+ def update_timestamps(video):
+ if not video: return [gr.update() for _ in range(MAX_SEGMENTS * 2)]
+ dur = get_video_duration(video)
+ updates = []
+ for i in range(MAX_SEGMENTS):
+ updates.append(gr.update(value=0.0))
+ updates.append(gr.update(value=dur))
+ return updates
+
+ def load_default_video_fn():
+ return DEFAULT_VIDEO_PATH, DEFAULT_TEXT, DEFAULT_CLUE
+
+ def load_default_audio_fn():
+ return DEFAULT_AUDIO_PATH
+
+ load_video_btn.click(
+ fn=load_default_video_fn,
+ inputs=[],
+ outputs=[video_input, segments[0]["text"], segments[0]["clue"]]
+ ).then(
+ fn=update_timestamps,
+ inputs=[video_input],
+ outputs=[segment_inputs[i] for i in range(len(segment_inputs)) if i % 8 in [2, 3]]
+ )
+
+ video_input.change(
+ fn=update_timestamps,
+ inputs=[video_input],
+ outputs=[comp for pair in zip(segment_inputs[2::8], segment_inputs[3::8]) for comp in pair]
+ )
+
+ if segments and segments[0]["load_audio_btn"]:
+ segments[0]["load_audio_btn"].click(
+ fn=load_default_audio_fn,
+ inputs=[],
+ outputs=[segments[0]["audio"]]
+ )
+
+ # ==================== 启动服务 ====================
+
+ demo.launch(
+ server_name="0.0.0.0",
+ server_port=SERVER_PORT,
+ share=False,
+ show_error=True,
+ inbrowser=True,
+ )
+
+if __name__ == "__main__":
+ main()
diff --git a/data/ref.wav b/data/ref.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e4ade7d66817e84e06ce8ff794e0630ca0a90e77
--- /dev/null
+++ b/data/ref.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8420568976edb1cf17a63d9fa968aedaf3c0f68cca4dbf75a409876b96ad700b
+size 788876
diff --git a/data/sample.mp4 b/data/sample.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b93cd03629902635e883db124fe29d43dd413ca7
--- /dev/null
+++ b/data/sample.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b901981a2213fc7f98cd6424869710e8396eb558ff2ff3e8ab5d52fe427e0ab6
+size 2567737
diff --git a/decode_conf/decode.yaml b/decode_conf/decode.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..96dfd2d9c9ce3d76beb29eac8ec8acce78ad8d8e
--- /dev/null
+++ b/decode_conf/decode.yaml
@@ -0,0 +1,42 @@
+model: FunCineForgeInferModel
+index_ds: FunCineForgeDS
+xvec_model: pretrained_models/funcineforge_zh_en/camplus.onnx
+model_conf: {}
+
+dataset_conf:
+# face is from the video, vocal is the reference audio, extract speaker ID and start-end timestamp from dialogue
+ load_meta_data_key: "text,clue,face,dialogue,vocal,video"
+ sos: 6561
+ eos: 6562
+ turn_of_speech: 6563
+ fill_token: 6564
+ ignore_id: -100
+ startofclue_token: 151646
+ endofclue_token: 151647
+ frame_shift: 25 # ms
+ timebook_size: 1500 # 60 * 25 = 1500
+ pangbai: 1500
+ dubai: 1501
+ duihua: 1502
+ duoren: 1503
+ male: 1504
+ female: 1505
+ child: 1506
+ youth: 1507
+ adult: 1508
+ middle: 1509
+ elderly: 1510
+ speaker_id_start: 1511
+
+
+sampling: ras
+lm_use_prompt: true
+fm_use_prompt: true
+use_llm_cache: true
+seed: 0
+max_length: 1500 # 60s * 25 fps
+min_length: 50 # 2s * 25 fps
+llm_dtype: fp32
+fm_dtype: fp32
+voc_dtype: fp32
+batch_size: 1
\ No newline at end of file
diff --git a/decode_conf/diar.yaml b/decode_conf/diar.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7b96a04d0ea119ebb5c1a02aecc8003b5c69ddf6
--- /dev/null
+++ b/decode_conf/diar.yaml
@@ -0,0 +1,51 @@
+# Diarization config
+
+fbank_dim: 80
+embedding_size: 192
+
+feature_extractor:
+ obj: speakerlab.process.processor.FBank
+ args:
+ n_mels:
+ sample_rate:
+ mean_nor: True
+
+embedding_model:
+ obj: speakerlab.models.campplus.DTDNN.CAMPPlus
+ args:
+ feat_dim:
+ embedding_size:
+
+# for visual embeddings extraction
+min_track: 10
+num_failed_det: 10
+crop_scale: 0.4
+min_face_size: 1
+face_det_stride: 5 # 每5帧检测一次人脸
+shot_stride: 50
+
+# for clustering
+audio_cluster:
+ obj: speakerlab.process.cluster.CommonClustering
+ args:
+ cluster_type: spectral
+ min_num_spks: 1
+ max_num_spks: 15
+ min_cluster_size: 1
+ oracle_num: null
+ pval: 0.032
+ mer_cos: 0.8
+
+vision_cluster:
+ obj: speakerlab.process.cluster.CommonClustering
+ args:
+ cluster_type: AHC
+ cluster_line: 2
+ min_cluster_size: 1
+ fix_cos_thr: 0.25
+
+cluster:
+ obj: speakerlab.process.cluster.JointClustering
+ args:
+ audio_cluster:
+ vision_cluster:
diff --git a/decode_conf/ds_stage0_fp32.json b/decode_conf/ds_stage0_fp32.json
new file mode 100644
index 0000000000000000000000000000000000000000..99941ce20cafff63703c7d8f25641d246c0defa3
--- /dev/null
+++ b/decode_conf/ds_stage0_fp32.json
@@ -0,0 +1,33 @@
+{
+ "train_micro_batch_size_per_gpu": 1,
+ "gradient_accumulation_steps": 1,
+ "steps_per_print": 100,
+ "gradient_clipping": 5,
+ "fp16": {
+ "enabled": false,
+ "auto_cast": false,
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "consecutive_hysteresis": false,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": false
+ },
+ "zero_force_ds_cpu_optimizer": false,
+ "zero_optimization": {
+ "stage": 0,
+ "offload_optimizer": {
+ "device": "none",
+ "pin_memory": true
+ },
+ "allgather_partitions": true,
+ "allgather_bucket_size": 5e8,
+ "overlap_comm": true,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e8,
+ "contiguous_gradients" : true
+ }
+}
diff --git a/funcineforge/.DS_Store b/funcineforge/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..da9ee7ebd8e4b6b069dd5222172c97a4301bfe28
Binary files /dev/null and b/funcineforge/.DS_Store differ
diff --git a/funcineforge/__init__.py b/funcineforge/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a8970aa5066500a5bf6746506f440d575c453ef
--- /dev/null
+++ b/funcineforge/__init__.py
@@ -0,0 +1,7 @@
+"""Initialize package."""
+
+import os
+from funcineforge.auto.auto_model import AutoModel
+from funcineforge.auto.auto_frontend import AutoFrontend
+
+os.environ["HYDRA_FULL_ERROR"] = "1"
diff --git a/funcineforge/auto/__init__.py b/funcineforge/auto/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/funcineforge/auto/auto_frontend.py b/funcineforge/auto/auto_frontend.py
new file mode 100644
index 0000000000000000000000000000000000000000..b486f2aabc59e556153fd8ba7c0e8d8b85519baf
--- /dev/null
+++ b/funcineforge/auto/auto_frontend.py
@@ -0,0 +1,95 @@
+import os
+import torch
+import logging
+from omegaconf import OmegaConf
+from funcineforge.utils.hinter import get_logger
+from funcineforge.models.utils import dtype_map
+from funcineforge.datasets import FunCineForgeDS
+
+class AutoFrontend:
+ def __init__(
+ self,
+ ckpt_path: str,
+ config_path: str,
+ output_dir: str,
+ device: str = "cuda:0"
+ ):
+ self.logger = get_logger(log_level=logging.INFO, local_rank=1, world_size=1)
+ self.device = device
+ self.output_dir = output_dir
+ self.lm_model = None
+ self.fm_model = None
+ self.voc_model = None
+ self.model = None
+ self.index_ds_class = None
+
+ self.dataset_conf = None
+ self.kwargs = OmegaConf.load(config_path)
+
+ if device.startswith("cuda"):
+ try:
+ device_id = int(device.split(":")[-1])
+ torch.cuda.set_device(device_id)
+ except (ValueError, IndexError):
+ self.logger.warning(f"Invalid cuda device string {device}, defaulting to 0")
+ torch.cuda.set_device(0)
+ else:
+ self.logger.info(f"Running on CPU")
+
+
+ lm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/llm/ds-model.pt.best/mp_rank_00_model_states.pt")
+ fm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/flow/ds-model.pt.best/mp_rank_00_model_states.pt")
+ voc_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/vocoder/ds-model.pt.best/avg_5_removewn.pt")
+
+ lm_exp_dir, lm_model_name, lm_ckpt_id, _ = lm_ckpt_path.rsplit("/", 3)
+ self.logger.info(f"init LM model form {lm_ckpt_path}")
+
+ from funcineforge import AutoModel
+ self.lm_model = (AutoModel(
+ model=os.path.join(lm_exp_dir, lm_model_name),
+ init_param=lm_ckpt_path,
+ output_dir=None,
+ device=device,
+ ))
+ self.lm_model.model.to(dtype_map[self.kwargs.get("llm_dtype", "fp32")])
+
+ fm_exp_dir, fm_model_name, fm_ckpt_id, _ = fm_ckpt_path.rsplit("/", 3)
+ self.logger.info(f"build FM model form {fm_ckpt_path}")
+ self.fm_model = AutoModel(
+ model=os.path.join(fm_exp_dir, fm_model_name),
+ init_param=fm_ckpt_path,
+ output_dir=None,
+ device=device,
+ )
+ self.fm_model.model.to(dtype_map[self.kwargs.get("fm_dtype", "fp32")])
+
+ voc_exp_dir, voc_model_name, voc_ckpt_id, _ = voc_ckpt_path.rsplit("/", 3)
+ self.logger.info(f"build VOC model form {voc_ckpt_path}")
+ self.voc_model = AutoModel(
+ model=os.path.join(voc_exp_dir, voc_model_name),
+ init_param=voc_ckpt_path,
+ output_dir=None,
+ device=device,
+ )
+ self.voc_model.model.to(dtype_map[self.kwargs.get("voc_dtype", "fp32")])
+
+ self.logger.info(f"build inference model {self.kwargs.get('model')}")
+ self.kwargs["output_dir"] = output_dir
+ self.kwargs["tokenizer"] = None
+ self.model = AutoModel(
+ **self.kwargs,
+ lm_model=self.lm_model,
+ fm_model=self.fm_model,
+ voc_model=self.voc_model,
+ )
+ self.dataset_conf = self.kwargs.get("dataset_conf")
+
+ def inference(self, jsonl_path: str):
+ if not self.model:
+ raise RuntimeError("Model class not initialized.")
+
+ dataset = FunCineForgeDS(jsonl_path, **self.dataset_conf)
+ self.logger.info(f"Starting inference on {len(dataset)} items...")
+
+ self.model.inference(input=dataset, input_len=len(dataset))
+ self.logger.info("Inference finished.")
\ No newline at end of file
diff --git a/funcineforge/auto/auto_model.py b/funcineforge/auto/auto_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..31d35c48b81d188c51fa0b6b8d443401fc93c2ce
--- /dev/null
+++ b/funcineforge/auto/auto_model.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
+import time
+import torch
+import logging
+import os
+from tqdm import tqdm
+from funcineforge.utils.misc import deep_update
+from funcineforge.utils.set_all_random_seed import set_all_random_seed
+from funcineforge.utils.load_pretrained_model import load_pretrained_model
+from funcineforge.download.download_model_from_hub import download_model
+from funcineforge.tokenizer import FunCineForgeTokenizer
+from funcineforge.face import FaceRecIR101
+import importlib
+
+
+def prepare_data_iterator(data_in, input_len):
+ """ """
+ data_list = []
+ key_list = []
+ for idx in range(input_len):
+ item = data_in[idx]
+ utt = item["utt"]
+ data_list.append(item)
+ key_list.append(utt)
+ return key_list, data_list
+
+
+class AutoModel:
+
+ def __init__(self, **kwargs):
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+ logging.basicConfig(level=log_level)
+ model, kwargs = self.build_model(**kwargs)
+ self.kwargs = kwargs
+ self.model = model
+ self.model_path = kwargs.get("model_path")
+
+ @staticmethod
+ def build_model(**kwargs):
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info("download models from {} or local dir".format(kwargs.get("hub", "ms")))
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ torch.set_num_threads(kwargs.get("ncpu", 4))
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer = FunCineForgeTokenizer(**kwargs.get("tokenizer_conf", {}))
+ kwargs["token_list"] = (
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ )
+ kwargs["token_list"] = (
+ tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
+ )
+ vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
+ else:
+ vocab_size = -1
+ kwargs["tokenizer"] = tokenizer
+
+ # build face_encoder
+ face_encoder = kwargs.get("face_encoder", None)
+ if face_encoder is not None:
+ face_encoder = FaceRecIR101(**kwargs.get("face_encoder_conf", {}))
+ kwargs["face_encoder"] = face_encoder
+
+ model_conf = {}
+ model_class_name = kwargs["model"]
+ deep_update(model_conf, kwargs.get("model_conf", {}))
+ deep_update(model_conf, kwargs)
+ module = importlib.import_module("funcineforge.models")
+ model_class = getattr(module, model_class_name)
+ model = model_class(**model_conf, vocab_size=vocab_size)
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None and os.path.exists(init_param):
+ logging.info(f"Loading pretrained params from ckpt: {init_param}")
+ load_pretrained_model(
+ path=init_param,
+ model=model,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ scope_map=kwargs.get("scope_map", []),
+ excludes=kwargs.get("excludes", None),
+ use_deepspeed=kwargs.get("train_conf", {}).get("use_deepspeed", False),
+ save_deepspeed_zero_fp32=kwargs.get("save_deepspeed_zero_fp32", True),
+ )
+
+ # fp16
+ if kwargs.get("fp16", False):
+ model.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ model.to(torch.bfloat16)
+ model.to(device)
+
+ return model, kwargs
+
+ def __call__(self, *args, **cfg):
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ res = self.model(*args, kwargs)
+ return res
+
+
+ def inference(self, input, input_len=None, model=None, kwargs=None, **cfg):
+ kwargs = self.kwargs if kwargs is None else kwargs
+ deep_update(kwargs, cfg)
+ model = self.model if model is None else model
+ model.eval()
+ batch_size = kwargs.get("batch_size", 1)
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len
+ )
+
+ speed_stats = {}
+ num_samples = len(data_list)
+ disable_pbar = self.kwargs.get("disable_pbar", False)
+ pbar = (
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
+ )
+ time_speech_total = 0.0
+ time_escape_total = 0.0
+ count = 0
+ log_interval = kwargs.get("log_interval", None)
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "data_lengths": end_idx - beg_idx, "key": key_batch}
+
+ time1 = time.perf_counter()
+ with torch.no_grad():
+ res = model.inference(**batch, **kwargs)
+ if isinstance(res, (list, tuple)):
+ results = res[0] if len(res) > 0 else [{"text": ""}]
+ meta_data = res[1] if len(res) > 1 else {}
+ time2 = time.perf_counter()
+
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ time_escape = time2 - time1
+ speed_stats["forward"] = f"{time_escape:0.3f}"
+ speed_stats["batch_size"] = f"{len(results)}"
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
+ description = f"{speed_stats}, "
+ if pbar:
+ pbar.update(batch_size)
+ pbar.set_description(description)
+ else:
+ if log_interval is not None and count % log_interval == 0:
+ logging.info(
+ f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}"
+ )
+ time_speech_total += batch_data_time
+ time_escape_total += time_escape
+ count += 1
+
+ if pbar:
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+ torch.cuda.empty_cache()
+ return
diff --git a/funcineforge/datasets/__init__.py b/funcineforge/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b0b1bfa169dc90eed19feaa88de5bd609edd09d
--- /dev/null
+++ b/funcineforge/datasets/__init__.py
@@ -0,0 +1,2 @@
+from .index_ds import FunCineForgeDS
+from .datasets import FunCineForgeDataset
\ No newline at end of file
diff --git a/funcineforge/datasets/datasets.py b/funcineforge/datasets/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb5400426c9ac00f4240c96ab5d10a6ed21bde9
--- /dev/null
+++ b/funcineforge/datasets/datasets.py
@@ -0,0 +1,193 @@
+import logging
+import torch
+import pickle
+import numpy as np
+from funcineforge.utils.hinter import hint_once
+from funcineforge.datasets import FunCineForgeDS
+from funcineforge.models import FunCineForgeSpecAug
+
+class FunCineForgeDataset(torch.utils.data.Dataset):
+ """
+ Dataset for Mixed LM of FunCineForge
+ """
+
+ def __init__(
+ self,
+ path,
+ index_ds: str = None,
+ frontend=None,
+ tokenizer=None,
+ face_encoder=None,
+ int_pad_value: int = -1,
+ float_pad_value: float = 0.0,
+ **kwargs,
+ ):
+ super().__init__()
+ self.index_ds = FunCineForgeDS(path, **kwargs)
+ self.tokenizer = tokenizer
+ self.face_encoder = face_encoder
+
+ self.int_pad_value = int_pad_value
+ self.float_pad_value = float_pad_value
+ self.batch_size = kwargs.get("batch_size")
+ self.batch_type = kwargs.get("batch_type")
+ self.retry = kwargs.get("retry", 100)
+
+ # self.kwargs = kwargs
+ self.max_token_length = kwargs.get("max_token_length", 1500)
+ self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5)
+ self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500)
+ self.multiturn_num_max = kwargs.get("multiturn_num_max", 1)
+ self.face_size = kwargs.get("face_size", 512)
+
+ self.codebook_size = kwargs.get("codebook_size", 6561)
+ self.sos = kwargs.get("sos", self.codebook_size)
+ self.eos = kwargs.get("eos", self.codebook_size + 1)
+ self.turn_of_speech = kwargs.get("turn_of_speech", self.codebook_size + 2)
+ self.ignore_id = kwargs.get("ignore_id", -100)
+
+ specaug = kwargs.get("specaug", None)
+ specaug_conf = kwargs.get("specaug_conf", {})
+ if specaug is not None:
+ specaug = FunCineForgeSpecAug(**specaug_conf)
+ self.specaug = specaug
+
+ self.set_invalid_xvec_zeros = kwargs.get("set_invalid_xvec_zeros", False)
+ self.use_emotion_clue = kwargs.get("use_emotion_clue", False)
+ logging.info(f"use_emotion_clue: {self.use_emotion_clue}")
+
+ def get_source_len(self, index):
+ item = self.index_ds[index]
+ source_len = self.index_ds.get_source_len(item)
+ return source_len
+
+ def get_target_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_target_len(item)
+
+ def __len__(self):
+ return len(self.index_ds)
+
+ def mixup_text_codec(self, text: torch.Tensor, aug_codec: torch.Tensor, timespk_ids: torch.Tensor, type_id: int):
+ text_len = text.shape[0]
+ timespk_len = timespk_ids.shape[0]
+ sequence = [self.sos, *text.tolist(), type_id, *timespk_ids.tolist(), self.turn_of_speech, *aug_codec.tolist(), self.eos]
+ # sequence = [self.sos, *text.tolist(), type_id, self.turn_of_speech, *aug_codec.tolist(), self.eos]
+ input_ids = torch.tensor(sequence, dtype=torch.int64)
+ text_flag = torch.zeros(len(sequence), dtype=torch.float32)
+ text_flag[1:text_len+1] = 1
+ timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
+ timespk_flag[text_len+1:text_len+2+timespk_len] = 1
+ # timespk_flag[text_len+1:text_len+2] = 1
+ codec_flag = 1 - (text_flag + timespk_flag)
+ labels = torch.tensor(sequence, dtype=torch.int64)
+ labels[:text_len+timespk_len+3] = self.ignore_id
+ # labels[:text_len+3] = self.ignore_id
+
+ return input_ids, labels, text_flag, codec_flag, timespk_flag
+
+ def __getitem__(self, index):
+ output = None
+ for idx in range(self.retry):
+ if idx == 0:
+ index_cur = index
+ else:
+ index_cur = torch.randint(0, len(self.index_ds), ()).item()
+ item = self.index_ds[index_cur]
+
+ # clue + text
+ text = item["text"]
+ clue = "<|startofclue|>" + item["clue"] + "<|endofclue|>"
+ if self.use_emotion_clue:
+ text = clue + text
+ text_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.int32)
+ hint_once(f"raw text: {text}", "log_text")
+
+ # speech tokens
+ target_out = item["token"]
+ codec = torch.from_numpy(np.load(target_out))
+ codec_len = codec.shape[0] # 可用数据集中的 speech_length 代替
+ aug_codec = codec.clone()
+ if self.specaug is not None: # aug_codec是随机mask的codec增强鲁棒性
+ aug_codec, _ = self.specaug(aug_codec.float().unsqueeze(0).unsqueeze(-1))
+ aug_codec = aug_codec.squeeze(0).squeeze(-1).long()
+
+ # dialogue
+ timespk_ids = torch.from_numpy(item["timespk_ids"])
+
+ # mixup
+ type_id = item["type_id"]
+ input_ids, labels, text_flag, codec_flag, timespk_flag = self.mixup_text_codec(
+ text_ids, aug_codec, timespk_ids, type_id
+ )
+
+ # face
+ face_features = item["face"]
+ face_emb = torch.zeros((codec_len, self.face_size), dtype=torch.float32) # face_emb 长度与 codec_len 相同
+ with open(face_features, 'rb') as f:
+ stat_obj = pickle.load(f)
+ embeddings = stat_obj['embeddings']
+ faceI = stat_obj['faceI']
+ for emb, frameI in zip(embeddings, faceI):
+ fi = int(frameI)
+ if 0 <= fi < codec_len:
+ end = min(fi + 5, codec_len)
+ face_emb[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
+
+ # attention_mask 对应序列长度包括input_id=(sos, <|startofclue|>, clue, <|endofclue|>, text, type_id, timespk_ids, turn_of_speech, speech, eos)
+ attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
+ codec_len = torch.tensor([codec_len], dtype=torch.int32)
+ output = {
+ "input_ids": input_ids,
+ "face_emb": face_emb,
+ "attention_mask": attention_mask,
+ "labels_ids": labels,
+ "text_flag": text_flag,
+ "codec_flag": codec_flag,
+ "timespk_flag": timespk_flag,
+ "codec_len": codec_len,
+ }
+ break
+ return output
+
+ def collator(self, samples: list = None):
+
+ for idx in range(self.retry):
+ badcase_flag = False
+
+ outputs = {}
+ for sample in samples:
+ if sample is None:
+ continue
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ if isinstance(sample[key], (list, tuple)):
+ outputs[key].extend(sample[key])
+ else:
+ outputs[key].append(sample[key])
+
+ for key, data_list in outputs.items():
+ if isinstance(data_list[0], torch.Tensor):
+ if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
+
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(
+ data_list, batch_first=True, padding_value=pad_value
+ )
+
+ if self.batch_type != "example":
+ b, t = outputs["input_ids"].shape
+ if b > 1 and b * t > self.batch_size_token_max:
+ logging.info(
+ f"Warning, {idx}th, b*t: {b}*{t}={b * t} > batch_size_token_max: {self.batch_size_token_max}, drop last data"
+ )
+ samples = samples[:-1]
+ continue
+
+ break
+
+ return outputs
\ No newline at end of file
diff --git a/funcineforge/datasets/index_ds.py b/funcineforge/datasets/index_ds.py
new file mode 100644
index 0000000000000000000000000000000000000000..b057da9617154e644be86b1cfc5ec810f4d2d569
--- /dev/null
+++ b/funcineforge/datasets/index_ds.py
@@ -0,0 +1,151 @@
+import json
+import torch
+import logging
+import numpy as np
+
+
+class FunCineForgeDS(torch.utils.data.Dataset):
+
+ def __init__(self, data_jsonl: str, **kwargs):
+ super().__init__()
+
+ self.max_source_length = kwargs.get("max_source_length", None)
+ self.max_text_length = kwargs.get("max_text_length", None)
+ self.max_token_length = kwargs.get("max_token_length", None)
+ self.ignore_id = kwargs.get("ignore_id", -100)
+ self.frame_shift = kwargs.get("frame_shift", 25)
+ self.timebook_size = kwargs.get("timebook_size", 1500)
+ self.type_map = {"旁白": kwargs.get("pangbai", self.timebook_size),
+ "独白": kwargs.get("dubai", self.timebook_size + 1),
+ "对话": kwargs.get("duihua", self.timebook_size + 2),
+ "多人": kwargs.get("duoren", self.timebook_size + 3),}
+ self.gender_map = {"男": kwargs.get("male", self.timebook_size + 4),
+ "male": kwargs.get("male", self.timebook_size + 4),
+ "女": kwargs.get("female", self.timebook_size + 5),
+ "female": kwargs.get("female", self.timebook_size + 5),}
+ self.age_map = {"儿童": kwargs.get("child", self.timebook_size + 6),
+ "child": kwargs.get("child", self.timebook_size + 6),
+ "青年": kwargs.get("youth", self.timebook_size + 7),
+ "teenager": kwargs.get("youth", self.timebook_size + 7),
+ "中年": kwargs.get("adult", self.timebook_size + 8),
+ "adult": kwargs.get("adult", self.timebook_size + 8),
+ "中老年": kwargs.get("middle", self.timebook_size + 9),
+ "middle-aged": kwargs.get("middle", self.timebook_size + 9),
+ "老年": kwargs.get("elderly", self.timebook_size + 10),
+ "elderly": kwargs.get("elderly", self.timebook_size + 10)}
+ self.speaker_id_start = kwargs.get("speaker_id_start", self.timebook_size + 11)
+
+ load_meta_data_key = kwargs.get("load_meta_data_key").split(",")
+
+ if not (data_jsonl.endswith(".jsonl") or data_jsonl.endswith(".json")):
+ # jsonl list file
+ with open(data_jsonl, encoding="utf-8") as fin:
+ file_list = fin.readlines()
+ logging.info(f"file_list: {file_list}")
+ else:
+ file_list = [data_jsonl]
+
+ contents = []
+ for file_json in file_list:
+ with open(file_json.strip(), encoding="utf-8") as fin:
+ for line in fin:
+ data_dict = json.loads(line.strip())
+ utt = data_dict["utt"]
+ data_type = data_dict.get("type")
+ type_id = self.type_map[data_type] if data_type in self.type_map else 1500
+ data = data_dict["messages"]
+ speech_length = data_dict.get("speech_length", -1)
+ # 2 for startofclue, endofclue
+ text_length = data_dict.get("text_length", -1) + data_dict.get("clue_length", -1) + 2
+ if self.max_token_length is not None and (speech_length > self.max_token_length or speech_length <= 0):
+ logging.info(
+ f"speech_length: {speech_length} > {self.max_token_length}, drop it: {data_dict}"
+ )
+ continue
+ if self.max_text_length is not None and (text_length > self.max_text_length or text_length <= 0):
+ logging.info(
+ f"text_length: {text_length} > {self.max_text_length}, drop it: {data_dict}"
+ )
+ continue
+
+ skip_flag = None
+ roles = {item.get("role") for item in data}
+ for key in load_meta_data_key:
+ if key not in roles:
+ skip_flag = key
+ break
+ if skip_flag is not None:
+ logging.info(
+ f"doesn't have {skip_flag}, drop it: {data_dict}")
+ continue
+
+ contents_i = {}
+ timespk_ids_len = 0
+ for i, item in enumerate(data):
+ role = item["role"]
+ content = item["content"]
+ for key in load_meta_data_key:
+ if role == key:
+ if key == "dialogue":
+ timespk_ids = self.timespk_to_codec(content)
+ timespk_ids_len = len(timespk_ids)
+ if timespk_ids_len == 0:
+ logging.info(f"[WARNING] len of timespk_ids is 0: {data_dict}")
+ contents_i["timespk_ids"] = timespk_ids
+ else:
+ contents_i[role] = content
+ contents_i["utt"] = utt
+ contents_i["type_id"] = type_id
+ # face embs len = speech tokens len, so need * 2;
+ # 4: sos, tos, eos; type_id
+ contents_i["source_len"] = speech_length * 2 + text_length + timespk_ids_len + 4
+ contents_i["speech_len"] = speech_length
+ contents_i["text_len"] = text_length # include clue_length
+ contents.append(contents_i)
+
+ self.contents = contents
+
+ logging.info("total_num of samplers: {}, {}".format(len(self.contents), data_jsonl))
+
+
+ def timespk_to_codec(self, dialogue):
+ # tuple tokens (start, spk, gender, age, end) * n_parts
+ n_parts = len(dialogue)
+ if n_parts == 0:
+ return np.array([], dtype=np.int64)
+ starts = np.array([part["start"] for part in dialogue])
+ durations = np.array([part["duration"] for part in dialogue])
+ speakers = np.array([int(part["spk"]) for part in dialogue])
+ genders = [part["gender"] for part in dialogue]
+ ages = [part["age"] for part in dialogue]
+
+ start_idxs = (starts * self.frame_shift + 1).astype(np.int64)
+ end_idxs = ((starts + durations) * self.frame_shift + 1).astype(np.int64)
+ spk_ids = (self.speaker_id_start + speakers - 1).astype(np.int64)
+ gender_ids = [self.gender_map.get(g, self.ignore_id) for g in genders]
+ age_ids = [self.age_map.get(a, self.ignore_id) for a in ages]
+
+ sequence = np.full(n_parts * 5, self.ignore_id, dtype=np.int64)
+ sequence[0::5] = start_idxs
+ sequence[1::5] = spk_ids
+ sequence[2::5] = gender_ids
+ sequence[3::5] = age_ids
+ sequence[4::5] = end_idxs
+ return sequence
+
+ def __len__(self):
+ return len(self.contents)
+
+ def __getitem__(self, index):
+
+ data = self.contents[index]
+
+ return data
+
+ def get_source_len(self, data_dict):
+ source_len = data_dict.get("source_len", 0)
+ return source_len
+
+ def get_target_len(self, data_dict):
+ target_len = data_dict.get("speech_len", 0)
+ return target_len
\ No newline at end of file
diff --git a/funcineforge/download/__init__.py b/funcineforge/download/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/funcineforge/download/download_model_from_hub.py b/funcineforge/download/download_model_from_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f6581eb56cefd23c6111869a791a258124d431
--- /dev/null
+++ b/funcineforge/download/download_model_from_hub.py
@@ -0,0 +1,220 @@
+import os
+import json
+from omegaconf import OmegaConf, DictConfig
+from funcineforge.download.name_maps_from_hub import name_maps_ms, name_maps_hf, name_maps_openai
+
+def download_model(**kwargs):
+ hub = kwargs.get("hub", "ms")
+ if hub == "ms":
+ kwargs = download_from_ms(**kwargs)
+ elif hub == "hf":
+ kwargs = download_from_hf(**kwargs)
+ elif hub == "openai":
+ model_or_path = kwargs.get("model")
+ if os.path.exists(model_or_path):
+ # local path
+ kwargs["model_path"] = model_or_path
+ kwargs["model"] = "WhisperWarp"
+ else:
+ # model name
+ if model_or_path in name_maps_openai:
+ model_or_path = name_maps_openai[model_or_path]
+ kwargs["model_path"] = model_or_path
+
+ return kwargs
+
+
+def download_from_ms(**kwargs):
+ model_or_path = kwargs.get("model")
+ if model_or_path in name_maps_ms:
+ model_or_path = name_maps_ms[model_or_path]
+ model_revision = kwargs.get("model_revision", "master")
+ if not os.path.exists(model_or_path) and "model_path" not in kwargs:
+ try:
+ model_or_path = get_or_download_model_dir(
+ model_or_path,
+ model_revision,
+ is_training=kwargs.get("is_training"),
+ check_latest=kwargs.get("check_latest", True),
+ )
+ except Exception as e:
+ print(f"Download: {model_or_path} failed!: {e}")
+
+ kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
+
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
+ with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
+ conf_json = json.load(f)
+
+ cfg = {}
+ if "file_path_metas" in conf_json:
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+ # cfg.update(kwargs)
+ cfg = OmegaConf.merge(cfg, kwargs)
+ if "config" in cfg:
+ config = OmegaConf.load(cfg["config"])
+ kwargs = OmegaConf.merge(config, cfg)
+ kwargs["model"] = config["model"]
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")):
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
+ kwargs = OmegaConf.merge(config, kwargs)
+
+ init_param = kwargs.get("init_param", "")
+ if (
+ isinstance(init_param, str)
+ and not os.path.exists(init_param)
+ or isinstance(init_param, (list, tuple))
+ ):
+ init_param_new = init_param
+ if isinstance(init_param, str):
+ init_param = init_param.split(",")
+ for init_param_i in init_param:
+ if not os.path.exists(init_param_i):
+ print(f"init_param: {init_param_i}, does not exist")
+ init_param_i = os.path.join(model_or_path, "model.pt")
+ init_param_new = f"{init_param_new},{init_param_i}"
+ kwargs["init_param"] = init_param_new
+ # assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+ kwargs["model"] = config["model"]
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+ if isinstance(kwargs, DictConfig):
+ kwargs = OmegaConf.to_container(kwargs, resolve=True)
+
+ return kwargs
+
+
+def download_from_hf(**kwargs):
+ model_or_path = kwargs.get("model")
+ if model_or_path in name_maps_hf:
+ model_or_path = name_maps_hf[model_or_path]
+ model_revision = kwargs.get("model_revision", "master")
+ if not os.path.exists(model_or_path) and "model_path" not in kwargs:
+ try:
+ model_or_path = get_or_download_model_dir_hf(
+ model_or_path,
+ model_revision,
+ is_training=kwargs.get("is_training"),
+ check_latest=kwargs.get("check_latest", True),
+ )
+ except Exception as e:
+ print(f"Download: {model_or_path} failed!: {e}")
+
+ kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
+
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
+ with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
+ conf_json = json.load(f)
+
+ cfg = {}
+ if "file_path_metas" in conf_json:
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+ cfg = OmegaConf.merge(cfg, kwargs)
+ # cfg.update(kwargs)
+ if "config" in cfg:
+ config = OmegaConf.load(cfg["config"])
+ kwargs = OmegaConf.merge(config, cfg)
+ kwargs["model"] = config["model"]
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
+ os.path.join(model_or_path, "model.pt")
+ ):
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
+ kwargs = OmegaConf.merge(config, kwargs)
+ init_param = os.path.join(model_or_path, "model.pt")
+ kwargs["init_param"] = init_param
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+ kwargs["model"] = config["model"]
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+ if isinstance(kwargs, DictConfig):
+ kwargs = OmegaConf.to_container(kwargs, resolve=True)
+
+ return kwargs
+
+
+def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
+ print(file_path_metas)
+ if isinstance(file_path_metas, dict):
+ for k, v in file_path_metas.items():
+ if isinstance(v, str):
+ p = os.path.join(model_or_path, v)
+ if os.path.exists(p):
+ cfg[k] = p
+ elif isinstance(v, dict):
+ if k not in cfg:
+ cfg[k] = {}
+ add_file_root_path(model_or_path, v, cfg[k])
+ return cfg
+
+
+def get_or_download_model_dir(
+ model,
+ model_revision=None,
+ is_training=False,
+ check_latest=True,
+):
+ """Get local model directory or download model if necessary.
+
+ Args:
+ model (str): model id or path to local model directory.
+ model_revision (str, optional): model version number.
+ :param is_training:
+ """
+ from modelscope.hub.check_model import check_local_model_is_latest
+ from modelscope.hub.snapshot_download import snapshot_download
+
+ from modelscope.utils.constant import Invoke, ThirdParty
+
+ key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
+
+ if os.path.exists(model) and check_latest:
+ model_cache_dir = model if os.path.isdir(model) else os.path.dirname(model)
+ try:
+ check_local_model_is_latest(
+ model_cache_dir, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funcineforge"}
+ )
+ except:
+ print("could not check the latest version")
+ else:
+ model_cache_dir = snapshot_download(
+ model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funcineforge"}
+ )
+ return model_cache_dir
+
+
+def get_or_download_model_dir_hf(
+ model,
+ model_revision=None,
+ is_training=False,
+ check_latest=True,
+):
+ """Get local model directory or download model if necessary.
+
+ Args:
+ model (str): model id or path to local model directory.
+ model_revision (str, optional): model version number.
+ :param is_training:
+ """
+ from huggingface_hub import snapshot_download
+
+ model_cache_dir = snapshot_download(model)
+ return model_cache_dir
diff --git a/funcineforge/download/file.py b/funcineforge/download/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8a13f8741aa1ba796d1c86eae103c6f05aa3b08
--- /dev/null
+++ b/funcineforge/download/file.py
@@ -0,0 +1,320 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import contextlib
+import os
+import tempfile
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+from typing import Generator, Union
+
+import requests
+from urllib.parse import urlparse
+
+
+def download_from_url(url):
+ result = urlparse(url)
+ file_path = None
+ if result.scheme is not None and len(result.scheme) > 0:
+ storage = HTTPStorage()
+ # bytes
+ data = storage.read(url)
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ file_path = os.path.join(work_dir, os.path.basename(url))
+ with open(file_path, "wb") as fb:
+ fb.write(data)
+ assert file_path is not None, f"failed to download: {url}"
+ return file_path
+
+
+class Storage(metaclass=ABCMeta):
+ """Abstract class of storage.
+
+ All backends need to implement two apis: ``read()`` and ``read_text()``.
+ ``read()`` reads the file as a byte stream and ``read_text()`` reads
+ the file as texts.
+ """
+
+ @abstractmethod
+ def read(self, filepath: str):
+ pass
+
+ @abstractmethod
+ def read_text(self, filepath: str):
+ pass
+
+ @abstractmethod
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ pass
+
+ @abstractmethod
+ def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
+ pass
+
+
+class LocalStorage(Storage):
+ """Local hard disk storage"""
+
+ def read(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, "rb") as f:
+ content = f.read()
+ return content
+
+ def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, "r", encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``write`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ dirname = os.path.dirname(filepath)
+ if dirname and not os.path.exists(dirname):
+ os.makedirs(dirname, exist_ok=True)
+
+ with open(filepath, "wb") as f:
+ f.write(obj)
+
+ def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``write_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ dirname = os.path.dirname(filepath)
+ if dirname and not os.path.exists(dirname):
+ os.makedirs(dirname, exist_ok=True)
+
+ with open(filepath, "w", encoding=encoding) as f:
+ f.write(obj)
+
+ @contextlib.contextmanager
+ def as_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
+ """Only for unified API and do nothing."""
+ yield filepath
+
+
+class HTTPStorage(Storage):
+ """HTTP and HTTPS storage."""
+
+ def read(self, url):
+ # TODO @wenmeng.zwm add progress bar if file is too large
+ r = requests.get(url)
+ r.raise_for_status()
+ return r.content
+
+ def read_text(self, url):
+ r = requests.get(url)
+ r.raise_for_status()
+ return r.text
+
+ @contextlib.contextmanager
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath``.
+
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> storage = HTTPStorage()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with storage.get_local_path('http://path/to/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.read(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def write(self, obj: bytes, url: Union[str, Path]) -> None:
+ raise NotImplementedError("write is not supported by HTTP Storage")
+
+ def write_text(self, obj: str, url: Union[str, Path], encoding: str = "utf-8") -> None:
+ raise NotImplementedError("write_text is not supported by HTTP Storage")
+
+
+class OSSStorage(Storage):
+ """OSS storage."""
+
+ def __init__(self, oss_config_file=None):
+ # read from config file or env var
+ raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
+
+ def read(self, filepath):
+ raise NotImplementedError("OSSStorage.read to be implemented in the future")
+
+ def read_text(self, filepath, encoding="utf-8"):
+ raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
+
+ @contextlib.contextmanager
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath``.
+
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> storage = OSSStorage()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with storage.get_local_path('http://path/to/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.read(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ raise NotImplementedError("OSSStorage.write to be implemented in the future")
+
+ def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
+ raise NotImplementedError("OSSStorage.write_text to be implemented in the future")
+
+
+G_STORAGES = {}
+
+
+class File(object):
+ _prefix_to_storage: dict = {
+ "oss": OSSStorage,
+ "http": HTTPStorage,
+ "https": HTTPStorage,
+ "local": LocalStorage,
+ }
+
+ @staticmethod
+ def _get_storage(uri):
+ assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
+
+ if "://" not in uri:
+ # local path
+ storage_type = "local"
+ else:
+ prefix, _ = uri.split("://")
+ storage_type = prefix
+
+ assert storage_type in File._prefix_to_storage, (
+ f"Unsupported uri {uri}, valid prefixs: " f"{list(File._prefix_to_storage.keys())}"
+ )
+
+ if storage_type not in G_STORAGES:
+ G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
+
+ return G_STORAGES[storage_type]
+
+ @staticmethod
+ def read(uri: str) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ storage = File._get_storage(uri)
+ return storage.read(uri)
+
+ @staticmethod
+ def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ storage = File._get_storage(uri)
+ return storage.read_text(uri)
+
+ @staticmethod
+ def write(obj: bytes, uri: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``write`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ storage = File._get_storage(uri)
+ return storage.write(obj, uri)
+
+ @staticmethod
+ def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``write_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ storage = File._get_storage(uri)
+ return storage.write_text(obj, uri)
+
+ @contextlib.contextmanager
+ def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
+ """Only for unified API and do nothing."""
+ storage = File._get_storage(uri)
+ with storage.as_local_path(uri) as local_path:
+ yield local_path
diff --git a/funcineforge/download/name_maps_from_hub.py b/funcineforge/download/name_maps_from_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..647082ae94093ef758901bfb2cf152c0f36a747d
--- /dev/null
+++ b/funcineforge/download/name_maps_from_hub.py
@@ -0,0 +1,42 @@
+name_maps_ms = {
+ "paraformer": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ "paraformer-zh": "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ "paraformer-en": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+ "paraformer-en-spk": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+ "paraformer-zh-streaming": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
+ "fsmn-vad": "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ "ct-punc": "iic/punc_ct-transformer_cn-en-common-vocab471067-large",
+ "ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+ "fa-zh": "iic/speech_timestamp_prediction-v1-16k-offline",
+ "cam++": "iic/speech_campplus_sv_zh-cn_16k-common",
+ "Whisper-large-v3": "iic/Whisper-large-v3",
+ "Qwen-Audio": "Qwen/Qwen-Audio",
+ "emotion2vec_plus_large": "iic/emotion2vec_plus_large",
+ "emotion2vec_plus_base": "iic/emotion2vec_plus_base",
+ "emotion2vec_plus_seed": "iic/emotion2vec_plus_seed",
+}
+
+name_maps_hf = {
+ "paraformer": "funasr/paraformer-zh",
+ "paraformer-zh": "funasr/paraformer-zh",
+ "paraformer-en": "funasr/paraformer-zh",
+ "paraformer-zh-streaming": "funasr/paraformer-zh-streaming",
+ "fsmn-vad": "funasr/fsmn-vad",
+ "ct-punc": "funasr/ct-punc",
+ "ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+ "fa-zh": "funasr/fa-zh",
+ "cam++": "funasr/campplus",
+ "iic/emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
+ "iic/emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
+ "iic/emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
+}
+
+name_maps_openai = {
+ "Whisper-base.en": "base.en",
+ "Whisper-base": "base",
+ "Whisper-large": "large",
+ "Whisper-large-v1": "large-v1",
+ "Whisper-large-v2": "large-v2",
+ "Whisper-large-v3": "large-v3",
+ "Whisper-large-v3-turbo": "turbo",
+}
diff --git a/funcineforge/face/__init__.py b/funcineforge/face/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f0ad858946a8a78339eda53755de65884f0d764
--- /dev/null
+++ b/funcineforge/face/__init__.py
@@ -0,0 +1 @@
+from .face_recognition import FaceRecIR101
\ No newline at end of file
diff --git a/funcineforge/face/face_recognition.py b/funcineforge/face/face_recognition.py
new file mode 100644
index 0000000000000000000000000000000000000000..702b15a3a26633f8c353cb8f668abc610beed7e1
--- /dev/null
+++ b/funcineforge/face/face_recognition.py
@@ -0,0 +1,16 @@
+def FaceRecIR101(init_param_path, **kwargs):
+ """
+ Face embeddings extraction with CurricularFace pretrained model.
+ Reference:
+ - https://modelscope.cn/models/iic/cv_ir101_facerecognition_cfglint
+ """
+ import onnxruntime
+ options = onnxruntime.SessionOptions()
+ options.intra_op_num_threads = 8
+ options.inter_op_num_threads = 8
+ ort_session = onnxruntime.InferenceSession(
+ init_param_path,
+ sess_options=options,
+ providers=['CPUExecutionProvider']
+ )
+ return ort_session
diff --git a/funcineforge/models/__init__.py b/funcineforge/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..effa113c30b8bd759a108a61c3c47a99ed9c7a94
--- /dev/null
+++ b/funcineforge/models/__init__.py
@@ -0,0 +1,5 @@
+from .specaug.specaug import SpecAug as FunCineForgeSpecAug
+from .language_model import FunCineForgeLM
+from .causal_hifigan import CausalHifiGan
+from .flow_matching_model import CosyVoiceFlowMatching
+from .inference_model import FunCineForgeInferModel
\ No newline at end of file
diff --git a/funcineforge/models/causal_hifigan.py b/funcineforge/models/causal_hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9c07dd8ff99da68521b8eee8aeefcea8a9c79c8
--- /dev/null
+++ b/funcineforge/models/causal_hifigan.py
@@ -0,0 +1,834 @@
+# Copyright 2023 KaiHu
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""HIFI-GAN"""
+
+from typing import Dict
+from typing import Tuple, List
+
+import numpy as np
+from scipy.signal import get_window
+import torch
+import torchaudio
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.utils import remove_weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+from torch.nn.utils.parametrizations import weight_norm
+import logging
+from funcineforge.utils.device_funcs import to_device
+import os
+from torch.nn.utils.rnn import pad_sequence
+from funcineforge.models.utils import dtype_map
+from funcineforge.models.modules.hifigan import init_weights
+from funcineforge.models.modules.hifigan.activations import Snake
+
+
+class LookRightConv1d(torch.nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super(LookRightConv1d, self).__init__(in_channels, out_channels,
+ kernel_size, stride,
+ padding=0, dilation=dilation,
+ groups=groups, bias=bias,
+ padding_mode=padding_mode,
+ device=device, dtype=dtype)
+ assert stride == 1
+ self.causal_padding = kernel_size - 1
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+ if context.size(2) == 0:
+ x = F.pad(x, (0, self.causal_padding), value=0.0)
+ else:
+ assert context.size(2) == self.causal_padding
+ x = torch.concat([x, context], dim=2)
+ x = super(LookRightConv1d, self).forward(x)
+ return x
+
+class LookLeftConv1d(torch.nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super(LookLeftConv1d, self).__init__(in_channels, out_channels,
+ kernel_size, stride,
+ padding=0, dilation=dilation,
+ groups=groups, bias=bias,
+ padding_mode=padding_mode,
+ device=device, dtype=dtype)
+ assert stride == 1 and dilation == 1
+ self.causal_padding = kernel_size - 1
+
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+ if cache.size(2) == 0:
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
+ else:
+ assert cache.size(2) == self.causal_padding
+ x = torch.concat([cache, x], dim=2)
+ # NOTE 兼容kernel_size=1的情况
+ if self.causal_padding == 0:
+ cache_new = x[:, :, :0]
+ else:
+ cache_new = x[:, :, -self.causal_padding:]
+ x = super(LookLeftConv1d, self).forward(x)
+ return x, cache_new
+
+
+class CausalConvRNNF0Predictor(nn.Module):
+ def __init__(self,
+ num_class: int = 1,
+ in_channels: int = 80,
+ cond_channels: int = 512
+ ):
+ super().__init__()
+
+ self.num_class = num_class
+ self.condnet = nn.Sequential(
+ weight_norm(
+ LookRightConv1d(in_channels, cond_channels, kernel_size=4)
+ ),
+ nn.ELU(),
+ weight_norm(
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
+ ),
+ nn.ELU(),
+ weight_norm(
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
+ ),
+ nn.ELU(),
+ weight_norm(
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
+ ),
+ nn.ELU(),
+ weight_norm(
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
+ ),
+ nn.ELU(),
+ )
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0, 0), finalize: bool = True) -> torch.Tensor:
+ if finalize is False:
+ x, context = x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:]
+ else:
+ x, context = x, x[:, :, :0]
+ x = self.condnet[0](x, context)
+ x = self.condnet[1](x)
+ if cache.size(0) != 0:
+ x, cache[0] = self.condnet[2](x, cache[0])
+ else:
+ x, _ = self.condnet[2](x)
+ x = self.condnet[3](x)
+ if cache.size(0) != 0:
+ x, cache[1] = self.condnet[4](x, cache[1])
+ else:
+ x, _ = self.condnet[4](x)
+ x = self.condnet[5](x)
+ if cache.size(0) != 0:
+ x, cache[2] = self.condnet[6](x, cache[2])
+ else:
+ x, _ = self.condnet[6](x)
+ x = self.condnet[7](x)
+ if cache.size(0) != 0:
+ x, cache[3] = self.condnet[8](x, cache[3])
+ else:
+ x, _ = self.condnet[8](x)
+ x = self.condnet[9](x)
+ x = x.transpose(1, 2)
+ x = torch.abs(self.classifier(x).squeeze(-1))
+ return x, cache
+
+ def init_cache(self, device):
+ return torch.zeros(4, 1, 512, 2).to(device)
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ try:
+ remove_weight_norm(self.condnet[0])
+ remove_weight_norm(self.condnet[2])
+ remove_weight_norm(self.condnet[4])
+ remove_weight_norm(self.condnet[6])
+ remove_weight_norm(self.condnet[8])
+ except:
+ remove_parametrizations(self.condnet[0], 'weight')
+ remove_parametrizations(self.condnet[2], 'weight')
+ remove_parametrizations(self.condnet[4], 'weight')
+ remove_parametrizations(self.condnet[6], 'weight')
+ remove_parametrizations(self.condnet[8], 'weight')
+
+
+class LookLeftConvTranspose1d(torch.nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super(LookLeftConvTranspose1d, self).__init__(in_channels, out_channels,
+ kernel_size, 1,
+ padding=0, dilation=dilation,
+ groups=groups, bias=bias,
+ padding_mode=padding_mode,
+ device=device, dtype=dtype)
+ assert dilation == 1 and stride != 1
+ self.causal_padding = kernel_size - 1
+ self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
+
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = self.upsample(x)
+ if cache.size(2) == 0:
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
+ else:
+ assert cache.size(2) == self.causal_padding
+ x = torch.concat([cache, x], dim=2)
+ cache_new = x[:, :, -self.causal_padding:]
+ x = super(LookLeftConvTranspose1d, self).forward(x)
+ return x, cache_new
+
+
+class LookLeftConv1dWithStride(torch.nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super(LookLeftConv1dWithStride, self).__init__(in_channels, out_channels,
+ kernel_size, stride,
+ padding=0, dilation=dilation,
+ groups=groups, bias=bias,
+ padding_mode=padding_mode,
+ device=device, dtype=dtype)
+ assert stride != 1 and dilation == 1
+ assert kernel_size % stride == 0
+ self.causal_padding = stride - 1
+
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+ if cache.size(2) == 0:
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
+ else:
+ assert cache.size(2) == self.causal_padding
+ x = torch.concat([cache, x], dim=2)
+ cache_new = x[:, :, -self.causal_padding:]
+ x = super(LookLeftConv1dWithStride, self).forward(x)
+ return x, cache_new
+
+
+class LookLeftConv1dWithDilation(torch.nn.Conv1d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super(LookLeftConv1dWithDilation, self).__init__(in_channels, out_channels,
+ kernel_size, stride,
+ padding=0, dilation=dilation,
+ groups=groups, bias=bias,
+ padding_mode=padding_mode,
+ device=device, dtype=dtype)
+ # NOTE(lyuxiang.lx) 这个causal_padding仅在kernel_size为奇数时才成立
+ assert kernel_size // 2 * dilation * 2 == int((kernel_size * dilation - dilation) / 2) * 2
+ self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2
+
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+ if cache.size(2) == 0:
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
+ else:
+ assert cache.size(2) == self.causal_padding
+ x = torch.concat([cache, x], dim=2)
+ cache_new = x[:, :, -self.causal_padding:]
+ x = super(LookLeftConv1dWithDilation, self).forward(x)
+ return x, cache_new
+
+
+class ResBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN/BigVGAN."""
+ def __init__(
+ self,
+ channels: int = 512,
+ kernel_size: int = 3,
+ dilations: List[int] = [1, 3, 5],
+ ):
+ super(ResBlock, self).__init__()
+ self.convs1 = nn.ModuleList()
+ self.convs2 = nn.ModuleList()
+
+ for dilation in dilations:
+ self.convs1.append(
+ weight_norm(
+ LookLeftConv1dWithDilation(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation
+ ) if dilation != 1 else
+ LookLeftConv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation
+ )
+ )
+ )
+ self.convs2.append(
+ weight_norm(
+ LookLeftConv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1
+ )
+ )
+ )
+ self.convs1.apply(init_weights)
+ self.convs2.apply(init_weights)
+ self.activations1 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs1))
+ ])
+ self.activations2 = nn.ModuleList([
+ Snake(channels, alpha_logscale=False)
+ for _ in range(len(self.convs2))
+ ])
+
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)) -> torch.Tensor:
+ for idx in range(len(self.convs1)):
+ xt = self.activations1[idx](x)
+ xt, _ = self.convs1[idx](xt)
+ xt = self.activations2[idx](xt)
+ xt, _ = self.convs2[idx](xt)
+ x = xt + x
+ return x, cache
+
+ def remove_weight_norm(self):
+ for idx in range(len(self.convs1)):
+ try:
+ remove_weight_norm(self.convs1[idx])
+ remove_weight_norm(self.convs2[idx])
+ except:
+ remove_parametrizations(self.convs1[idx], 'weight')
+ remove_parametrizations(self.convs2[idx], 'weight')
+
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.flag_for_pulse = flag_for_pulse
+ self.upsample_scale = upsample_scale
+ self.rand_ini = torch.rand(1, 9)
+ self.rand_ini[:, 0] = 0
+ self.sine_waves = torch.rand(1, 300 * 24000, 9)
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ def _f02sine(self, f0_values):
+ """ f0_values: (batchsize, length, dim)
+ where dim indicates fundamental tone and overtones
+ """
+ # convert to F0 in rad. The interger part n can be ignored
+ # because 2 * np.pi * n doesn't affect phase
+ rad_values = (f0_values / self.sampling_rate) % 1
+
+ # initial phase noise (no noise for fundamental component)
+ rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
+
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+ if not self.flag_for_pulse:
+# # for normal case
+
+# # To prevent torch.cumsum numerical overflow,
+# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+# # Buffer tmp_over_one_idx indicates the time step to add -1.
+# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+# tmp_over_one = torch.cumsum(rad_values, 1) % 1
+# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
+# cumsum_shift = torch.zeros_like(rad_values)
+# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
+ scale_factor=1/self.upsample_scale,
+ mode="linear").transpose(1, 2)
+
+# tmp_over_one = torch.cumsum(rad_values, 1) % 1
+# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
+# cumsum_shift = torch.zeros_like(rad_values)
+# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
+ scale_factor=self.upsample_scale, mode="nearest").transpose(1, 2)
+ sines = torch.sin(phase)
+
+ else:
+ # If necessary, make sure that the first time step of every
+ # voiced segments is sin(pi) or cos(0)
+ # This is used for pulse-train generation
+
+ # identify the last time step in unvoiced segments
+ uv = self._f02uv(f0_values)
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
+ uv_1[:, -1, :] = 1
+ u_loc = (uv < 1) * (uv_1 > 0)
+
+ # get the instantanouse phase
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
+ # different batch needs to be processed differently
+ for idx in range(f0_values.shape[0]):
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+ # stores the accumulation of i.phase within
+ # each voiced segments
+ tmp_cumsum[idx, :, :] = 0
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
+ # within the previous voiced segment.
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+ # get the sines
+ sines = torch.cos(i_phase * 2 * np.pi)
+ return sines
+
+ def forward(self, f0):
+ """ sine_tensor, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output sine_tensor: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+ """
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
+ device=f0.device)
+ # fundamental component
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
+
+ # generate sine waveforms
+ sine_waves = self._f02sine(fn) * self.sine_amp
+
+ # generate uv signal
+ # uv = torch.ones(f0.shape)
+ # uv = uv * (f0 > self.voiced_threshold)
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+ self.uv = torch.rand(1, 300 * 24000, 1)
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ with torch.no_grad():
+ sine_wavs, uv, _ = self.l_sin_gen(x)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
+ return sine_merge, noise, uv
+
+
+class CausalHiFTGenerator(nn.Module):
+ """
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
+ https://arxiv.org/abs/2309.09493
+ """
+ def __init__(
+ self,
+ in_channels: int = 80,
+ base_channels: int = 512,
+ nb_harmonics: int = 8,
+ sampling_rate: int = 22050,
+ nsf_alpha: float = 0.1,
+ nsf_sigma: float = 0.003,
+ nsf_voiced_threshold: float = 10,
+ upsample_rates: List[int] = [8, 8],
+ upsample_kernel_sizes: List[int] = [16, 16],
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ source_resblock_kernel_sizes: List[int] = [7, 11],
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
+ lrelu_slope: float = 0.1,
+ audio_limit: float = 0.99,
+ f0_predictor: torch.nn.Module = None,
+ ):
+ super(CausalHiFTGenerator, self).__init__()
+
+ self.out_channels = 1
+ self.nb_harmonics = nb_harmonics
+ self.sampling_rate = sampling_rate
+ self.istft_params = istft_params
+ self.lrelu_slope = lrelu_slope
+ self.audio_limit = audio_limit
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=sampling_rate,
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
+ harmonic_num=nb_harmonics,
+ sine_amp=nsf_alpha,
+ add_noise_std=nsf_sigma,
+ voiced_threshod=nsf_voiced_threshold)
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"], mode='nearest')
+
+ self.conv_pre = weight_norm(
+ LookRightConv1d(in_channels, base_channels, 5, 1)
+ )
+
+ # Up
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ LookLeftConvTranspose1d(
+ base_channels // (2**i),
+ base_channels // (2**(i + 1)),
+ k,
+ u
+ )
+ )
+ )
+
+ # Down
+ self.source_downs = nn.ModuleList()
+ self.source_resblocks = nn.ModuleList()
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
+ downsample_cum_rates = np.cumprod(downsample_rates)
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
+ if u == 1:
+ self.source_downs.append(
+ LookLeftConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
+ )
+ else:
+ self.source_downs.append(
+ LookLeftConv1dWithStride(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
+ )
+
+ self.source_resblocks.append(
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = base_channels // (2**(i + 1))
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(ResBlock(ch, k, d))
+
+ self.conv_post = weight_norm(LookLeftConv1d(ch, istft_params["n_fft"] + 2, 7, 1))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
+ self.f0_predictor = f0_predictor
+ # f0回退3帧,hift回退5帧
+ self.context_size = 8
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ try:
+ remove_weight_norm(l)
+ except:
+ remove_parametrizations(l, 'weight')
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ try:
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+ except:
+ remove_parametrizations(self.conv_pre, 'weight')
+ remove_parametrizations(self.conv_post, 'weight')
+ self.f0_predictor.remove_weight_norm()
+ for l in self.source_resblocks:
+ l.remove_weight_norm()
+
+ def _stft(self, x):
+ spec = torch.stft(
+ x,
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
+ return_complex=True)
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
+ return spec[..., 0], spec[..., 1]
+
+ def _istft(self, magnitude, phase):
+ magnitude = torch.clip(magnitude, max=1e2)
+ real = magnitude * torch.cos(phase)
+ img = magnitude * torch.sin(phase)
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+ return inverse_transform
+
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(0, 0, 0), finalize: bool = True) -> torch.Tensor:
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
+ # NOTE(lyuxiang.lx) 回退4帧
+ if finalize is False:
+ s_stft_real, s_stft_imag = s_stft_real[:, :, :-int(480 * 4 / self.istft_params["hop_len"])], s_stft_imag[:, :, :-int(480 * 4 / self.istft_params["hop_len"])]
+ x = self.conv_pre(x[:, :, :-4], x[:, :, -4:])
+ else:
+ x = self.conv_pre(x)
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x, _ = self.ups[i](x)
+
+ if i == self.num_upsamples - 1:
+ x = self.reflection_pad(x)
+
+ # fusion
+ si, _ = self.source_downs[i](s_stft)
+ si, _ = self.source_resblocks[i](si)
+ x = x + si
+
+ xs = None
+ for j in range(self.num_kernels):
+ this_xs, _ = self.resblocks[i * self.num_kernels + j](x)
+ if xs is None:
+ xs = this_xs
+ else:
+ xs += this_xs
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x, _ = self.conv_post(x)
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
+
+ x = self._istft(magnitude, phase)
+ # NOTE(lyuxiang.lx) 回退1帧
+ if finalize is False:
+ x = x[:, :-480]
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
+ return x
+
+ @torch.inference_mode()
+ def inference(self, speech_feat: torch.Tensor, f0_cpu: bool = False, finalize: bool = True) -> torch.Tensor:
+ # mel->f0->source
+ if f0_cpu is True:
+ self.f0_predictor.to('cpu')
+ f0, _ = self.f0_predictor(speech_feat.cpu(), finalize=finalize)
+ f0 = f0.to(speech_feat.device)
+ else:
+ self.f0_predictor.to(speech_feat.device)
+ f0, _ = self.f0_predictor(speech_feat, finalize=finalize)
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+ s, _, _ = self.m_source(s)
+ s = s.transpose(1, 2)
+ if finalize is False:
+ generated_speech = self.decode(speech_feat[:, :, :-3], s, finalize=finalize)
+ else:
+ generated_speech = self.decode(speech_feat, s, finalize=finalize)
+ return generated_speech, []
+
+
+class CausalHifiGan(nn.Module):
+ """HIFIGAN-style vocoders (generator [stack of time-level-upsampling blocks] + discriminator).
+ NSF-HIFIGAN, HiFTNet Optional.
+ """
+
+ def __init__(
+ self,
+ CausalHiFTGenerator_conf: dict = {},
+ CausalConvRNNF0Predictor_conf: dict = {},
+ sample_rate: float = 24000,
+ **kwargs
+ ):
+ super().__init__()
+ self.generator = CausalHiFTGenerator(**CausalHiFTGenerator_conf)
+ self.generator.f0_predictor = CausalConvRNNF0Predictor(**CausalConvRNNF0Predictor_conf)
+ self.generator.remove_weight_norm()
+ self.sample_rate = sample_rate
+
+ def inference_prepare(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ **kwargs,
+ ):
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ feat_list = []
+ feat_len_list = []
+ for i, feat in enumerate(data_in):
+ if isinstance(feat, str) and os.path.exists(feat):
+ feat = np.load(feat)
+ if isinstance(feat, np.ndarray):
+ feat = torch.from_numpy(feat)
+
+ feat_list.append(feat)
+ feat_len_list.append(feat.shape[0])
+
+ batch = {
+ "x": pad_sequence(feat_list, batch_first=True),
+ "x_lengths": torch.tensor(feat_len_list, dtype=torch.int64),
+ }
+ batch = to_device(batch, kwargs["device"])
+
+ return batch
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ f0_cpu: bool = True,
+ finalize: bool = True,
+ **kwargs,
+ ) -> torch.Tensor:
+ """Run inference.
+
+ Args:
+ x (torch.Tensor): input representation, B x T x C
+
+ Returns:
+ Dict[str, Tensor]:
+ * recon_speech (Tensor): Reconstructed waveform tensor (B, T_wav).
+
+ """
+ uttid = key[0]
+ batch = self.inference_prepare(data_in, data_lengths, key, **kwargs)
+ voc_dtype = dtype_map[kwargs.get("voc_dtype", "fp32")]
+ x = batch["x"].to(voc_dtype)
+ recon_speech = self.generator.inference(x.transpose(1, 2), f0_cpu=f0_cpu, finalize=finalize)[0].squeeze(1)
+ recon_speech = recon_speech.float()
+ logging.info(f"{uttid}: wav lengths {recon_speech.shape[1]}")
+
+ output_dir = kwargs.get("output_dir", None)
+ output_sr = kwargs.get("output_sr", None)
+ if output_dir is not None:
+ wav_out_dir = os.path.join(output_dir, "wav")
+ os.makedirs(wav_out_dir, exist_ok=True)
+ wav_sr = self.sample_rate
+ if output_sr is not None and output_sr != self.sample_rate:
+ recon_speech = torchaudio.functional.resample(
+ recon_speech,
+ orig_freq=self.sample_rate,
+ new_freq=output_sr
+ )
+ wav_sr = output_sr
+ torchaudio.save(
+ os.path.join(wav_out_dir, f"{key[0]}.wav"), recon_speech.cpu(),
+ sample_rate=wav_sr, encoding='PCM_S', bits_per_sample=16
+ )
+
+ return recon_speech
\ No newline at end of file
diff --git a/funcineforge/models/flow_matching_model.py b/funcineforge/models/flow_matching_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..53736709addd0aa1a0d561469b60af9174b24423
--- /dev/null
+++ b/funcineforge/models/flow_matching_model.py
@@ -0,0 +1,514 @@
+import os.path
+
+import torch
+import torch.nn as nn
+from typing import Dict
+import logging
+from librosa.filters import mel as librosa_mel_fn
+import torch.nn.functional as F
+from funcineforge.models.utils.nets_utils import make_pad_mask
+from funcineforge.utils.device_funcs import to_device
+import numpy as np
+from funcineforge.utils.load_utils import extract_campp_xvec
+import time
+from funcineforge.models.utils import dtype_map
+from funcineforge.utils.hinter import hint_once
+from funcineforge.models.utils.masks import add_optional_chunk_mask
+from .modules.dit_flow_matching.dit_model import DiT
+
+
+class Audio2Mel(nn.Module):
+ def __init__(
+ self,
+ n_fft=1024,
+ hop_length=256,
+ win_length=1024,
+ sampling_rate=22050,
+ n_mel_channels=80,
+ mel_fmin=0.0,
+ mel_fmax=None,
+ center=False,
+ device='cuda',
+ feat_type="power_log",
+ ):
+ super().__init__()
+ ##############################################
+ # FFT Parameters
+ ##############################################
+ window = torch.hann_window(win_length, device=device).float()
+ mel_basis = librosa_mel_fn(
+ sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
+ )
+ mel_basis = torch.from_numpy(mel_basis).float().to(device)
+ self.register_buffer("mel_basis", mel_basis)
+ self.register_buffer("window", window)
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.sampling_rate = sampling_rate
+ self.n_mel_channels = n_mel_channels
+ self.mel_fmax = mel_fmax
+ self.center = center
+ self.feat_type = feat_type
+
+ def forward(self, audioin):
+ p = (self.n_fft - self.hop_length) // 2
+ audio = F.pad(audioin, (p, p), "reflect").squeeze(1)
+ fft = torch.stft(
+ audio,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ if self.feat_type == "mag_log10":
+ power_spec = torch.sqrt(torch.pow(fft.imag, 2) + torch.pow(fft.real, 2))
+ mel_output = torch.matmul(self.mel_basis, power_spec)
+ return torch.log10(torch.clamp(mel_output, min=1e-5))
+ power_spec = torch.pow(fft.imag, 2) + torch.pow(fft.real, 2)
+ mel_spec = torch.matmul(self.mel_basis, torch.sqrt(power_spec + 1e-9))
+ return self.spectral_normalize(mel_spec)
+
+ @classmethod
+ def spectral_normalize(cls, spec, C=1, clip_val=1e-5):
+ output = cls.dynamic_range_compression(spec, C, clip_val)
+ return output
+
+ @classmethod
+ def spectral_de_normalize_torch(cls, spec, C=1, clip_val=1e-5):
+ output = cls.dynamic_range_decompression(spec, C, clip_val)
+ return output
+
+ @staticmethod
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+ @staticmethod
+ def dynamic_range_decompression(x, C=1):
+ return torch.exp(x) / C
+
+
+class LookaheadBlock(nn.Module):
+ def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
+ super().__init__()
+ self.channels = channels
+ self.pre_lookahead_len = pre_lookahead_len
+ self.conv1 = nn.Conv1d(
+ in_channels, channels,
+ kernel_size=pre_lookahead_len+1,
+ stride=1, padding=0,
+ )
+ self.conv2 = nn.Conv1d(
+ channels, in_channels,
+ kernel_size=3, stride=1, padding=0,
+ )
+
+ def forward(self, inputs, ilens, context: torch.Tensor = torch.zeros(0, 0, 0)):
+ """
+ inputs: (batch_size, seq_len, channels)
+ """
+ outputs = inputs.transpose(1, 2).contiguous()
+ context = context.transpose(1, 2).contiguous()
+ # look ahead
+ if context.size(2) == 0:
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0)
+ else:
+ assert context.size(2) == self.pre_lookahead_len
+ outputs = torch.concat([outputs, context], dim=2)
+ outputs = F.leaky_relu(self.conv1(outputs))
+ # outputs
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0)
+ outputs = self.conv2(outputs)
+ outputs = outputs.transpose(1, 2).contiguous()
+
+ mask = (~make_pad_mask(ilens).unsqueeze(-1).to(inputs.device))
+ # residual connection
+ outputs = (outputs + inputs) * mask
+
+ return outputs, ilens
+
+
+class CosyVoiceFlowMatching(nn.Module):
+ def __init__(
+ self,
+ codebook_size: int,
+ model_size: int,
+ xvec_size: int = 198,
+ dit_conf: Dict = {},
+ mel_feat_conf: Dict = {},
+ prompt_conf: Dict = None,
+ **kwargs):
+ super().__init__()
+
+ # feat related
+ self.feat_token_ratio = kwargs.get("feat_token_ratio", None)
+ try:
+ self.mel_extractor = Audio2Mel(**mel_feat_conf)
+ self.sample_rate = self.mel_extractor.sampling_rate
+ except:
+ self.mel_extractor = None
+ self.sample_rate = 24000
+ self.mel_norm_type = kwargs.get("mel_norm_type", None)
+ self.num_mels = num_mels = mel_feat_conf["n_mel_channels"]
+ self.token_rate = kwargs.get("token_rate", 25)
+ self.model_dtype = kwargs.get("model_dtype", "fp32")
+ self.codebook_size = codebook_size
+
+ # condition related
+ self.prompt_conf = prompt_conf
+ if self.prompt_conf is not None:
+ self.prompt_masker = self.build_prompt_masker()
+
+ # codec related
+ self.codec_embedder = nn.Embedding(codebook_size, num_mels)
+ lookahead_length = kwargs.get("lookahead_length", 4)
+ self.lookahead_conv1d = LookaheadBlock(num_mels, model_size, lookahead_length)
+
+ # spk embed related
+ if xvec_size is not None:
+ self.xvec_proj = torch.nn.Linear(xvec_size, num_mels)
+
+ # dit model related
+ self.dit_conf = dit_conf
+ self.dit_model = DiT(**dit_conf)
+
+ self.training_cfg_rate = kwargs.get("training_cfg_rate", 0)
+ self.only_mask_loss = kwargs.get("only_mask_loss", True)
+
+ # NOTE fm需要右看的下文
+ self.context_size = self.lookahead_conv1d.pre_lookahead_len
+
+ def build_prompt_masker(self):
+ prompt_type = self.prompt_conf.get("prompt_type", "free")
+ if prompt_type == "prefix":
+ from funcineforge.models.utils.mask_along_axis import MaskTailVariableMaxWidth
+ masker = MaskTailVariableMaxWidth(
+ mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"],
+ )
+ else:
+ raise NotImplementedError
+
+ return masker
+
+ @staticmethod
+ def norm_spk_emb(xvec):
+ xvec_mask = (~xvec.norm(dim=-1).isnan()) * (~xvec.norm(dim=-1).isinf())
+ xvec = xvec * xvec_mask.unsqueeze(-1)
+ xvec = xvec.mean(dim=1)
+ xvec = F.normalize(xvec, dim=1)
+
+ return xvec
+
+ def select_target_prompt(self, y: torch.Tensor, y_lengths: torch.Tensor):
+ # cond_mask: 1, 1, 1, ..., 0, 0, 0
+ cond_mask = self.prompt_masker(y, y_lengths, return_mask=True)
+
+ return cond_mask
+
+ @torch.no_grad()
+ def normalize_mel_feat(self, feat, feat_lengths):
+ # feat in B,T,D
+ if self.mel_norm_type == "mean_std":
+ max_length = feat.shape[1]
+ mask = (~make_pad_mask(feat_lengths, maxlen=max_length))
+ mask = mask.unsqueeze(-1).to(feat)
+ mean = ((feat * mask).sum(dim=(1, 2), keepdim=True) /
+ (mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1]))
+ var = (((feat - mean)**2 * mask).sum(dim=(1, 2), keepdim=True) /
+ (mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1] - 1)) # -1 for unbiased estimation
+ std = torch.sqrt(var)
+ feat = (feat - mean) / std
+ feat = feat * mask
+ return feat
+ if self.mel_norm_type == "min_max":
+ bb, tt, dd = feat.shape
+ mask = (~make_pad_mask(feat_lengths, maxlen=tt))
+ mask = mask.unsqueeze(-1).to(feat)
+ feat_min = (feat * mask).reshape([bb, tt * dd]).min(dim=1, keepdim=True).values.unsqueeze(-1)
+ feat_max = (feat * mask).reshape([bb, tt * dd]).max(dim=1, keepdim=True).values.unsqueeze(-1)
+ feat = (feat - feat_min) / (feat_max - feat_min)
+ # noise ~ N(0, I), P(x >= 3sigma) = 0.001, 3 is enough.
+ feat = (feat * 3) * mask # feat in [-3, 3]
+ return feat
+ else:
+ raise NotImplementedError
+
+ @torch.no_grad()
+ def extract_feat(self, y: torch.Tensor, y_lengths: torch.Tensor):
+ mel_extractor = self.mel_extractor.float()
+ feat = mel_extractor(y)
+ feat = feat.transpose(1, 2)
+ feat_lengths = (y_lengths / self.mel_extractor.hop_length).to(y_lengths)
+ if self.mel_norm_type is not None:
+ feat = self.normalize_mel_feat(feat, feat_lengths)
+ return feat, feat_lengths
+
+ def load_data(self, contents: dict, **kwargs):
+ fm_use_prompt = kwargs.get("fm_use_prompt", True)
+
+ # codec
+ codec = contents["codec"]
+ if isinstance(codec, np.ndarray):
+ codec = torch.from_numpy(codec)
+ # codec = torch.from_numpy(codec)[None, :]
+ codec_lengths = torch.tensor([codec.shape[1]], dtype=torch.int64)
+
+ # prompt codec (optional)
+ prompt_codec = kwargs.get("prompt_codec", None)
+ prompt_codec_lengths = None
+ if prompt_codec is not None and fm_use_prompt:
+ if isinstance(prompt_codec, str) and os.path.exists(prompt_codec):
+ prompt_codec = np.load(prompt_codec)
+ if isinstance(prompt_codec, np.ndarray):
+ prompt_codec = torch.from_numpy(prompt_codec)[None, :]
+ prompt_codec_lengths = torch.tensor([prompt_codec.shape[1]], dtype=torch.int64)
+ else:
+ prompt_codec = None
+ spk_emb = kwargs.get("spk_emb", None)
+ spk_emb_lengths = None
+ if spk_emb is not None:
+ if isinstance(spk_emb, str) and os.path.exists(spk_emb):
+ spk_emb = np.load(spk_emb)
+ if isinstance(spk_emb, np.ndarray):
+ spk_emb = torch.from_numpy(spk_emb)[None, :]
+ spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
+
+ # prompt wav as condition
+ prompt_wav = contents["vocal"]
+ prompt_wav_lengths = None
+ if prompt_wav is not None and fm_use_prompt and os.path.exists(prompt_wav):
+ if prompt_wav.endswith(".npy"):
+ spk_emb = np.load(prompt_wav)
+ spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
+ else:
+ spk_emb = extract_campp_xvec(prompt_wav, **kwargs)
+ spk_emb = torch.from_numpy(spk_emb)
+ spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
+ # prompt_wav = load_audio_text_image_video(prompt_wav, fs=self.sample_rate)
+ # prompt_wav = prompt_wav[None, :]
+ # prompt_wav_lengths = torch.tensor([prompt_wav.shape[1]], dtype=torch.int64)
+ else:
+ logging.info("[error] prompt_wav is None or not path or path not exists! Please provide the correct speaker embedding.")
+
+ output = {
+ "codec": codec,
+ "codec_lengths": codec_lengths,
+ "prompt_codec": prompt_codec,
+ "prompt_codec_lengths": prompt_codec_lengths,
+ "prompt_wav": None,
+ "prompt_wav_lengths": None,
+ "xvec": spk_emb,
+ "xvec_lengths": spk_emb_lengths,
+ }
+
+ return output
+
+ @torch.no_grad()
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ chunk_size: int = -1,
+ finalize: bool = True,
+ **kwargs,
+ ):
+ uttid = key[0]
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ batch = self.load_data(data_in[0], **kwargs)
+ batch = to_device(batch, kwargs["device"])
+ batch.update({'finalize': finalize, 'chunk_size': chunk_size})
+ feat = self._inference(**batch, **kwargs)
+ feat = feat.float()
+ logging.info(f"{uttid}: feat lengths {feat.shape[1]}")
+
+ return feat
+
+ @torch.no_grad()
+ def _inference(
+ self,
+ codec, codec_lengths,
+ prompt_codec=None, prompt_codec_lengths=None,
+ prompt_wav=None, prompt_wav_lengths=None,
+ xvec=None, xvec_lengths=None, chunk_size=-1, finalize=False,
+ **kwargs
+ ):
+ fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")]
+ rand_xvec = None
+ if xvec is not None:
+ if xvec.dim() == 2:
+ xvec = xvec.unsqueeze(1)
+ xvec_lens = torch.ones_like(xvec_lengths)
+ rand_xvec = self.norm_spk_emb(xvec)
+ self.xvec_proj.to(fm_dtype)
+ rand_xvec = self.xvec_proj(rand_xvec.to(fm_dtype))
+ rand_xvec = rand_xvec.unsqueeze(1)
+
+ if (codec >= self.codebook_size).any():
+ new_codec = codec[codec < self.codebook_size].unsqueeze(0)
+ logging.info(f"remove out-of-range token for FM: from {codec.shape[1]} to {new_codec.shape[1]}.")
+ codec_lengths = codec_lengths - (codec.shape[1] - new_codec.shape[1])
+ codec = new_codec
+ if prompt_codec is not None:
+ codec, codec_lengths = self.concat_prompt(prompt_codec, prompt_codec_lengths, codec, codec_lengths)
+ mask = (codec != -1).float().unsqueeze(-1)
+ codec_emb = self.codec_embedder(torch.clamp(codec, min=0)) * mask
+
+ self.lookahead_conv1d.to(fm_dtype)
+ if finalize is True:
+ context = torch.zeros(1, 0, self.codec_embedder.embedding_dim).to(fm_dtype)
+ else:
+ codec_emb, context = codec_emb[:, :-self.context_size].to(fm_dtype), codec_emb[:, -self.context_size:].to(fm_dtype)
+ codec_lengths = codec_lengths - self.context_size
+ mu, _ = self.lookahead_conv1d(codec_emb, codec_lengths, context)
+ mu = mu.repeat_interleave(self.feat_token_ratio, dim=1)
+ # print(mu.size())
+ conditions = torch.zeros([mu.size(0), mu.shape[1], self.num_mels]).to(mu)
+ # get conditions
+ if prompt_wav is not None:
+ if prompt_wav.ndim == 2:
+ prompt_wav, prompt_wav_lengths = self.extract_feat(prompt_wav, prompt_wav_lengths)
+ # NOTE 在fmax12k fm中,尝试mel interploate成token 2倍shape,而不是强制截断
+ prompt_wav = prompt_wav.to(fm_dtype)
+ for i, _len in enumerate(prompt_wav_lengths):
+ conditions[i, :_len] = prompt_wav[i]
+
+ feat_lengths = codec_lengths * self.feat_token_ratio
+ # NOTE add_optional_chunk_mask支持生成-1/1/15/30不同chunk_size的mask
+ mask = add_optional_chunk_mask(mu, torch.ones([1, 1, mu.shape[1]]).to(mu).bool(), False, False, 0, chunk_size, -1)
+ feat = self.solve_ode(mu, rand_xvec, conditions.to(fm_dtype), mask, **kwargs)
+
+ if prompt_codec is not None and prompt_wav is not None:
+ feat, feat_lens = self.remove_prompt(None, prompt_wav_lengths, feat, feat_lengths)
+
+ return feat
+
+ @staticmethod
+ def concat_prompt(prompt, prompt_lengths, text, text_lengths):
+ xs_list, x_len_list = [], []
+ for idx, (_prompt_len, _text_len) in enumerate(zip(prompt_lengths, text_lengths)):
+ xs_list.append(torch.concat([prompt[idx, :_prompt_len], text[idx, :_text_len]], dim=0))
+ x_len_list.append(_prompt_len + _text_len)
+
+ xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0)
+ x_lens = torch.tensor(x_len_list, dtype=torch.int64).to(xs.device)
+
+ return xs, x_lens
+
+ @staticmethod
+ def remove_prompt(prompt, prompt_lengths, padded, padded_lengths):
+ xs_list = []
+ for idx, (_prompt_len, _x_len) in enumerate(zip(prompt_lengths, padded_lengths)):
+ xs_list.append(padded[idx, _prompt_len: _x_len])
+
+ xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0)
+
+ return xs, padded_lengths - prompt_lengths
+
+ def get_rand_noise(self, mu: torch.Tensor, **kwargs):
+ use_fixed_noise_infer = kwargs.get("use_fixed_noise_infer", True)
+ max_len = kwargs.get("max_len", 50*300)
+ if use_fixed_noise_infer:
+ if not hasattr(self, "rand_noise") or self.rand_noise is None or self.rand_noise.shape[2] < mu.shape[2]:
+ self.rand_noise = torch.randn([1, max_len, mu.shape[2]]).to(mu)
+ logging.info("init random noise for Flow")
+ # return self.rand_noise[:, :mu.shape[1], :]
+ return torch.concat([self.rand_noise[:, :mu.shape[1], :] for _ in range(mu.size(0))], dim = 0)
+ else:
+ return torch.randn_like(mu)
+
+ def solve_ode(self, mu, rand_xvec, conditions, mask, **kwargs):
+ fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")]
+ temperature = kwargs.get("temperature", 1.0)
+ n_timesteps = kwargs.get("n_timesteps", 10)
+ infer_t_scheduler = kwargs.get("infer_t_scheduler", "cosine")
+ z = self.get_rand_noise(mu) * temperature
+ # print("z", z.size(), "mu", mu.size())
+ t_span = torch.linspace(0, 1, n_timesteps + 1).to(mu)
+ # print("t_span", t_span)
+ if infer_t_scheduler == 'cosine':
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
+ fm_time = time.time()
+ self.dit_model.to(fm_dtype)
+ feat = self.solve_euler(
+ z.to(fm_dtype), t_span=t_span.to(fm_dtype), mu=mu.to(fm_dtype), mask=mask,
+ spks=rand_xvec.to(fm_dtype), cond=conditions.to(fm_dtype), **kwargs
+ )
+ escape_time = (time.time() - fm_time) * 1000.0
+ logging.info(f"fm dec {n_timesteps} step time: {escape_time:.2f}, avg {escape_time/n_timesteps:.2f} ms")
+ return feat
+
+ def solve_euler(self, x, t_span, mu, mask, spks=None, cond=None, **kwargs):
+ """
+ Fixed euler solver for ODEs.
+ Args:
+ x (torch.Tensor): random noise
+ t_span (torch.Tensor): n_timesteps interpolated
+ shape: (n_timesteps + 1,)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+ """
+ inference_cfg_rate = kwargs.get("inference_cfg_rate", 0.7)
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+ # print("solve_euler cond", cond.size())
+ steps = 1
+ z, bz = x, x.shape[0]
+ while steps <= len(t_span) - 1:
+ if inference_cfg_rate > 0:
+ x_in = torch.concat([x, x], dim=0)
+ spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
+ mask_in = torch.concat([mask, mask], dim=0)
+ mu_in = torch.concat([mu, torch.zeros_like(mu)], dim=0)
+ t_in = torch.concat([t.unsqueeze(0) for _ in range(mu_in.size(0))], dim=0)
+ if isinstance(cond, torch.Tensor):
+ cond_in = torch.concat([cond, torch.zeros_like(cond)], dim=0)
+ else:
+ cond_in = dict(
+ prompt=[
+ torch.concat([cond["prompt"][0], torch.zeros_like(cond["prompt"][0])], dim=0),
+ torch.concat([cond["prompt"][1], cond["prompt"][1]], dim=0),
+ ]
+ )
+ else:
+ x_in, mask_in, mu_in, spks_in, t_in, cond_in = x, mask, mu, spks, t, cond
+
+ # if spks is not None:
+ # cond_in = cond_in + spks
+
+ infer_causal_mask_type = kwargs.get("infer_causal_mask_type", 0)
+ chunk_mask_value = self.dit_model.causal_mask_type[infer_causal_mask_type]["prob_min"]
+ hint_once(
+ f"flow mask type: {infer_causal_mask_type}, mask_rank value: {chunk_mask_value}.",
+ "chunk_mask_value"
+ )
+ # print("dit_model cond", x_in.size(), cond_in.size(), mu_in.size(), spks_in.size(), t_in.size())
+ # print(t_in)
+ dphi_dt = self.dit_model(
+ x_in, cond_in, mu_in, spks_in, t_in,
+ mask=mask_in,
+ mask_rand=torch.ones_like(t_in).reshape(-1, 1, 1) * chunk_mask_value
+ )
+ if inference_cfg_rate > 0:
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [bz, bz], dim=0)
+ dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
+ inference_cfg_rate * cfg_dphi_dt)
+
+ x = x + dt * dphi_dt
+ t = t + dt
+ # sol.append(x)
+ if steps < len(t_span) - 1:
+ dt = t_span[steps + 1] - t
+ steps += 1
+
+ return x
\ No newline at end of file
diff --git a/funcineforge/models/inference_model.py b/funcineforge/models/inference_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..76cd4424eee02ecfc579fb5b668bc9672f425877
--- /dev/null
+++ b/funcineforge/models/inference_model.py
@@ -0,0 +1,116 @@
+import torch
+import torch.nn as nn
+import logging
+import numpy as np
+import os
+import torchaudio
+import time
+import shutil
+from funcineforge.utils.set_all_random_seed import set_all_random_seed
+from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
+
+
+class FunCineForgeInferModel(nn.Module):
+ def __init__(
+ self,
+ lm_model,
+ fm_model,
+ voc_model,
+ **kwargs
+ ):
+ from funcineforge.auto.auto_model import AutoModel
+ super().__init__()
+ self.tokenizer = lm_model.kwargs["tokenizer"]
+ self.frontend = fm_model.kwargs["frontend"]
+ self.lm_model = lm_model.model
+ self.fm_model = fm_model.model
+ self.voc_model = voc_model.model
+ mel_extractor = self.fm_model.mel_extractor
+ if mel_extractor:
+ self.mel_frame_rate = mel_extractor.sampling_rate // mel_extractor.hop_length
+ self.sample_rate = mel_extractor.sampling_rate
+ else:
+ self.mel_frame_rate = self.fm_model.sample_rate // 480
+ self.sample_rate = self.fm_model.sample_rate
+
+ @torch.no_grad()
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ **kwargs,
+ ):
+ uttid = key[0]
+ logging.info(f"generating {uttid}")
+ # text -> codec in [1, T]
+ kwargs["tokenizer"] = self.tokenizer
+ set_all_random_seed(kwargs.get("random_seed", 0))
+ lm_time = time.time()
+ codec, hit_eos, states = self.lm_model.inference(data_in, data_lengths, key, **kwargs)
+ logging.info(f"[llm time]: {((time.time()-lm_time)*1000):.2f} ms, [hit_eos]: {hit_eos}, [gen len]: {codec.shape[1]}, [speech tokens]: {codec[0].cpu().tolist()}")
+ wav, batch_data_time = None, 1.0
+ if codec.shape[1] > 0:
+ fm_time = time.time()
+ data_in[0]["codec"] = codec
+ set_all_random_seed(kwargs.get("random_seed", 0))
+ feat = self.fm_model.inference(data_in, data_lengths, key, **kwargs)
+ # feat -> wav
+ set_all_random_seed(kwargs.get("random_seed", 0))
+ wav = self.voc_model.inference([feat[0]], data_lengths, key, **kwargs)
+ # output save
+ output_dir = kwargs.get("output_dir", None)
+ if output_dir is not None:
+ feat_out_dir = os.path.join(output_dir, "feat")
+ os.makedirs(feat_out_dir, exist_ok=True)
+ np.save(os.path.join(feat_out_dir, f"{key[0]}.npy"), feat[0].cpu().numpy())
+
+ wav_out_dir = os.path.join(output_dir, "wav")
+ os.makedirs(wav_out_dir, exist_ok=True)
+ output_wav_path = os.path.join(wav_out_dir, f"{key[0]}.wav")
+ torchaudio.save(
+ output_wav_path, wav.cpu(),
+ sample_rate=self.sample_rate, encoding='PCM_S', bits_per_sample=16
+ )
+
+ silent_video_path = data_in[0]["video"]
+ if os.path.exists(silent_video_path):
+ video_out_dir = os.path.join(output_dir, "mp4")
+ video_gt_dir = os.path.join(output_dir, "gt")
+ os.makedirs(video_out_dir, exist_ok=True)
+ os.makedirs(video_gt_dir, exist_ok=True)
+ output_video_path = os.path.join(video_out_dir, f"{key[0]}.mp4")
+ copy_video_path = os.path.join(video_gt_dir, f"{key[0]}.mp4")
+ shutil.copy2(silent_video_path, copy_video_path)
+ self.merge_video_audio(
+ silent_video_path=silent_video_path,
+ wav_path=output_wav_path,
+ output_path=output_video_path,
+ )
+
+ logging.info(f"fm_voc time: {((time.time()-fm_time)*1000):.2f} ms")
+
+ batch_data_time = wav.shape[1] / self.voc_model.sample_rate
+
+ return [[wav]], {"batch_data_time": batch_data_time}
+
+ def merge_video_audio(self, silent_video_path, wav_path, output_path):
+
+ video_clip = VideoFileClip(silent_video_path)
+ video_duration = video_clip.duration
+ audio_clip = AudioFileClip(wav_path)
+ audio_duration = audio_clip.duration
+
+ if audio_duration >= video_duration:
+ audio_clip = audio_clip.subclipped(0, video_duration)
+
+ video_clip = video_clip.with_audio(audio_clip)
+ video_clip.write_videofile(
+ output_path,
+ codec='libx264',
+ audio_codec='aac',
+ fps=video_clip.fps,
+ logger=None
+ )
+ video_clip.close()
+ audio_clip.close()
\ No newline at end of file
diff --git a/funcineforge/models/language_model.py b/funcineforge/models/language_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0580e8c71b76ec43a608139f2273564053ce6c22
--- /dev/null
+++ b/funcineforge/models/language_model.py
@@ -0,0 +1,274 @@
+import logging
+import os
+import torch
+import torch.nn as nn
+from funcineforge.models.utils.llm_decoding import LLMDecoder
+from funcineforge.utils.device_funcs import to_device
+import numpy as np
+from funcineforge.models.utils import dtype_map
+from funcineforge.models import FunCineForgeSpecAug
+from transformers import AutoModelForCausalLM
+import pickle
+
+
+
+class FunCineForgeLM(nn.Module):
+ def __init__(
+ self,
+ llm: str = None,
+ llm_conf: dict = None,
+ input_size: int = 80,
+ length_normalized_loss: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+
+ # llm
+ self.llm_conf = llm_conf
+ self.llm = None
+
+ init_param_path = llm_conf.get("init_param_path", "")
+ llm_load_kwargs = llm_conf.get("load_kwargs", {})
+ self.sample_rate = kwargs.get("sample_rate", 24000)
+ self.token_rate = kwargs.get("token_rate", 25)
+
+ if kwargs.get("infer_lora_merged", False):
+ llm_conf["use_qlora"] = False
+ llm_conf["use_lora"] = False
+ kwargs["infer_use_lora"] = False
+
+
+ model = AutoModelForCausalLM.from_pretrained(
+ init_param_path,
+ load_in_8bit=None,
+ device_map=None,
+ use_cache=None,
+ **llm_load_kwargs,
+ )
+
+ freeze = llm_conf.get("freeze", True)
+ if freeze:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model.eval()
+
+ logging.info(f"use_lora: {llm_conf.get('use_lora', False)}, use_qlora: {llm_conf.get('use_qlora', False)}, infer_use_lora: {kwargs.get('infer_use_lora',False)}, infer_lora_merged: {kwargs.get('infer_lora_merged',False)}")
+
+ if llm_conf.get("activation_checkpoint", False):
+ model.gradient_checkpointing_enable()
+
+ self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
+ self.llm = model.to(dtype_map[self.llm_dtype])
+ llm_dim = model.get_input_embeddings().weight.shape[-1]
+
+ if (not llm_conf.get("use_lora", False)) and (not kwargs.get("infer_use_lora",False)):
+ del self.llm.lm_head
+ self.codec_unit = kwargs.get("codec_unit", 6761)
+ self.timespk_unit = kwargs.get("timespk_unit", 1550)
+ self.codec_embed = nn.Embedding(self.codec_unit, llm_dim, 0)
+ self.timespk_embed = nn.Embedding(self.timespk_unit, llm_dim, 0)
+ self.codec_head = nn.Linear(llm_dim, self.codec_unit, bias=False)
+ self.face_size = kwargs.get("face_size", 512)
+ self.face_linear = nn.Linear(self.face_size, llm_dim)
+
+ self.length_normalized_loss = length_normalized_loss
+ self.ignore_id = kwargs.get("ignore_id", -100)
+
+ specaug = kwargs.get("specaug", None)
+ specaug_conf = kwargs.get("specaug_conf", {})
+ if specaug is not None:
+ specaug = FunCineForgeSpecAug(**specaug_conf)
+ self.specaug = specaug
+ rank = int(os.environ.get("RANK", 0))
+ logging.info(f"rank: {rank}, model is builded.")
+
+
+ def insert_face_embeddings(
+ self, inputs_embeds, face_emb, attention_mask, labels_ids,
+ codec_len, insert_pos, device
+ ):
+ """
+ 将face_emb插入到inputs_embeds中的指定位置, 同步更新attention_mask和labels_ids
+ Args:
+ inputs_embeds: (batch_size, token_num, dims) 输入embedding
+ face_emb: (batch_size, max_face_len, dims) 面部embedding
+ attention_mask: (batch_size, token_num) 注意力mask
+ labels_ids: (batch_size, token_num) 标签ID
+ codec_len: (batch_size,) 每个样本的实际face_emb长度
+ insert_pos: int 插入位置, SOS token之后
+ device
+ Returns:
+ padded_inputs_embeds: 插入face_emb并padding后的inputs_embeds
+ padded_attention_mask: 更新后的attention_mask
+ padded_labels: 更新后的labels_ids
+ """
+ batch_size, token_num, dims = inputs_embeds.shape
+ max_face_len = face_emb.size(1)
+
+ # 预计算新序列的最大长度
+ new_max_length = token_num + max_face_len
+
+ # 预分配输出张量
+ padded_inputs_embeds = torch.zeros(batch_size, new_max_length, dims, device=device)
+ padded_attention_mask = torch.zeros(batch_size, new_max_length, device=device, dtype=attention_mask.dtype)
+ padded_labels = torch.full((batch_size, new_max_length), self.ignore_id, device=device, dtype=labels_ids.dtype)
+
+ for i in range(batch_size):
+ current_face_len = codec_len[i].item()
+
+ # 直接填充,避免中间拼接
+ padded_inputs_embeds[i, :insert_pos] = inputs_embeds[i, :insert_pos]
+ padded_inputs_embeds[i, insert_pos:insert_pos+current_face_len] = face_emb[i, :current_face_len]
+ padded_inputs_embeds[i, insert_pos+current_face_len:token_num+current_face_len] = inputs_embeds[i, insert_pos:]
+
+ # 同样处理mask和labels
+ padded_attention_mask[i, :insert_pos] = attention_mask[i, :insert_pos]
+ padded_attention_mask[i, insert_pos:insert_pos+current_face_len] = 1
+ padded_attention_mask[i, insert_pos+current_face_len:token_num+current_face_len] = attention_mask[i, insert_pos:]
+
+ padded_labels[i, :insert_pos] = labels_ids[i, :insert_pos]
+ padded_labels[i, insert_pos:insert_pos+current_face_len] = self.ignore_id
+ padded_labels[i, insert_pos+current_face_len:token_num+current_face_len] = labels_ids[i, insert_pos:]
+
+ return padded_inputs_embeds, padded_attention_mask, padded_labels
+
+
+ def load_data(self, contents: dict, **kwargs):
+ lm_use_prompt = kwargs.get("lm_use_prompt", True)
+ tokenizer = kwargs.get("tokenizer")
+ # text + clue
+ text = contents["text"]
+ clue = "<|startofclue|>" + contents["clue"] + "<|endofclue|>"
+ if lm_use_prompt:
+ text = clue + text
+ text_ids = tokenizer.encode(text)
+ text_len = len(text_ids)
+ # timespk_ids
+ timespk_ids = contents["timespk_ids"].tolist()
+ type_id = contents["type_id"]
+ # sequence
+ sequence = [
+ kwargs['dataset_conf']["sos"],
+ *text_ids,
+ type_id,
+ *timespk_ids,
+ kwargs['dataset_conf']["turn_of_speech"]
+ ]
+ input_ids = torch.tensor(sequence, dtype=torch.int64)
+
+ # flag tensors
+ text_flag = torch.zeros(len(sequence), dtype=torch.float32)
+ timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
+ codec_flag = torch.zeros(len(sequence), dtype=torch.float32)
+ text_flag[1: text_len+1] = 1
+ timespk_flag[text_len+1: -1] = 1
+ codec_flag = 1 - text_flag - timespk_flag
+
+ # face embs
+ speech_len = contents["speech_len"]
+ face_embs = torch.zeros((speech_len, self.face_size), dtype=torch.float32)
+ face_path = contents.get("face")
+ with open(face_path, 'rb') as f:
+ stat_obj = pickle.load(f)
+ embeddings = stat_obj['embeddings']
+ faceI = stat_obj['faceI']
+ for emb, frameI in zip(embeddings, faceI):
+ fi = int(frameI)
+ if 0 <= fi < speech_len:
+ end = min(fi + 5, speech_len)
+ face_embs[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
+
+ # batch dimension
+ input_ids = input_ids[None, :]
+ text_flag = text_flag[None, :]
+ timespk_flag = timespk_flag[None, :]
+ codec_flag = codec_flag[None, :]
+ face_embs = face_embs[None, :, :]
+ output = {
+ "input_ids": input_ids,
+ "face_embs": face_embs,
+ "text_flag": text_flag > 0,
+ "timespk_flag": timespk_flag > 0,
+ "codec_flag": codec_flag > 0,
+ "prompt_codec": None, # you can add prompt codec here if needed
+ }
+ return output
+
+ def inference_prepare(self, data_in, **kwargs):
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ output = self.load_data(data_in[0], **kwargs)
+ batch = to_device(output, kwargs["device"])
+ input_ids = batch["input_ids"]
+ input_ids = input_ids * (input_ids > 0)
+ text_flag = batch["text_flag"]
+ timespk_flag = batch["timespk_flag"]
+ codec_flag = batch["codec_flag"]
+ face_embs = batch["face_embs"]
+
+ if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
+ text_embeds = self.llm.base_model.model.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
+ else:
+ text_embeds = self.llm.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
+ timespk_embeds = self.timespk_embed(input_ids * timespk_flag) * timespk_flag.unsqueeze(-1)
+ codec_embs = self.codec_embed(input_ids * codec_flag) * codec_flag.unsqueeze(-1)
+ face_embs = self.face_linear(face_embs)
+
+ inputs_embeds = text_embeds + timespk_embeds + codec_embs
+
+ inputs_embeds = torch.cat([
+ inputs_embeds[:, 0:1, :], # sos token
+ face_embs, # face embeddings
+ inputs_embeds[:, 1:, :] # inputs_embeds after sos
+ ], dim=1)
+
+ prompt_codec = batch.get("prompt_codec", None)
+ if prompt_codec is not None:
+ codec_emb = self.codec_embed(prompt_codec)
+ inputs_embeds = torch.cat((inputs_embeds, codec_emb), dim=1)
+
+ return inputs_embeds
+
+ @torch.no_grad()
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ **kwargs,
+ ):
+ uttid = key[0]
+ inputs_emb = self.inference_prepare(data_in, **kwargs)
+
+ logging.info(f"{uttid}: min length: {kwargs['min_length']}, max length: {kwargs['max_length']}")
+
+ dtype = dtype_map[kwargs.get("llm_dtype", "fp32")]
+ if not hasattr(self, "llm_generator"):
+ llm_generator_conf = kwargs.get("dataset_conf", {})
+ self.llm_generator = LLMDecoder(
+ token_embeder=self.codec_embed,
+ **llm_generator_conf
+ ).to(dtype)
+
+ if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
+ self.llm.base_model.model.lm_head = self.codec_head.to(dtype)
+ else:
+ self.llm.lm_head = self.codec_head.to(dtype)
+
+ gen_codec, hit_eos, states = self.llm_generator(
+ inputs_emb.to(dtype),
+ self.llm,
+ states=kwargs.get("states", {}),
+ **kwargs
+ )
+
+ output_dir = kwargs.get("output_dir", None)
+ if output_dir is not None:
+ output_dir = os.path.join(output_dir, "codec")
+ os.makedirs(output_dir, exist_ok=True)
+ np.save(
+ os.path.join(output_dir, f"{key[0]}.npy"),
+ gen_codec[0].cpu().numpy()
+ )
+
+ return gen_codec, hit_eos, states
\ No newline at end of file
diff --git a/funcineforge/models/modules/__init__.py b/funcineforge/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/funcineforge/models/modules/dit_flow_matching/__init__.py b/funcineforge/models/modules/dit_flow_matching/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/funcineforge/models/modules/dit_flow_matching/dit_model.py b/funcineforge/models/modules/dit_flow_matching/dit_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a69dc70d2a7f90a7bc9a7a5e68839c63464296
--- /dev/null
+++ b/funcineforge/models/modules/dit_flow_matching/dit_model.py
@@ -0,0 +1,208 @@
+"""
+ein notation:
+b - batch
+n - sequence
+nt - text sequence
+nw - raw wave length
+d - dimension
+"""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+from x_transformers.x_transformers import RotaryEmbedding
+from funcineforge.models.utils.masks import causal_block_mask
+
+from .dit_modules import (
+ TimestepEmbedding,
+ ConvNeXtV2Block,
+ CausalConvPositionEmbedding,
+ DiTBlock,
+ AdaLayerNormZero_Final,
+ precompute_freqs_cis,
+ get_pos_embed_indices,
+)
+
+
+# Text embedding
+
+
+class TextEmbedding(nn.Module):
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
+ super().__init__()
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
+
+ if conv_layers > 0:
+ self.extra_modeling = True
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
+ self.text_blocks = nn.Sequential(
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
+ )
+ else:
+ self.extra_modeling = False
+
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
+ batch, text_len = text.shape[0], text.shape[1]
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
+ text = F.pad(text, (0, seq_len - text_len), value=0)
+
+ if drop_text: # cfg for text
+ text = torch.zeros_like(text)
+
+ text = self.text_embed(text) # b n -> b n d
+
+ # possible extra modeling
+ if self.extra_modeling:
+ # sinus pos emb
+ batch_start = torch.zeros((batch,), dtype=torch.long)
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
+ text_pos_embed = self.freqs_cis[pos_idx]
+ text = text + text_pos_embed
+
+ # convnextv2 blocks
+ text = self.text_blocks(text)
+
+ return text
+
+
+# noised input audio and context mixing embedding
+
+
+class InputEmbedding(nn.Module):
+ def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
+ super().__init__()
+ spk_dim = 0 if spk_dim is None else spk_dim
+ self.spk_dim = spk_dim
+ self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
+ self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
+
+ def forward(
+ self,
+ x: float["b n d"],
+ cond: float["b n d"],
+ text_embed: float["b n d"],
+ spks: float["b d"],
+ ):
+ to_cat = [x, cond, text_embed]
+ if self.spk_dim > 0:
+ spks = repeat(spks, "b c -> b t c", t=x.shape[1])
+ to_cat.append(spks)
+
+ x = self.proj(torch.cat(to_cat, dim=-1))
+ x = self.conv_pos_embed(x) + x
+ return x
+
+
+# Transformer backbone using DiT blocks
+
+
+class DiT(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth=8,
+ heads=8,
+ dim_head=64,
+ dropout=0.1,
+ ff_mult=4,
+ mel_dim=80,
+ mu_dim=None,
+ long_skip_connection=False,
+ spk_dim=None,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.time_embed = TimestepEmbedding(dim)
+ if mu_dim is None:
+ mu_dim = mel_dim
+ self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
+
+ self.rotary_embed = RotaryEmbedding(dim_head)
+
+ self.dim = dim
+ self.depth = depth
+
+ self.transformer_blocks = nn.ModuleList(
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
+ )
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
+
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
+ self.proj_out = nn.Linear(dim, mel_dim)
+ self.causal_mask_type = kwargs.get("causal_mask_type", None)
+
+ def build_mix_causal_mask(self, attn_mask, rand=None, ratio=None):
+ b, _, _, t = attn_mask.shape
+ if rand is None:
+ rand = torch.rand((b, 1, 1, 1), device=attn_mask.device, dtype=torch.float32)
+ mixed_mask = attn_mask.clone()
+ for item in self.causal_mask_type:
+ prob_min, prob_max = item["prob_min"], item["prob_max"]
+ _ratio = 1
+ if "ratio" in item:
+ _ratio = item["ratio"]
+ if ratio is not None:
+ _ratio = ratio
+ block_size = item["block_size"] * _ratio
+ if block_size <= 0:
+ causal_mask = attn_mask
+ else:
+ causal_mask = causal_block_mask(
+ t, block_size, attn_mask.device, torch.float32
+ ).unsqueeze(0).unsqueeze(1) # 1,1,T,T
+ flag = (prob_min <= rand) & (rand < prob_max)
+ mixed_mask = mixed_mask * (~flag) + (causal_mask * attn_mask) * flag
+
+ return mixed_mask
+
+ def forward(
+ self,
+ x: float["b n d"], # nosied input audio
+ cond: float["b n d"], # masked cond audio
+ mu: int["b nt d"], # mu
+ spks: float["b 1 d"], # spk xvec
+ time: float["b"] | float[""], # time step
+ return_hidden: bool = False,
+ mask: bool["b 1 n"] | None = None,
+ mask_rand: float["b 1 1"] = None, # for mask flag type
+ **kwargs,
+ ):
+ batch, seq_len = x.shape[0], x.shape[1]
+ if time.ndim == 0:
+ time = time.repeat(batch)
+
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
+ t = self.time_embed(time)
+ x = self.input_embed(x, cond, mu, spks.squeeze(1))
+
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
+
+ if self.long_skip_connection is not None:
+ residual = x
+
+ mask = mask.unsqueeze(1) # B,1,1,T
+ if self.causal_mask_type is not None:
+ mask = self.build_mix_causal_mask(mask, rand=mask_rand.unsqueeze(-1))
+
+ for block in self.transformer_blocks:
+ # mask-out padded values for amp training
+ x = x * mask[:, 0, -1, :].unsqueeze(-1)
+ x = block(x, t, mask=mask.bool(), rope=rope)
+
+ if self.long_skip_connection is not None:
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
+
+ x = self.norm_out(x, t)
+ output = self.proj_out(x)
+
+ if return_hidden:
+ return output, None
+
+ return output
diff --git a/funcineforge/models/modules/dit_flow_matching/dit_modules.py b/funcineforge/models/modules/dit_flow_matching/dit_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..98657e03737d2253e232a140c382a2612949ca5f
--- /dev/null
+++ b/funcineforge/models/modules/dit_flow_matching/dit_modules.py
@@ -0,0 +1,622 @@
+"""
+ein notation:
+b - batch
+n - sequence
+nt - text sequence
+nw - raw wave length
+d - dimension
+"""
+
+from __future__ import annotations
+from typing import Optional
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+
+from x_transformers.x_transformers import apply_rotary_pos_emb
+
+
+# raw wav to mel spec
+class MelSpec(nn.Module):
+ def __init__(
+ self,
+ filter_length=1024,
+ hop_length=256,
+ win_length=1024,
+ n_mel_channels=100,
+ target_sample_rate=24_000,
+ normalize=False,
+ power=1,
+ norm=None,
+ center=True,
+ ):
+ super().__init__()
+ self.n_mel_channels = n_mel_channels
+
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
+ sample_rate=target_sample_rate,
+ n_fft=filter_length,
+ win_length=win_length,
+ hop_length=hop_length,
+ n_mels=n_mel_channels,
+ power=power,
+ center=center,
+ normalized=normalize,
+ norm=norm,
+ )
+
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
+
+ def forward(self, inp):
+ if len(inp.shape) == 3:
+ inp = inp.squeeze(1) # 'b 1 nw -> b nw'
+
+ assert len(inp.shape) == 2
+
+ if self.dummy.device != inp.device:
+ self.to(inp.device)
+
+ mel = self.mel_stft(inp)
+ mel = mel.clamp(min=1e-5).log()
+ return mel
+
+
+# sinusoidal position embedding
+
+
+class SinusPositionEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x, scale=1000):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+# convolutional position embedding
+
+
+class ConvPositionEmbedding(nn.Module):
+ def __init__(self, dim, kernel_size=31, groups=16):
+ super().__init__()
+ assert kernel_size % 2 != 0
+ self.conv1d = nn.Sequential(
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
+ nn.Mish(),
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
+ nn.Mish(),
+ )
+
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
+ if mask is not None:
+ mask = mask[..., None]
+ x = x.masked_fill(~mask, 0.0)
+
+ x = x.permute(0, 2, 1)
+ x = self.conv1d(x)
+ out = x.permute(0, 2, 1)
+
+ if mask is not None:
+ out = out.masked_fill(~mask, 0.0)
+
+ return out
+
+
+class CausalConvPositionEmbedding(nn.Module):
+ def __init__(self, dim, kernel_size=31, groups=16):
+ super().__init__()
+ assert kernel_size % 2 != 0
+ self.kernel_size = kernel_size
+ self.conv1 = nn.Sequential(
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
+ nn.Mish(),
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
+ nn.Mish(),
+ )
+
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
+ if mask is not None:
+ mask = mask[..., None]
+ x = x.masked_fill(~mask, 0.0)
+
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
+ x = self.conv1(x)
+ x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
+ x = self.conv2(x)
+ out = x.permute(0, 2, 1)
+
+ if mask is not None:
+ out = out.masked_fill(~mask, 0.0)
+
+ return out
+
+
+# rotary positional embedding related
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
+ # has some connection to NTK literature
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+ t = torch.arange(end, device=freqs.device) # type: ignore
+ freqs = torch.outer(t, freqs).float() # type: ignore
+ freqs_cos = torch.cos(freqs) # real part
+ freqs_sin = torch.sin(freqs) # imaginary part
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
+
+
+def get_pos_embed_indices(start, length, max_pos, scale=1.0):
+ # length = length if isinstance(length, int) else length.max()
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
+ pos = (
+ start.unsqueeze(1)
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
+ )
+ # avoid extra long error.
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
+ return pos
+
+
+# Global Response Normalization layer (Instance Normalization ?)
+
+
+class GRN(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(enabled=False):
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
+# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
+
+
+class ConvNeXtV2Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ dilation: int = 1,
+ ):
+ super().__init__()
+ padding = (dilation * (7 - 1)) // 2
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.grn = GRN(intermediate_dim)
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ x = x.transpose(1, 2) # b n d -> b d n
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # b d n -> b n d
+ with torch.cuda.amp.autocast(enabled=False):
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.grn(x)
+ x = self.pwconv2(x)
+ return residual + x
+
+
+# AdaLayerNormZero
+# return with modulated x for attn input, and params for later mlp modulation
+
+
+class AdaLayerNormZero(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(dim, dim * 6)
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, x, emb=None):
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+# AdaLayerNormZero for final layer
+# return only with modulated x for attn input, cuz no more mlp modulation
+
+
+class AdaLayerNormZero_Final(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(dim, dim * 2)
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(self, x, emb):
+ emb = self.linear(self.silu(emb))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
+# FeedForward
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ activation = nn.GELU(approximate=approximate)
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.ff(x)
+
+
+# Attention with possible joint part
+# modified from diffusers/src/diffusers/models/attention_processor.py
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ processor: JointAttnProcessor | AttnProcessor,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ context_dim: Optional[int] = None, # if not None -> joint attention
+ context_pre_only=None,
+ ):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.processor = processor
+
+ self.dim = dim
+ self.heads = heads
+ self.inner_dim = dim_head * heads
+ self.dropout = dropout
+
+ self.context_dim = context_dim
+ self.context_pre_only = context_pre_only
+
+ self.to_q = nn.Linear(dim, self.inner_dim)
+ self.to_k = nn.Linear(dim, self.inner_dim)
+ self.to_v = nn.Linear(dim, self.inner_dim)
+
+ if self.context_dim is not None:
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
+ if self.context_pre_only is not None:
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
+
+ def forward(
+ self,
+ x: float["b n d"], # noised input x # noqa: F722
+ c: float["b n d"] = None, # context c # noqa: F722
+ mask: bool["b n"] | None = None, # noqa: F722
+ rope=None, # rotary position embedding for x
+ c_rope=None, # rotary position embedding for c
+ ) -> torch.Tensor:
+ if c is not None:
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
+ else:
+ return self.processor(self, x, mask=mask, rope=rope)
+
+
+# Attention processor
+
+
+class AttnProcessor:
+ def __init__(self):
+ pass
+
+ def __call__(
+ self,
+ attn: Attention,
+ x: float["b n d"], # noised input x # noqa: F722
+ mask: bool["b n"] | None = None, # noqa: F722
+ rope=None, # rotary position embedding
+ ) -> torch.FloatTensor:
+ batch_size = x.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(x)
+ key = attn.to_k(x)
+ value = attn.to_v(x)
+
+ # apply rotary position embedding
+ if rope is not None:
+ freqs, xpos_scale = rope
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
+
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
+
+ # attention
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
+ if mask is not None:
+ attn_mask = mask
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
+ else:
+ attn_mask = None
+
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ x = x.to(query.dtype)
+
+ # linear proj
+ x = attn.to_out[0](x)
+ # dropout
+ x = attn.to_out[1](x)
+
+ if mask is not None:
+ if mask.dim() == 2:
+ mask = mask.unsqueeze(-1)
+ else:
+ mask = mask[:, 0, -1].unsqueeze(-1)
+ x = x.masked_fill(~mask, 0.0)
+
+ return x
+
+
+# Joint Attention processor for MM-DiT
+# modified from diffusers/src/diffusers/models/attention_processor.py
+
+
+class JointAttnProcessor:
+ def __init__(self):
+ pass
+
+ def __call__(
+ self,
+ attn: Attention,
+ x: float["b n d"], # noised input x # noqa: F722
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
+ mask: bool["b n"] | None = None, # noqa: F722
+ rope=None, # rotary position embedding for x
+ c_rope=None, # rotary position embedding for c
+ ) -> torch.FloatTensor:
+ residual = x
+
+ batch_size = c.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(x)
+ key = attn.to_k(x)
+ value = attn.to_v(x)
+
+ # `context` projections.
+ c_query = attn.to_q_c(c)
+ c_key = attn.to_k_c(c)
+ c_value = attn.to_v_c(c)
+
+ # apply rope for context and noised input independently
+ if rope is not None:
+ freqs, xpos_scale = rope
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
+ if c_rope is not None:
+ freqs, xpos_scale = c_rope
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
+
+ # attention
+ query = torch.cat([query, c_query], dim=1)
+ key = torch.cat([key, c_key], dim=1)
+ value = torch.cat([value, c_value], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
+ if mask is not None:
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
+ else:
+ attn_mask = None
+
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ x = x.to(query.dtype)
+
+ # Split the attention outputs.
+ x, c = (
+ x[:, : residual.shape[1]],
+ x[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ x = attn.to_out[0](x)
+ # dropout
+ x = attn.to_out[1](x)
+ if not attn.context_pre_only:
+ c = attn.to_out_c(c)
+
+ if mask is not None:
+ mask = mask.unsqueeze(-1)
+ x = x.masked_fill(~mask, 0.0)
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
+
+ return x, c
+
+
+# DiT Block
+
+
+class DiTBlock(nn.Module):
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
+ super().__init__()
+
+ self.attn_norm = AdaLayerNormZero(dim)
+ self.attn = Attention(
+ processor=AttnProcessor(),
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ dropout=dropout,
+ )
+
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
+
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
+ # pre-norm & modulation for attention input
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
+
+ # attention
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
+
+ # process attention output for input x
+ x = x + gate_msa.unsqueeze(1) * attn_output
+
+ with torch.cuda.amp.autocast(enabled=False):
+ ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ ff_output = self.ff(ff_norm)
+ x = x + gate_mlp.unsqueeze(1) * ff_output
+
+ return x
+
+
+# MMDiT Block https://arxiv.org/abs/2403.03206
+
+
+class MMDiTBlock(nn.Module):
+ r"""
+ modified from diffusers/src/diffusers/models/attention.py
+
+ notes.
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
+ _x: noised input related. (right part)
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
+ """
+
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
+ super().__init__()
+
+ self.context_pre_only = context_pre_only
+
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
+ self.attn_norm_x = AdaLayerNormZero(dim)
+ self.attn = Attention(
+ processor=JointAttnProcessor(),
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ dropout=dropout,
+ context_dim=dim,
+ context_pre_only=context_pre_only,
+ )
+
+ if not context_pre_only:
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
+ else:
+ self.ff_norm_c = None
+ self.ff_c = None
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
+
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
+ # pre-norm & modulation for attention input
+ if self.context_pre_only:
+ norm_c = self.attn_norm_c(c, t)
+ else:
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
+
+ # attention
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
+
+ # process attention output for context c
+ if self.context_pre_only:
+ c = None
+ else: # if not last layer
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
+
+ with torch.cuda.amp.autocast(enabled=False):
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+ c_ff_output = self.ff_c(norm_c)
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
+
+ # process attention output for input x
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
+
+ with torch.cuda.amp.autocast(enabled=False):
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
+ x_ff_output = self.ff_x(norm_x)
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
+
+ return c, x
+
+
+# time step conditioning embedding
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, dim, freq_embed_dim=256):
+ super().__init__()
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
+
+ def forward(self, timestep: float["b"]): # noqa: F821
+ time_hidden = self.time_embed(timestep)
+ time_hidden = time_hidden.to(timestep.dtype)
+ time = self.time_mlp(time_hidden) # b d
+ return time
diff --git a/funcineforge/models/modules/hifigan/__init__.py b/funcineforge/models/modules/hifigan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd4293fb351e905fc3a6823aeec87f6d3fe02464
--- /dev/null
+++ b/funcineforge/models/modules/hifigan/__init__.py
@@ -0,0 +1,14 @@
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+from funcineforge.models.modules.hifigan.generator import HifiGenerator, NsfHifiGenerator, HiFTGenerator
+from funcineforge.models.modules.hifigan.discriminator import MultipleDiscriminator
+from funcineforge.models.modules.hifigan.nsf_utils import ConvRNNF0Predictor
diff --git a/funcineforge/models/modules/hifigan/activations.py b/funcineforge/models/modules/hifigan/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..61f2808a5466b3cf4d041059700993af5527dd29
--- /dev/null
+++ b/funcineforge/models/modules/hifigan/activations.py
@@ -0,0 +1,120 @@
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+import torch
+from torch import nn, sin, pow
+from torch.nn import Parameter
+
+
+class Snake(nn.Module):
+ '''
+ Implementation of a sine-based periodic activation function
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter
+ References:
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snake(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha: trainable parameter
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(Snake, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ Snake ∶= x + 1/a * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
+
+
+class SnakeBeta(nn.Module):
+ '''
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.beta = Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+ return x
\ No newline at end of file
diff --git a/funcineforge/models/modules/hifigan/discriminator.py b/funcineforge/models/modules/hifigan/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c88d5bb9f0c64d0778584e58811ccbefbd409963
--- /dev/null
+++ b/funcineforge/models/modules/hifigan/discriminator.py
@@ -0,0 +1,299 @@
+"""hifigan based dicriminator implementation.
+
+This code is modified from https://github.com/jik876/hifi-gan and https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+import typing as tp
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv2d, AvgPool1d, Conv1d
+from torch.nn.utils import weight_norm, spectral_norm
+
+from funcineforge.models.modules.hifigan import get_padding
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3,
+ use_spectral_norm=False, lrelu_slope=0.1):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ self.lrelu_slope = lrelu_slope
+
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(
+ Conv2d(
+ 1,
+ 32, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(
+ Conv2d(
+ 32,
+ 128, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(
+ Conv2d(
+ 128,
+ 512, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(
+ Conv2d(
+ 512,
+ 1024, (kernel_size, 1), (stride, 1),
+ padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.lrelu_slope)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self,
+ in_channels: int = 1,
+ periods: tp.List[int] = [2, 3, 5, 7, 11]):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ DiscriminatorP(p) for p in periods
+ ])
+
+ def forward(self, x: torch.Tensor, return_intermediates: bool = True):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of list of each discriminator outputs, which consists of each
+ layer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ # outs += [f(x)]
+ if return_intermediates:
+ outs.append(f(x))
+ else:
+ outs.append(f(x)[0])
+
+ return outs
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False, lrelu_slope=0.1):
+ super(DiscriminatorS, self).__init__()
+ self.lrelu_slope = lrelu_slope
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ])
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.lrelu_slope)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiScaleDiscriminator(torch.nn.Module):
+ def __init__(self, in_channels: int = 1, nb_scales: int = 3):
+ super(MultiScaleDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ DiscriminatorS(use_spectral_norm=True),
+ DiscriminatorS(),
+ DiscriminatorS(),
+ ])
+ self.meanpools = nn.ModuleList(
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
+
+ def forward(self, x: torch.Tensor, return_intermediates: bool = True):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of list of each discriminator outputs, which consists of each
+ layer output tensors.
+
+ """
+ outs = []
+ for i, f in enumerate(self.discriminators):
+ if i != 0:
+ x = self.meanpools[i - 1](x)
+ if return_intermediates:
+ outs.append(f(x))
+ else:
+ outs.append(f(x)[0])
+
+ return outs
+
+
+class DiscriminatorR(nn.Module):
+ def __init__(
+ self,
+ stft_params: tp.List[int],
+ lrelu_slope: float = 0.1,
+ use_spectral_norm: bool = False,
+ ):
+ super().__init__()
+
+ self.stft_params = stft_params
+ self.lrelu_slope = lrelu_slope
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+
+ self.convs = nn.ModuleList([
+ norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
+ ])
+ self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
+
+ def spectrogram(self, x):
+ n_fft, hop_length, win_length = self.stft_params
+ x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
+ x = x.squeeze(1)
+ spec = torch.stft(x, n_fft, hop_length=hop_length, win_length=win_length,
+ center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
+
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
+ mag = torch.norm(spec, p=2, dim =-1) #[B, F, TT]
+
+ return mag
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.spectrogram(x).unsqueeze(1)
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.lrelu_slope)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiResolutionDiscriminator(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ fft_sizes: tp.List[int] = [1024, 2048, 512],
+ hop_sizes: tp.List[int] = [120, 240, 50],
+ win_lengths: tp.List[int] = [600, 1200, 240],
+ lrelu_slope: float = 0.1,
+ ):
+ super().__init__()
+
+ self.discriminators = nn.ModuleList()
+
+ for fft, hop, win in zip(fft_sizes, hop_sizes, win_lengths):
+ self.discriminators.append(DiscriminatorR([fft, hop, win], lrelu_slope))
+
+ def forward(self, x: torch.Tensor, return_intermediates: bool = True):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of list of each discriminator outputs, which consists of each
+ layer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ if return_intermediates:
+ outs.append(f(x))
+ else:
+ outs.append(f(x)[0])
+
+ return outs
+
+
+class MultipleDiscriminator(nn.Module):
+ def __init__(
+ self,
+ input_size: int = 1,
+ disc_conf_list: tp.List[tp.Dict[str, tp.Any]] = None,
+ ):
+ super().__init__()
+
+ self.support_disc_choices = dict(
+ mpd=MultiPeriodDiscriminator,
+ msd=MultiScaleDiscriminator,
+ mrd=MultiResolutionDiscriminator,
+ )
+
+ self.discriminators = nn.ModuleList()
+ self.discriminator_type_lst = []
+ for args in disc_conf_list:
+ assert "name" in args, "disc_conf must have `name` attr to specific disc type."
+ disc_type = args.pop("name")
+ assert disc_type in self.support_disc_choices, \
+ "Unsupported discriminator type, only support {}".format(
+ ",".join(self.support_disc_choices.keys())
+ )
+
+ disc_class = self.support_disc_choices[disc_type]
+ one_disc = disc_class(in_channels=input_size, **args)
+ self.discriminators.append(one_disc)
+ # add back to the args for dump config.yaml
+ args["name"] = disc_type
+ self.discriminator_type_lst.append(disc_type)
+
+ def get_discriminator_type_lst(self) -> tp.List[str]:
+ return self.discriminator_type_lst
+
+ def forward(self, x, return_intermediates=True):
+ retval = []
+ for disc in self.discriminators:
+ out = disc(x, return_intermediates=return_intermediates)
+ if isinstance(out, tuple):
+ retval.append(out)
+ elif isinstance(out, list):
+ retval.extend(out)
+ else:
+ raise TypeError("The return value of discriminator must be tuple or list[tuple]")
+
+ return retval
\ No newline at end of file
diff --git a/funcineforge/models/modules/hifigan/generator.py b/funcineforge/models/modules/hifigan/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c68d3b797b02ea0f21aae683c60ce47edade3cd
--- /dev/null
+++ b/funcineforge/models/modules/hifigan/generator.py
@@ -0,0 +1,625 @@
+"""hifigan based generator implementation.
+
+This code is modified from https://github.com/jik876/hifi-gan
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
+ https://github.com/NVIDIA/BigVGAN
+
+"""
+
+import typing as tp
+
+import numpy as np
+from scipy.signal import get_window
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm
+from torch.nn.utils import remove_weight_norm
+
+from funcineforge.models.modules.hifigan import get_padding, init_weights
+from funcineforge.models.modules.hifigan.activations import Snake, SnakeBeta
+from funcineforge.models.modules.hifigan.nsf_utils import SourceModule, SourceModuleHnNSF
+
+
+class ResBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN/BigVGAN."""
+ def __init__(
+ self,
+ channels: int = 512,
+ kernel_size: int = 3,
+ dilations: tp.List[int] = [1, 3, 5],
+ use_additional_convs: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
+ ):
+ super(ResBlock, self).__init__()
+ self.use_additional_convs = use_additional_convs
+
+ self.convs1 = nn.ModuleList()
+ if use_additional_convs:
+ self.convs2 = nn.ModuleList()
+
+ for dilation in dilations:
+ self.convs1.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ padding=get_padding(kernel_size, dilation)
+ )
+ )
+ )
+
+ if use_additional_convs:
+ self.convs2.append(
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1)
+ )
+ )
+ )
+
+ self.convs1.apply(init_weights)
+ if use_additional_convs:
+ self.convs2.apply(init_weights)
+
+ if nonlinear_activation == "LeakyReLU":
+ self.activations1 = nn.ModuleList([
+ nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
+ for _ in range(len(self.convs1))
+ ])
+ if use_additional_convs:
+ self.activations2 = nn.ModuleList([
+ nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
+ for _ in range(len(self.convs2))
+ ])
+
+ elif nonlinear_activation == "Snake":
+ self.activations1 = nn.ModuleList([
+ Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
+ for _ in range(len(self.convs1))
+ ])
+ if use_additional_convs:
+ self.activations2 = nn.ModuleList([
+ Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
+ for _ in range(len(self.convs2))
+ ])
+
+ elif nonlinear_activation == "SnakeBeta":
+ self.activations1 = nn.ModuleList([
+ SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
+ for _ in range(len(self.convs1))
+ ])
+ if use_additional_convs:
+ self.activations2 = nn.ModuleList([
+ SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
+ for _ in range(len(self.convs2))
+ ])
+
+ else:
+ raise NotImplementedError
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for idx in range(len(self.convs1)):
+ xt = self.activations1[idx](x)
+ xt = self.convs1[idx](xt)
+ if self.use_additional_convs:
+ xt = self.activations2[idx](xt)
+ xt = self.convs2[idx](xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for idx in range(len(self.convs1)):
+ remove_weight_norm(self.convs1[idx])
+ if self.use_additional_convs:
+ remove_weight_norm(self.convs2[idx])
+
+
+class HifiGenerator(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 80,
+ base_channels: int = 512,
+ global_channels: int = -1,
+ upsample_rates: tp.List[int] = [8, 8, 2, 2],
+ upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ resblock_nonlinear_activation: str = "LeakyReLU",
+ resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
+ use_additional_convs: bool = True,
+ cond_in_each_up_layer: bool = False,
+ lrelu_slope: float = 0.1,
+ act_pre_each_up_layer: bool = True
+ ):
+ super(HifiGenerator, self).__init__()
+
+ self.out_channels = 1
+ self.global_channels = global_channels
+ self.use_additional_convs = use_additional_convs
+ self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
+ self.lrelu_slope = lrelu_slope
+ self.act_pre_each_up_layer = act_pre_each_up_layer
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+
+ self.conv_pre = weight_norm(
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
+ )
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ base_channels // (2**i),
+ base_channels // (2**(i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = base_channels // (2**(i + 1))
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
+ resblock_nonlinear_activation,
+ resblock_nonlinear_activation_params))
+
+ if self.global_channels > 0:
+ self.conv_global_cond = weight_norm(
+ Conv1d(global_channels, base_channels, 1)
+ )
+ self.conv_global_cond.apply(init_weights)
+
+ if self.cond_in_each_up_layer:
+ self.conv_conds = nn.ModuleList()
+ for i in range(len(self.ups)):
+ self.conv_conds.append(weight_norm(
+ nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
+ )
+ self.conv_conds.apply(init_weights)
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def output_size(self):
+ return self.out_channels
+
+ def forward(self, x: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+ # x in (B, in_channels, T), g in (B, global_channels, 1)
+ x = self.conv_pre(x)
+ if self.global_channels > 0 and g is not None:
+ x = x + self.conv_global_cond(g)
+
+ for i in range(self.num_upsamples):
+ if self.act_pre_each_up_layer:
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = self.ups[i](x)
+
+ if self.cond_in_each_up_layer and g is not None:
+ x = x + self.conv_conds[i](g)
+
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+ if self.global_channels > 0:
+ remove_weight_norm(self.conv_global_cond)
+ if self.cond_in_each_up_layer:
+ for l in self.conv_conds:
+ remove_weight_norm(l)
+
+
+class NsfHifiGenerator(nn.Module):
+ """
+ Neural Source Filter + HifiGan
+ """
+ def __init__(
+ self,
+ in_channels: int = 80,
+ base_channels: int = 512,
+ global_channels: int = -1,
+ nb_harmonics: int = 7,
+ sampling_rate: int = 22050,
+ nsf_alpha: float = 0.1,
+ nsf_sigma: float = 0.003,
+ nsf_voiced_threshold: float = 10,
+ upsample_rates: tp.List[int] = [8, 8, 2, 2],
+ upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ resblock_nonlinear_activation: str = "LeakyReLU",
+ resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
+ use_additional_convs: bool = True,
+ cond_in_each_up_layer: bool = False,
+ lrelu_slope: float = 0.1,
+ act_pre_each_up_layer: bool = True
+ ):
+ super(NsfHifiGenerator, self).__init__()
+
+ self.out_channels = 1
+ self.global_channels = global_channels
+ self.nb_harmonics = nb_harmonics
+ self.sampling_rate = sampling_rate
+ self.use_additional_convs = use_additional_convs
+ self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
+ self.lrelu_slope = lrelu_slope
+ self.act_pre_each_up_layer = act_pre_each_up_layer
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+
+ self.source_module = SourceModule(nb_harmonics, np.cumprod(upsample_rates)[-1],
+ sampling_rate, nsf_alpha, nsf_sigma, nsf_voiced_threshold)
+ self.conv_pre = weight_norm(
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
+ )
+
+ # Up
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ base_channels // (2**i),
+ base_channels // (2**(i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+ # Down
+ self.source_downs = nn.ModuleList()
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
+ downsample_cum_rates = np.cumprod(downsample_rates)
+ for i, u in enumerate(downsample_cum_rates[::-1]):
+ if (u == 1):
+ self.source_downs.append(
+ weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), 1, 1))
+ )
+ else:
+ self.source_downs.append(
+ weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2)))
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = base_channels // (2**(i + 1))
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
+ resblock_nonlinear_activation,
+ resblock_nonlinear_activation_params))
+
+ if self.global_channels > 0:
+ self.conv_global_cond = weight_norm(
+ Conv1d(global_channels, base_channels, 1)
+ )
+ self.conv_global_cond.apply(init_weights)
+
+ if self.cond_in_each_up_layer:
+ self.conv_conds = nn.ModuleList()
+ for i in range(len(self.ups)):
+ self.conv_conds.append(weight_norm(
+ nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
+ )
+ self.conv_conds.apply(init_weights)
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def output_size(self):
+ return self.out_channels
+
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
+ return self.source_module(f0.unsqueeze(1))
+
+ def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+ # x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
+
+ s = self._f02source(f0)
+
+ x = self.conv_pre(x)
+ if self.global_channels > 0 and g is not None:
+ x = x + self.conv_global_cond(g)
+
+ for i in range(self.num_upsamples):
+ if self.act_pre_each_up_layer:
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = self.ups[i](x)
+
+ if self.cond_in_each_up_layer and g is not None:
+ x = x + self.conv_conds[i](g)
+
+ # fusion
+ x = x + self.source_downs[i](s)
+
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+ if self.global_channels > 0:
+ remove_weight_norm(self.conv_global_cond)
+ if self.cond_in_each_up_layer:
+ for l in self.conv_conds:
+ remove_weight_norm(l)
+ self.source_module.remove_weight_norm()
+ for l in self.source_downs:
+ remove_weight_norm(l)
+
+
+class HiFTGenerator(nn.Module):
+ """
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
+ https://arxiv.org/abs/2309.09493
+ """
+ def __init__(
+ self,
+ in_channels: int = 80,
+ base_channels: int = 512,
+ global_channels: int = -1,
+ nb_harmonics: int = 8,
+ sampling_rate: int = 22050,
+ nsf_alpha: float = 0.1,
+ nsf_sigma: float = 0.003,
+ nsf_voiced_threshold: float = 10,
+ upsample_rates: tp.List[int] = [8, 8],
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ resblock_nonlinear_activation: str = "Snake",
+ resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
+ source_resblock_nonlinear_activation: str = "Snake",
+ source_resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
+ use_additional_convs: bool = True,
+ cond_in_each_up_layer: bool = False,
+ lrelu_slope: float = 0.1,
+ act_pre_each_up_layer: bool = True,
+ audio_limit: float = 0.99,
+ ):
+ super(HiFTGenerator, self).__init__()
+
+ self.out_channels = 1
+ self.global_channels = global_channels
+ self.nb_harmonics = nb_harmonics
+ self.sampling_rate = sampling_rate
+ self.istft_params = istft_params
+ self.use_additional_convs = use_additional_convs
+ self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
+ self.lrelu_slope = lrelu_slope
+ self.act_pre_each_up_layer = act_pre_each_up_layer
+ self.audio_limit = audio_limit
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=sampling_rate,
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
+ harmonic_num=nb_harmonics,
+ sine_amp=nsf_alpha,
+ add_noise_std=nsf_sigma,
+ voiced_threshod=nsf_voiced_threshold)
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
+
+ self.conv_pre = weight_norm(
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
+ )
+
+ # Up
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ base_channels // (2**i),
+ base_channels // (2**(i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ # Down
+ self.source_downs = nn.ModuleList()
+ self.source_resblocks = nn.ModuleList()
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
+ downsample_cum_rates = np.cumprod(downsample_rates)
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
+ source_resblock_dilation_sizes)):
+ if u == 1:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
+ )
+ else:
+ self.source_downs.append(
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2))
+ )
+
+ self.source_resblocks.append(
+ ResBlock(base_channels // (2 ** (i + 1)), k, d,
+ use_additional_convs, source_resblock_nonlinear_activation,
+ source_resblock_nonlinear_activation_params)
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = base_channels // (2**(i + 1))
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+ self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
+ resblock_nonlinear_activation,
+ resblock_nonlinear_activation_params))
+
+ if self.global_channels > 0:
+ self.conv_global_cond = weight_norm(
+ Conv1d(global_channels, base_channels, 1)
+ )
+ self.conv_global_cond.apply(init_weights)
+
+ if self.cond_in_each_up_layer:
+ self.conv_conds = nn.ModuleList()
+ for i in range(len(self.ups)):
+ self.conv_conds.append(weight_norm(
+ nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
+ )
+ self.conv_conds.apply(init_weights)
+
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
+ window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
+ self.register_buffer("stft_window", window)
+
+ def output_size(self):
+ return self.out_channels
+
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
+
+ har_source, _, _ = self.m_source(f0)
+ return har_source.transpose(1, 2)
+
+ def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+ # x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
+
+ s = self._f02source(f0)
+
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+
+ x = self.conv_pre(x)
+ if self.global_channels > 0 and g is not None:
+ x = x + self.conv_global_cond(g)
+
+ for i in range(self.num_upsamples):
+ if self.act_pre_each_up_layer:
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = self.ups[i](x)
+
+ if self.cond_in_each_up_layer and g is not None:
+ x = x + self.conv_conds[i](g)
+
+ if i == self.num_upsamples - 1:
+ x = self.reflection_pad(x)
+
+ # fusion
+ si = self.source_downs[i](s_stft)
+ si = self.source_resblocks[i](si)
+ x = x + si
+
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
+
+ x = self._istft(magnitude, phase)
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+ if self.global_channels > 0:
+ remove_weight_norm(self.conv_global_cond)
+ if self.cond_in_each_up_layer:
+ for l in self.conv_conds:
+ remove_weight_norm(l)
+ self.source_module.remove_weight_norm()
+ for l in self.source_downs:
+ remove_weight_norm(l)
+ for l in self.source_resblocks:
+ l.remove_weight_norm()
+
+ def _stft(self, x):
+ spec = torch.stft(
+ x,
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window,
+ return_complex=True)
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
+ return spec[...,0], spec[...,1]
+
+ def _istft(self, magnitude, phase):
+ magnitude = torch.clip(magnitude, max=1e2)
+ real = magnitude * torch.cos(phase)
+ img = magnitude * torch.sin(phase)
+ inverse_transform = torch.istft(
+ # torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1),
+ torch.complex(real, img),
+ self.istft_params["n_fft"], self.istft_params["hop_len"],
+ self.istft_params["n_fft"], window=self.stft_window,
+ return_complex=False
+ )
+
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
diff --git a/funcineforge/models/modules/hifigan/mel_spectrum.py b/funcineforge/models/modules/hifigan/mel_spectrum.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e768e75ac2f1b3ba638d5e6a0a3906324821e6
--- /dev/null
+++ b/funcineforge/models/modules/hifigan/mel_spectrum.py
@@ -0,0 +1,93 @@
+import torch
+import torch.utils.data
+import numpy as np
+from librosa.filters import mel as librosa_mel_fn
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ global mel_basis, hann_window
+ if fmax not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
+
+
+def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ global mel_basis, hann_window
+ if fmax not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+ spec = spectral_normalize_torch(spec)
+
+ return spec
+
+
+def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ global mel_basis, hann_window
+ spec = spectral_de_normalize_torch(spec)
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
diff --git a/funcineforge/models/modules/hifigan/nsf_utils.py b/funcineforge/models/modules/hifigan/nsf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..66d955c35664a4d7efb1f3502e1b572b2c6341b6
--- /dev/null
+++ b/funcineforge/models/modules/hifigan/nsf_utils.py
@@ -0,0 +1,253 @@
+"""
+Neural Source Filter based modules implementation.
+
+Neural source-filter waveform models for statistical parametric speech synthesis
+
+"""
+
+import numpy as np
+import typing as tp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import weight_norm, remove_weight_norm
+from torch.distributions.uniform import Uniform
+from torch.distributions.normal import Normal
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ @torch.no_grad()
+ def forward(self, f0):
+ """
+ :param f0: [B, 1, sample_len], Hz
+ :return: [B, 1, sample_len]
+ """
+
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
+ for i in range(self.harmonic_num + 1):
+ F_mat[:, i:i+1, :] = f0 * (i+1) / self.sampling_rate
+
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
+ u_dist = Uniform(low=-np.pi, high=np.pi)
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
+ phase_vec[:, 0, :] = 0
+
+ # generate sine waveforms
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
+
+ # generate uv signal
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ with torch.no_grad():
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1,2))
+ sine_wavs = sine_wavs.transpose(1,2)
+ uv = uv.transpose(1,2)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+
+
+class SourceModule(torch.nn.Module):
+ def __init__(self,
+ nb_harmonics: int,
+ upsample_ratio: int,
+ sampling_rate: int,
+ alpha: float = 0.1,
+ sigma: float = 0.003,
+ voiced_threshold: float = 10
+ ):
+ super(SourceModule, self).__init__()
+
+ self.nb_harmonics = nb_harmonics
+ self.upsample_ratio = upsample_ratio
+ self.sampling_rate = sampling_rate
+ self.alpha = alpha
+ self.sigma = sigma
+ self.voiced_threshold = voiced_threshold
+
+ self.ffn = nn.Sequential(
+ weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
+ nn.Tanh())
+
+ def f02uv(self, f0):
+ # generate uv signal
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
+ return uv
+
+ def forward(self, f0):
+ """
+ :param f0: [B, 1, frame_len], Hz
+ :return: [B, 1, sample_len]
+ """
+ with torch.no_grad():
+ uv = self.f02uv(f0)
+ f0_samples = F.interpolate(f0, scale_factor=(self.upsample_ratio), mode='nearest')
+ uv_samples = F.interpolate(uv, scale_factor=(self.upsample_ratio), mode='nearest')
+
+ F_mat = torch.zeros((f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(f0_samples.device)
+ for i in range(self.nb_harmonics + 1):
+ F_mat[:, i:i+1, :] = f0_samples * (i+1) / self.sampling_rate
+
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
+ u_dist = Uniform(low=-np.pi, high=np.pi)
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.nb_harmonics + 1, 1)).to(F_mat.device)
+ phase_vec[:, 0, :] = 0
+
+ n_dist = Normal(loc=0., scale=self.sigma)
+ noise = n_dist.sample(sample_shape=(f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(F_mat.device)
+
+ e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
+ e_unvoice = self.alpha / 3 / self.sigma * noise
+
+ e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
+
+ return self.ffn(e)
+
+ def remove_weight_norm(self):
+ remove_weight_norm(self.ffn[0])
+
+
+class ConvRNNF0Predictor(nn.Module):
+ def __init__(self,
+ num_class: int = 1,
+ in_channels: int = 80,
+ cond_channels: int = 512,
+ use_cond_rnn: bool = True,
+ bidirectional_rnn: bool = False,
+ ):
+
+ super().__init__()
+
+ self.num_class = num_class
+ self.use_cond_rnn = use_cond_rnn
+
+ self.condnet = nn.Sequential(
+ weight_norm(
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ weight_norm(
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
+ ),
+ nn.ELU(),
+ )
+
+ if self.use_cond_rnn:
+ self.rnn = nn.GRU(
+ cond_channels,
+ cond_channels // 2 if bidirectional_rnn else cond_channels,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=bidirectional_rnn,
+ )
+
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.condnet(x)
+ if self.use_cond_rnn:
+ x, _ = self.rnn(x.transpose(1, 2))
+ else:
+ x = x.transpose(1, 2)
+
+ return torch.abs(self.classifier(x).squeeze(-1))
+
+
+
diff --git a/funcineforge/models/specaug/__init__.py b/funcineforge/models/specaug/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/funcineforge/models/specaug/mask_along_axis.py b/funcineforge/models/specaug/mask_along_axis.py
new file mode 100644
index 0000000000000000000000000000000000000000..51543c2ac0dab931fddde3395ba2c374c4038c46
--- /dev/null
+++ b/funcineforge/models/specaug/mask_along_axis.py
@@ -0,0 +1,204 @@
+import math
+import torch
+from typing import Sequence
+from typing import Union
+
+
+def mask_along_axis(
+ spec: torch.Tensor,
+ spec_lengths: torch.Tensor,
+ mask_width_range: Sequence[int] = (0, 30),
+ dim: int = 1,
+ num_mask: int = 2,
+ replace_with_zero: bool = True,
+ fill_value: float = 0.0,
+):
+ """Apply mask along the specified direction.
+
+ Args:
+ spec: (Batch, Length, Freq)
+ spec_lengths: (Length): Not using lengths in this implementation
+ mask_width_range: Select the width randomly between this range
+ """
+
+ org_size = spec.size()
+ if spec.dim() == 4:
+ # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
+ spec = spec.view(-1, spec.size(2), spec.size(3))
+
+ B = spec.shape[0]
+ # D = Length or Freq
+ D = spec.shape[dim]
+ # mask_length: (B, num_mask, 1)
+ mask_length = torch.randint(
+ mask_width_range[0],
+ mask_width_range[1],
+ (B, num_mask),
+ device=spec.device,
+ ).unsqueeze(2)
+
+ # mask_pos: (B, num_mask, 1)
+ mask_pos = torch.randint(
+ 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
+ ).unsqueeze(2)
+
+ # aran: (1, 1, D)
+ aran = torch.arange(D, device=spec.device)[None, None, :]
+ # mask: (Batch, num_mask, D)
+ mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
+ # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
+ mask = mask.any(dim=1)
+ if dim == 1:
+ # mask: (Batch, Length, 1)
+ mask = mask.unsqueeze(2)
+ elif dim == 2:
+ # mask: (Batch, 1, Freq)
+ mask = mask.unsqueeze(1)
+
+ if replace_with_zero:
+ value = fill_value
+ else:
+ value = spec.mean()
+
+ spec = spec.masked_fill(mask, value)
+ spec = spec.view(*org_size)
+ return spec, spec_lengths
+
+
+class MaskAlongAxis(torch.nn.Module):
+ def __init__(
+ self,
+ mask_width_range: Union[int, Sequence[int]] = (0, 30),
+ num_mask: int = 2,
+ dim: Union[int, str] = "time",
+ replace_with_zero: bool = True,
+ fill_value: float = 0.0,
+ ):
+ if isinstance(mask_width_range, int):
+ mask_width_range = (0, mask_width_range)
+ if len(mask_width_range) != 2:
+ raise TypeError(
+ f"mask_width_range must be a tuple of int and int values: " f"{mask_width_range}",
+ )
+
+ assert mask_width_range[1] > mask_width_range[0]
+ if isinstance(dim, str):
+ if dim == "time":
+ dim = 1
+ elif dim == "freq":
+ dim = 2
+ else:
+ raise ValueError("dim must be int, 'time' or 'freq'")
+ if dim == 1:
+ self.mask_axis = "time"
+ elif dim == 2:
+ self.mask_axis = "freq"
+ else:
+ self.mask_axis = "unknown"
+
+ super().__init__()
+ self.mask_width_range = mask_width_range
+ self.num_mask = num_mask
+ self.dim = dim
+ self.replace_with_zero = replace_with_zero
+ self.fill_value = fill_value
+
+ def extra_repr(self):
+ return (
+ f"mask_width_range={self.mask_width_range}, "
+ f"num_mask={self.num_mask}, axis={self.mask_axis}"
+ )
+
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
+ """Forward function.
+
+ Args:
+ spec: (Batch, Length, Freq)
+ """
+
+ return mask_along_axis(
+ spec,
+ spec_lengths,
+ mask_width_range=self.mask_width_range,
+ dim=self.dim,
+ num_mask=self.num_mask,
+ replace_with_zero=self.replace_with_zero,
+ fill_value=self.fill_value,
+ )
+
+
+class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
+ """Mask input spec along a specified axis with variable maximum width.
+
+ Formula:
+ max_width = max_width_ratio * seq_len
+ """
+
+ def __init__(
+ self,
+ mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
+ num_mask: int = 2,
+ dim: Union[int, str] = "time",
+ replace_with_zero: bool = True,
+ fill_value: float = 0.0,
+ ):
+ if isinstance(mask_width_ratio_range, float):
+ mask_width_ratio_range = (0.0, mask_width_ratio_range)
+ if len(mask_width_ratio_range) != 2:
+ raise TypeError(
+ f"mask_width_ratio_range must be a tuple of float and float values: "
+ f"{mask_width_ratio_range}",
+ )
+
+ assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
+ if isinstance(dim, str):
+ if dim == "time":
+ dim = 1
+ elif dim == "freq":
+ dim = 2
+ else:
+ raise ValueError("dim must be int, 'time' or 'freq'")
+ if dim == 1:
+ self.mask_axis = "time"
+ elif dim == 2:
+ self.mask_axis = "freq"
+ else:
+ self.mask_axis = "unknown"
+
+ super().__init__()
+ self.mask_width_ratio_range = mask_width_ratio_range
+ self.num_mask = num_mask
+ self.dim = dim
+ self.replace_with_zero = replace_with_zero
+ self.fill_value = fill_value
+
+ def extra_repr(self):
+ return (
+ f"mask_width_ratio_range={self.mask_width_ratio_range}, "
+ f"num_mask={self.num_mask}, axis={self.mask_axis}"
+ )
+
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
+ """Forward function.
+
+ Args:
+ spec: (Batch, Length, Freq)
+ """
+
+ max_seq_len = spec.shape[self.dim]
+ min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
+ min_mask_width = max([0, min_mask_width])
+ max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
+ max_mask_width = min([max_seq_len, max_mask_width])
+
+ if max_mask_width > min_mask_width:
+ return mask_along_axis(
+ spec,
+ spec_lengths,
+ mask_width_range=(min_mask_width, max_mask_width),
+ dim=self.dim,
+ num_mask=self.num_mask,
+ replace_with_zero=self.replace_with_zero,
+ fill_value=self.fill_value,
+ )
+ return spec, spec_lengths
diff --git a/funcineforge/models/specaug/specaug.py b/funcineforge/models/specaug/specaug.py
new file mode 100644
index 0000000000000000000000000000000000000000..74dd93dd4b88e452294912043b1540fee742487a
--- /dev/null
+++ b/funcineforge/models/specaug/specaug.py
@@ -0,0 +1,103 @@
+"""SpecAugment module."""
+
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+from funcineforge.models.specaug.mask_along_axis import MaskAlongAxis
+from funcineforge.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
+from funcineforge.models.specaug.time_warp import TimeWarp
+
+import torch.nn as nn
+
+
+class SpecAug(nn.Module):
+ """Implementation of SpecAug.
+
+ Reference:
+ Daniel S. Park et al.
+ "SpecAugment: A Simple Data
+ Augmentation Method for Automatic Speech Recognition"
+
+ .. warning::
+ When using cuda mode, time_warp doesn't have reproducibility
+ due to `torch.nn.functional.interpolate`.
+
+ """
+
+ def __init__(
+ self,
+ apply_time_warp: bool = True,
+ time_warp_window: int = 5,
+ time_warp_mode: str = "bicubic",
+ apply_freq_mask: bool = True,
+ freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
+ num_freq_mask: int = 2,
+ apply_time_mask: bool = True,
+ time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
+ time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
+ num_time_mask: int = 2,
+ fill_value: float = 0.0,
+ ):
+ if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
+ raise ValueError("Either one of time_warp, time_mask, or freq_mask should be applied")
+ if (
+ apply_time_mask
+ and (time_mask_width_range is not None)
+ and (time_mask_width_ratio_range is not None)
+ ):
+ raise ValueError(
+ 'Either one of "time_mask_width_range" or '
+ '"time_mask_width_ratio_range" can be used'
+ )
+ super().__init__()
+ self.apply_time_warp = apply_time_warp
+ self.apply_freq_mask = apply_freq_mask
+ self.apply_time_mask = apply_time_mask
+
+ if apply_time_warp:
+ self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
+ else:
+ self.time_warp = None
+
+ if apply_freq_mask:
+ self.freq_mask = MaskAlongAxis(
+ dim="freq",
+ mask_width_range=freq_mask_width_range,
+ num_mask=num_freq_mask,
+ fill_value=fill_value,
+ )
+ else:
+ self.freq_mask = None
+
+ if apply_time_mask:
+ if time_mask_width_range is not None:
+ self.time_mask = MaskAlongAxis(
+ dim="time",
+ mask_width_range=time_mask_width_range,
+ num_mask=num_time_mask,
+ fill_value=fill_value,
+ )
+ elif time_mask_width_ratio_range is not None:
+ self.time_mask = MaskAlongAxisVariableMaxWidth(
+ dim="time",
+ mask_width_ratio_range=time_mask_width_ratio_range,
+ num_mask=num_time_mask,
+ fill_value=fill_value,
+ )
+ else:
+ raise ValueError(
+ 'Either one of "time_mask_width_range" or '
+ '"time_mask_width_ratio_range" should be used.'
+ )
+ else:
+ self.time_mask = None
+
+ def forward(self, x, x_lengths=None):
+ if self.time_warp is not None:
+ x, x_lengths = self.time_warp(x, x_lengths)
+ if self.freq_mask is not None:
+ x, x_lengths = self.freq_mask(x, x_lengths)
+ if self.time_mask is not None:
+ x, x_lengths = self.time_mask(x, x_lengths)
+ return x, x_lengths
diff --git a/funcineforge/models/specaug/time_warp.py b/funcineforge/models/specaug/time_warp.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fbea6f7b557976de0dfe857f8edf198cee84ce4
--- /dev/null
+++ b/funcineforge/models/specaug/time_warp.py
@@ -0,0 +1,89 @@
+"""Time warp module."""
+
+import torch
+
+from funcineforge.models.utils.nets_utils import pad_list
+
+DEFAULT_TIME_WARP_MODE = "bicubic"
+
+
+def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
+ """Time warping using torch.interpolate.
+
+ Args:
+ x: (Batch, Time, Freq)
+ window: time warp parameter
+ mode: Interpolate mode
+ """
+
+ # bicubic supports 4D or more dimension tensor
+ org_size = x.size()
+ if x.dim() == 3:
+ # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
+ x = x[:, None]
+
+ t = x.shape[2]
+ if t - window <= window:
+ return x.view(*org_size)
+
+ center = torch.randint(window, t - window, (1,))[0]
+ warped = torch.randint(center - window, center + window, (1,))[0] + 1
+
+ # left: (Batch, Channel, warped, Freq)
+ # right: (Batch, Channel, time - warped, Freq)
+ left = torch.nn.functional.interpolate(
+ x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
+ )
+ right = torch.nn.functional.interpolate(
+ x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
+ )
+
+ if x.requires_grad:
+ x = torch.cat([left, right], dim=-2)
+ else:
+ x[:, :, :warped] = left
+ x[:, :, warped:] = right
+
+ return x.view(*org_size)
+
+
+class TimeWarp(torch.nn.Module):
+ """Time warping using torch.interpolate.
+
+ Args:
+ window: time warp parameter
+ mode: Interpolate mode
+ """
+
+ def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
+ super().__init__()
+ self.window = window
+ self.mode = mode
+
+ def extra_repr(self):
+ return f"window={self.window}, mode={self.mode}"
+
+ def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
+ """Forward function.
+
+ Args:
+ x: (Batch, Time, Freq)
+ x_lengths: (Batch,)
+ """
+
+ if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
+ # Note that applying same warping for each sample
+ y = time_warp(x, window=self.window, mode=self.mode)
+ else:
+ # FIXME(kamo): I have no idea to batchify Timewarp
+ ys = []
+ for i in range(x.size(0)):
+ _y = time_warp(
+ x[i][None, : x_lengths[i]],
+ window=self.window,
+ mode=self.mode,
+ )[0]
+ ys.append(_y)
+ y = pad_list(ys, 0.0)
+
+ return y, x_lengths
diff --git a/funcineforge/models/utils/__init__.py b/funcineforge/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d0e31162cb06125ed890867c0d021889a7a2b6c
--- /dev/null
+++ b/funcineforge/models/utils/__init__.py
@@ -0,0 +1,2 @@
+import torch
+dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
diff --git a/funcineforge/models/utils/llm_decoding.py b/funcineforge/models/utils/llm_decoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d974edf6df4172604874e92cb009f604066c23e
--- /dev/null
+++ b/funcineforge/models/utils/llm_decoding.py
@@ -0,0 +1,178 @@
+from contextlib import nullcontext
+import torch
+import torch.nn as nn
+from typing import Union
+from funcineforge.utils.hinter import hint_once
+import numpy as np
+dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
+
+
+class LLMDecoder(nn.Module):
+ def __init__(self, **kwargs):
+ super(LLMDecoder, self).__init__()
+ self.eos_token = kwargs["eos"]
+ if isinstance(self.eos_token, int):
+ self.eos_token = [self.eos_token]
+ self.token_embeder = kwargs["token_embeder"]
+ self.ras_conf = kwargs.get("ras_conf", {})
+ self.token_offset = kwargs.get("token_offset", 0)
+
+ def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25, beam_size=1):
+ prob, indices = [], []
+ cum_prob = 0.0
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
+ for i in range(len(sorted_idx)):
+ # sampling both top-p and numbers.
+ if cum_prob < top_p and len(prob) < top_k:
+ cum_prob += sorted_value[i]
+ prob.append(sorted_value[i])
+ indices.append(sorted_idx[i])
+ else:
+ break
+ prob = torch.tensor(prob).to(weighted_scores)
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
+ sampling_ids = prob.multinomial(beam_size, replacement=True)
+ top_ids = indices[sampling_ids]
+ return top_ids
+
+ def random_sampling(self, weighted_scores, beam_size=1):
+ top_ids = weighted_scores.softmax(dim=0).multinomial(beam_size, replacement=True)
+ return top_ids
+
+ # Repetition Aware Sampling in VALL-E 2
+ def ras_sampling(
+ self, weighted_scores, decoded_tokens, *,
+ top_p=0.8, top_k=25, win_size=10, tau_r=0.1
+ ):
+ if self.ras_conf is not None:
+ top_p = self.ras_conf.get("top_p", top_p)
+ top_k = self.ras_conf.get("top_k", top_k)
+ win_size = self.ras_conf.get("win_size", win_size)
+ tau_r = self.ras_conf.get("tau_r", tau_r)
+
+ hint_once(f"using Repetition Aware Sampling: top_p: {top_p}, top_k: {top_k},win_size: {win_size}, tau_r: {tau_r}", "ras_sampling")
+ top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()
+ if rep_num >= win_size * tau_r:
+ top_ids = self.random_sampling(weighted_scores)
+
+ return top_ids
+
+ def sampling_ids(
+ self,
+ weighted_scores: torch.Tensor,
+ sampling: Union[bool, int, float] = True,
+ decoded_tokens: list = None,
+ ):
+ if isinstance(sampling, bool):
+ if sampling:
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
+ else:
+ top_ids = weighted_scores.topk(1)[1]
+ elif isinstance(sampling, int):
+ prob, indices = weighted_scores.softmax(dim=0).topk(sampling)
+ sampling_ids = prob.multinomial(1, replacement=True)
+ top_ids = indices[sampling_ids]
+ elif isinstance(sampling, float):
+ prob, indices = [], []
+ cum_prob = 0.0
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
+ for i in range(len(sorted_idx)):
+ # sampling both top-p and numbers.
+ if cum_prob < sampling and len(prob) < 25:
+ cum_prob += sorted_value[i]
+ prob.append(sorted_value[i])
+ indices.append(sorted_idx[i])
+ else:
+ break
+ prob = torch.tensor(prob).to(weighted_scores)
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
+ sampling_ids = prob.multinomial(1, replacement=True)
+ top_ids = indices[sampling_ids]
+ elif isinstance(sampling, str) and sampling.lower() == "ras":
+ top_ids = self.ras_sampling(weighted_scores, decoded_tokens=decoded_tokens)
+ else:
+ raise NotImplementedError(f"Not implemented for {type(sampling)} sampling")
+
+ return top_ids
+
+ def __call__(self, input_embeddings, llm, states, quantize=False, **kwargs):
+ max_length = kwargs.get("max_length", 60 * 25)
+ min_length = kwargs.get("min_length", 2 * 25)
+ sampling = kwargs.get("sampling", True)
+ device = kwargs.get("device", "cuda")
+ llm_dtype = kwargs.get("llm_dtype", "fp32")
+ use_llm_cache = kwargs.get("use_llm_cache", True)
+ include_eos = kwargs.get("include_eos", False)
+ custom_eos_token = kwargs.get("custom_eos_token", self.eos_token)
+ avoid_token = kwargs.get("avoid_token", None)
+
+ llm_cache = states.get("llm_cache", None)
+ out_tokens, hit_eos = [], False
+ for i in range(max_length):
+ with torch.cuda.amp.autocast(
+ enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
+ ) if quantize is False else nullcontext():
+ # default attention_mask is causal, no longer need manually construct
+ # input_masks = torch.ones((1, input_embeddings.shape[1]), device=input_embeddings.device).to(torch.bool)
+
+ if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
+ outputs = llm.base_model.model(
+ inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings,
+ # attention_mask=input_masks,
+ output_hidden_states=True,
+ return_dict=True,
+ use_cache=use_llm_cache,
+ past_key_values=llm_cache,
+ )
+ else:
+ outputs = llm(
+ inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings,
+ # attention_mask=input_masks,
+ output_hidden_states=True,
+ return_dict=True,
+ use_cache=use_llm_cache,
+ past_key_values=llm_cache,
+ )
+ lm_hidden_states = outputs.hidden_states[-1]
+ h = llm.lm_head(lm_hidden_states[:, -1])
+ # logp = h.log_softmax(dim=-1).squeeze(0)
+ logp = h.squeeze(0)
+ if use_llm_cache:
+ llm_cache = outputs.past_key_values
+
+ pred = torch.log_softmax(logp, dim=-1)
+ if min_length is not None and i < min_length:
+ for x in custom_eos_token:
+ if pred.dtype == torch.bfloat16:
+ pred[x] = float(np.finfo(np.float16).min)
+ else:
+ pred[x] = float(np.finfo(np.float32).min)
+ if avoid_token is not None and len(avoid_token) > 0:
+ for x in avoid_token:
+ if pred.dtype == torch.bfloat16:
+ pred[x] = float(np.finfo(np.float16).min)
+ else:
+ pred[x] = float(np.finfo(np.float32).min)
+ top_id = self.sampling_ids(pred, sampling, out_tokens)[0].item()
+
+ if top_id in custom_eos_token:
+ if include_eos:
+ out_tokens.append(top_id)
+ hit_eos = True
+ break
+
+ out_tokens.append(top_id)
+ if use_llm_cache:
+ input_embeddings = self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset)
+ else:
+ input_embeddings = torch.cat([
+ input_embeddings,
+ self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset)
+ ], dim=1)
+
+ out_tokens = torch.tensor([out_tokens], dtype=torch.int64, device=device)
+
+ states = {"llm_cache": llm_cache}
+
+ return out_tokens, hit_eos, states
diff --git a/funcineforge/models/utils/mask_along_axis.py b/funcineforge/models/utils/mask_along_axis.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1004eaa492687898f8e8f8232e11d3d450f9d0c
--- /dev/null
+++ b/funcineforge/models/utils/mask_along_axis.py
@@ -0,0 +1,76 @@
+import torch
+from typing import Sequence
+from typing import Union
+
+
+class MaskTailVariableMaxWidth(torch.nn.Module):
+ def __init__(
+ self,
+ mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
+ replace_value: float = 0.0,
+ ):
+ super().__init__()
+ self.mask_width_ratio_range = mask_width_ratio_range
+ self.replace_value = replace_value
+
+ def extra_repr(self):
+ return (
+ f"mask_width_ratio_range={self.mask_width_ratio_range}, "
+ )
+
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
+ bb, tt, _ = spec.shape
+
+ mask_width_ratio = torch.rand((bb, 1), device=spec.device)
+ ratio_st, ratio_ed = self.mask_width_ratio_range
+ mask_width_ratio = mask_width_ratio * (ratio_ed - ratio_st) + ratio_st
+ mask_length = (mask_width_ratio * spec_lengths.unsqueeze(1)).to(spec_lengths)
+
+ # mask_pos: (B, 1)
+ mask_start_pos = spec_lengths.unsqueeze(-1) - mask_length
+
+ aran = torch.arange(tt, device=spec.device)[None, :]
+ # mask: (Batch, L)
+ mask = aran < mask_start_pos
+ # (Batch, L) -> (Batch, L, 1)
+ mask = mask.unsqueeze(2)
+
+ return mask
+
+class PrefixMaskVariableMaxWidth(torch.nn.Module):
+ def __init__(
+ self,
+ mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
+ replace_value: float = 0.0,
+ ):
+ super().__init__()
+ self.mask_width_ratio_range = mask_width_ratio_range
+ self.replace_value = replace_value
+
+ def extra_repr(self):
+ return (
+ f"mask_width_ratio_range={self.mask_width_ratio_range}, "
+ )
+
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None, return_mask: bool = False):
+ bb, tt, _ = spec.shape
+
+ mask_width_ratio_range = torch.tensor(self.mask_width_ratio_range, dtype=torch.float32, device=spec.device)
+ mask_width_range = (mask_width_ratio_range * tt).long()
+ mask_length = torch.randint(
+ mask_width_range[0],
+ mask_width_range[1],
+ (bb, 1),
+ device=spec.device,
+ ).unsqueeze(2)
+
+ # mask_pos: (B, num_mask, 1)
+ mask_pos = tt - mask_length
+
+ aran = torch.arange(tt, device=spec.device)[None, None, :]
+ # mask: (Batch, num_mask, L)
+ mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
+ # Multiply masks: (Batch, num_mask, L) -> (Batch, L, 1)
+ mask = mask.any(dim=1).unsqueeze(2)
+
+ return mask
diff --git a/funcineforge/models/utils/masks.py b/funcineforge/models/utils/masks.py
new file mode 100644
index 0000000000000000000000000000000000000000..a67c79c020063880399874443dd8a8ec9ece657a
--- /dev/null
+++ b/funcineforge/models/utils/masks.py
@@ -0,0 +1,132 @@
+import torch
+
+def add_optional_chunk_mask(xs: torch.Tensor,
+ masks: torch.Tensor,
+ use_dynamic_chunk: bool,
+ use_dynamic_left_chunk: bool,
+ decoding_chunk_size: int,
+ static_chunk_size: int,
+ num_decoding_left_chunks: int,
+ enable_full_context: bool = True):
+ """ Apply optional mask for encoder.
+
+ Args:
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
+ mask (torch.Tensor): mask for xs, (B, 1, L)
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
+ training.
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ static_chunk_size (int): chunk size for static chunk training/decoding
+ if it's greater than 0, if use_dynamic_chunk is true,
+ this parameter will be ignored
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ enable_full_context (bool):
+ True: chunk size is either [1, 25] or full context(max_len)
+ False: chunk size ~ U[1, 25]
+
+ Returns:
+ torch.Tensor: chunk mask of the input xs.
+ """
+ # Whether to use chunk mask or not
+ if use_dynamic_chunk:
+ max_len = xs.size(1)
+ if decoding_chunk_size < 0:
+ chunk_size = max_len
+ num_left_chunks = -1
+ elif decoding_chunk_size > 0:
+ chunk_size = decoding_chunk_size
+ num_left_chunks = num_decoding_left_chunks
+ else:
+ # chunk size is either [1, 25] or full context(max_len).
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
+ # delay, the maximum frame is 100 / 4 = 25.
+ chunk_size = torch.randint(1, max_len, (1, )).item()
+ num_left_chunks = -1
+ if chunk_size > max_len // 2 and enable_full_context:
+ chunk_size = max_len
+ else:
+ chunk_size = chunk_size % 25 + 1
+ if use_dynamic_left_chunk:
+ max_left_chunks = (max_len - 1) // chunk_size
+ num_left_chunks = torch.randint(0, max_left_chunks,
+ (1, )).item()
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ elif static_chunk_size > 0:
+ num_left_chunks = num_decoding_left_chunks
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ else:
+ chunk_masks = masks
+ assert chunk_masks.dtype == torch.bool
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
+ print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
+ chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
+ return chunk_masks
+
+
+def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size,
+ this is for streaming encoder
+
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_chunk_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
+ pos_idx = torch.arange(size, device=device)
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
+ return ret
+
+def causal_block_mask(size, block_size=1, device="cpu", dtype=torch.bool):
+ """Create mask for subsequent steps (size, size).
+
+ :param int size: size of mask
+ :param int block_size: block size of mask
+ :param str device: "cpu" or "cuda" or torch.Tensor.device
+ :param torch.dtype dtype: result dtype
+ :rtype: torch.Tensor
+ >>> causal_block_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ # assert size % block_size == 0
+ pos_idx = torch.arange(size, device=device)
+ block_value = (torch.div(pos_idx, block_size, rounding_mode='trunc') + 1) * block_size
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
+ return ret.to(dtype)
\ No newline at end of file
diff --git a/funcineforge/models/utils/nets_utils.py b/funcineforge/models/utils/nets_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..29d23ee59c75068d2bafc1f09b76753636e0be7f
--- /dev/null
+++ b/funcineforge/models/utils/nets_utils.py
@@ -0,0 +1,734 @@
+# -*- coding: utf-8 -*-
+
+"""Network related utility tools."""
+
+import logging
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+
+
+def to_device(m, x):
+ """Send tensor into the device of the module.
+
+ Args:
+ m (torch.nn.Module): Torch module.
+ x (Tensor): Torch tensor.
+
+ Returns:
+ Tensor: Torch tensor located in the same place as torch module.
+
+ """
+ if isinstance(m, torch.nn.Module):
+ device = next(m.parameters()).device
+ elif isinstance(m, torch.Tensor):
+ device = m.device
+ else:
+ raise TypeError("Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}")
+ return x.to(device)
+
+
+def pad_list(xs, pad_value):
+ """Perform padding for the list of tensors.
+
+ Args:
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+ pad_value (float): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tmax, `*`).
+
+ Examples:
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+ >>> x
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+ >>> pad_list(x, 0)
+ tensor([[1., 1., 1., 1.],
+ [1., 1., 0., 0.],
+ [1., 0., 0., 0.]])
+
+ """
+ n_batch = len(xs)
+ max_len = max(x.size(0) for x in xs)
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+
+ for i in range(n_batch):
+ pad[i, : xs[i].size(0)] = xs[i]
+
+ return pad
+
+
+def pad_list_all_dim(xs, pad_value):
+ """Perform padding for the list of tensors.
+
+ Args:
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+ pad_value (float): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tmax, `*`).
+
+ Examples:
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+ >>> x
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+ >>> pad_list(x, 0)
+ tensor([[1., 1., 1., 1.],
+ [1., 1., 0., 0.],
+ [1., 0., 0., 0.]])
+
+ """
+ n_batch = len(xs)
+ num_dim = len(xs[0].shape)
+ max_len_all_dim = []
+ for i in range(num_dim):
+ max_len_all_dim.append(max(x.size(i) for x in xs))
+ pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
+
+ for i in range(n_batch):
+ if num_dim == 1:
+ pad[i, : xs[i].size(0)] = xs[i]
+ elif num_dim == 2:
+ pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
+ elif num_dim == 3:
+ pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
+ else:
+ raise ValueError(
+ "pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
+ )
+
+ return pad
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
+ """Make mask tensor containing indices of padded part.
+
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+
+ Returns:
+ Tensor: Mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+
+ Examples:
+ With only lengths.
+
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+
+ With the reference tensor.
+
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0],
+ [0, 0, 0, 0]],
+ [[0, 0, 0, 1],
+ [0, 0, 0, 1]],
+ [[0, 0, 1, 1],
+ [0, 0, 1, 1]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+
+ With the reference tensor and dimension indicator.
+
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_pad_mask(lengths, xs, 1)
+ tensor([[[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
+ >>> make_pad_mask(lengths, xs, 2)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+
+ """
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if not isinstance(lengths, list):
+ lengths = lengths.tolist()
+ bs = int(len(lengths))
+ if maxlen is None:
+ if xs is None:
+ maxlen = int(max(lengths))
+ else:
+ maxlen = xs.size(length_dim)
+ else:
+ assert xs is None
+ assert maxlen >= int(max(lengths))
+
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ if xs is not None:
+ assert xs.size(0) == bs, (xs.size(0), bs)
+
+ if length_dim < 0:
+ length_dim = xs.dim() + length_dim
+ # ind = (:, None, ..., None, :, , None, ..., None)
+ ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
+ mask = mask[ind].expand_as(xs).to(xs.device)
+ return mask
+
+
+def make_non_pad_mask(lengths, xs=None, length_dim=-1):
+ """Make mask tensor containing indices of non-padded part.
+
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+
+ Returns:
+ ByteTensor: mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+
+ Examples:
+ With only lengths.
+
+ >>> lengths = [5, 3, 2]
+ >>> make_non_pad_mask(lengths)
+ masks = [[1, 1, 1, 1 ,1],
+ [1, 1, 1, 0, 0],
+ [1, 1, 0, 0, 0]]
+
+ With the reference tensor.
+
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1],
+ [1, 1, 1, 1]],
+ [[1, 1, 1, 0],
+ [1, 1, 1, 0]],
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+
+ With the reference tensor and dimension indicator.
+
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_non_pad_mask(lengths, xs, 1)
+ tensor([[[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
+ >>> make_non_pad_mask(lengths, xs, 2)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+
+ """
+ return ~make_pad_mask(lengths, xs, length_dim)
+
+
+def mask_by_length(xs, lengths, fill=0):
+ """Mask tensor according to length.
+
+ Args:
+ xs (Tensor): Batch of input tensor (B, `*`).
+ lengths (LongTensor or List): Batch of lengths (B,).
+ fill (int or float): Value to fill masked part.
+
+ Returns:
+ Tensor: Batch of masked input tensor (B, `*`).
+
+ Examples:
+ >>> x = torch.arange(5).repeat(3, 1) + 1
+ >>> x
+ tensor([[1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5]])
+ >>> lengths = [5, 3, 2]
+ >>> mask_by_length(x, lengths)
+ tensor([[1, 2, 3, 4, 5],
+ [1, 2, 3, 0, 0],
+ [1, 2, 0, 0, 0]])
+
+ """
+ assert xs.size(0) == len(lengths)
+ ret = xs.data.new(*xs.size()).fill_(fill)
+ for i, l in enumerate(lengths):
+ ret[i, :l] = xs[i, :l]
+ return ret
+
+
+def to_torch_tensor(x):
+ """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
+
+ Args:
+ x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
+
+ Returns:
+ Tensor or ComplexTensor: Type converted inputs.
+
+ Examples:
+ >>> xs = np.ones(3, dtype=np.float32)
+ >>> xs = to_torch_tensor(xs)
+ tensor([1., 1., 1.])
+ >>> xs = torch.ones(3, 4, 5)
+ >>> assert to_torch_tensor(xs) is xs
+ >>> xs = {'real': xs, 'imag': xs}
+ >>> to_torch_tensor(xs)
+ ComplexTensor(
+ Real:
+ tensor([1., 1., 1.])
+ Imag;
+ tensor([1., 1., 1.])
+ )
+
+ """
+ # If numpy, change to torch tensor
+ if isinstance(x, np.ndarray):
+ if x.dtype.kind == "c":
+ # Dynamically importing because torch_complex requires python3
+ from torch_complex.tensor import ComplexTensor
+
+ return ComplexTensor(x)
+ else:
+ return torch.from_numpy(x)
+
+ # If {'real': ..., 'imag': ...}, convert to ComplexTensor
+ elif isinstance(x, dict):
+ # Dynamically importing because torch_complex requires python3
+ from torch_complex.tensor import ComplexTensor
+
+ if "real" not in x or "imag" not in x:
+ raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
+ # Relative importing because of using python3 syntax
+ return ComplexTensor(x["real"], x["imag"])
+
+ # If torch.Tensor, as it is
+ elif isinstance(x, torch.Tensor):
+ return x
+
+ else:
+ error = (
+ "x must be numpy.ndarray, torch.Tensor or a dict like "
+ "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
+ "but got {}".format(type(x))
+ )
+ try:
+ from torch_complex.tensor import ComplexTensor
+ except Exception:
+ # If PY2
+ raise ValueError(error)
+ else:
+ # If PY3
+ if isinstance(x, ComplexTensor):
+ return x
+ else:
+ raise ValueError(error)
+
+
+def get_subsample(train_args, mode, arch):
+ """Parse the subsampling factors from the args for the specified `mode` and `arch`.
+
+ Args:
+ train_args: argument Namespace containing options.
+ mode: one of ('asr', 'mt', 'st')
+ arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
+
+ Returns:
+ np.ndarray / List[np.ndarray]: subsampling factors.
+ """
+ if arch == "transformer":
+ return np.array([1])
+
+ elif mode == "mt" and arch == "rnn":
+ # +1 means input (+1) and layers outputs (train_args.elayer)
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
+ logging.warning("Subsampling is not performed for machine translation.")
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ return subsample
+
+ elif (
+ (mode == "asr" and arch in ("rnn", "rnn-t"))
+ or (mode == "mt" and arch == "rnn")
+ or (mode == "st" and arch == "rnn")
+ ):
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
+ ss = train_args.subsample.split("_")
+ for j in range(min(train_args.elayers + 1, len(ss))):
+ subsample[j] = int(ss[j])
+ else:
+ logging.warning(
+ "Subsampling is not performed for vgg*. "
+ "It is performed in max pooling layers at CNN."
+ )
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ return subsample
+
+ elif mode == "asr" and arch == "rnn_mix":
+ subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32)
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
+ ss = train_args.subsample.split("_")
+ for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
+ subsample[j] = int(ss[j])
+ else:
+ logging.warning(
+ "Subsampling is not performed for vgg*. "
+ "It is performed in max pooling layers at CNN."
+ )
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ return subsample
+
+ elif mode == "asr" and arch == "rnn_mulenc":
+ subsample_list = []
+ for idx in range(train_args.num_encs):
+ subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
+ if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
+ ss = train_args.subsample[idx].split("_")
+ for j in range(min(train_args.elayers[idx] + 1, len(ss))):
+ subsample[j] = int(ss[j])
+ else:
+ logging.warning(
+ "Encoder %d: Subsampling is not performed for vgg*. "
+ "It is performed in max pooling layers at CNN.",
+ idx + 1,
+ )
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ subsample_list.append(subsample)
+ return subsample_list
+
+ else:
+ raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
+
+
+def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
+ """Replace keys of old prefix with new prefix in state dict."""
+ # need this list not to break the dict iterator
+ old_keys = [k for k in state_dict if k.startswith(old_prefix)]
+ if len(old_keys) > 0:
+ logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
+ for k in old_keys:
+ v = state_dict.pop(k)
+ new_k = k.replace(old_prefix, new_prefix)
+ state_dict[new_k] = v
+
+
+class Swish(torch.nn.Module):
+ """Swish activation definition.
+
+ Swish(x) = (beta * x) * sigmoid(x)
+ where beta = 1 defines standard Swish activation.
+
+ References:
+ https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
+ E-swish variant: https://arxiv.org/abs/1801.07145.
+
+ Args:
+ beta: Beta parameter for E-Swish.
+ (beta >= 1. If beta < 1, use standard Swish).
+ use_builtin: Whether to use PyTorch function if available.
+
+ """
+
+ def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
+ super().__init__()
+
+ self.beta = beta
+
+ if beta > 1:
+ self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
+ else:
+ if use_builtin:
+ self.swish = torch.nn.SiLU()
+ else:
+ self.swish = lambda x: x * torch.sigmoid(x)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward computation."""
+ return self.swish(x)
+
+
+def get_activation(act):
+ """Return activation function."""
+
+ activation_funcs = {
+ "hardtanh": torch.nn.Hardtanh,
+ "tanh": torch.nn.Tanh,
+ "relu": torch.nn.ReLU,
+ "selu": torch.nn.SELU,
+ "swish": Swish,
+ }
+
+ return activation_funcs[act]()
+
+
+class TooShortUttError(Exception):
+ """Raised when the utt is too short for subsampling.
+
+ Args:
+ message: Error message to display.
+ actual_size: The size that cannot pass the subsampling.
+ limit: The size limit for subsampling.
+
+ """
+
+ def __init__(self, message: str, actual_size: int, limit: int) -> None:
+ """Construct a TooShortUttError module."""
+ super().__init__(message)
+
+ self.actual_size = actual_size
+ self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+ """Check if the input is too short for subsampling.
+
+ Args:
+ sub_factor: Subsampling factor for Conv2DSubsampling.
+ size: Input size.
+
+ Returns:
+ : Whether an error should be sent.
+ : Size limit for specified subsampling factor.
+
+ """
+ if sub_factor == 2 and size < 3:
+ return True, 7
+ elif sub_factor == 4 and size < 7:
+ return True, 7
+ elif sub_factor == 6 and size < 11:
+ return True, 11
+
+ return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+ """Get conv2D second layer parameters for given subsampling factor.
+
+ Args:
+ sub_factor: Subsampling factor (1/X).
+ input_size: Input size.
+
+ Returns:
+ : Kernel size for second convolution.
+ : Stride for second convolution.
+ : Conv2DSubsampling output size.
+
+ """
+ if sub_factor == 2:
+ return 3, 1, (((input_size - 1) // 2 - 2))
+ elif sub_factor == 4:
+ return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+ elif sub_factor == 6:
+ return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+ else:
+ raise ValueError("subsampling_factor parameter should be set to either 2, 4 or 6.")
+
+
+def make_chunk_mask(
+ size: int,
+ chunk_size: int,
+ left_chunk_size: int = 0,
+ device: torch.device = None,
+) -> torch.Tensor:
+ """Create chunk mask for the subsequent steps (size, size).
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ size: Size of the source mask.
+ chunk_size: Number of frames in chunk.
+ left_chunk_size: Size of the left context in chunks (0 means full context).
+ device: Device for the mask tensor.
+
+ Returns:
+ mask: Chunk mask. (size, size)
+
+ """
+ mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+ for i in range(size):
+ if left_chunk_size < 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+ end = min((i // chunk_size + 1) * chunk_size, size)
+ mask[i, start:end] = True
+
+ return ~mask
+
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+ """Create source mask for given lengths.
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ lengths: Sequence lengths. (B,)
+
+ Returns:
+ : Mask for the sequence lengths. (B, max_len)
+
+ """
+ max_len = lengths.max()
+ batch_size = lengths.size(0)
+
+ expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+ return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+ labels: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Get Transducer loss I/O.
+
+ Args:
+ labels: Label ID sequences. (B, L)
+ encoder_out_lens: Encoder output lengths. (B,)
+ ignore_id: Padding symbol ID.
+ blank_id: Blank symbol ID.
+
+ Returns:
+ decoder_in: Decoder inputs. (B, U)
+ target: Target label ID sequences. (B, U)
+ t_len: Time lengths. (B,)
+ u_len: Label lengths. (B,)
+
+ """
+
+ def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+ """Create padded batch of labels from a list of labels sequences.
+
+ Args:
+ labels: Labels sequences. [B x (?)]
+ padding_value: Padding value.
+
+ Returns:
+ labels: Batch of padded labels sequences. (B,)
+
+ """
+ batch_size = len(labels)
+
+ padded = (
+ labels[0]
+ .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+ .fill_(padding_value)
+ )
+
+ for i in range(batch_size):
+ padded[i, : labels[i].size(0)] = labels[i]
+
+ return padded
+
+ device = labels.device
+
+ labels_unpad = [y[y != ignore_id] for y in labels]
+ blank = labels[0].new([blank_id])
+
+ decoder_in = pad_list(
+ [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+ ).to(device)
+
+ target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+ encoder_out_lens = list(map(int, encoder_out_lens))
+ t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+ u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+ return decoder_in, target, t_len, u_len
+
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+ """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+ if t.size(dim) == pad_len:
+ return t
+ else:
+ pad_size = list(t.shape)
+ pad_size[dim] = pad_len - t.size(dim)
+ return torch.cat([t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim)
diff --git a/funcineforge/tokenizer/__init__.py b/funcineforge/tokenizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9d3087cb7e88fcf4d5171563c4a76327dffd33a
--- /dev/null
+++ b/funcineforge/tokenizer/__init__.py
@@ -0,0 +1 @@
+from .tokenizer import FunCineForgeTokenizer
\ No newline at end of file
diff --git a/funcineforge/tokenizer/tokenizer.py b/funcineforge/tokenizer/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..69e66c57956214f600f38ac95a30f7f35d625a01
--- /dev/null
+++ b/funcineforge/tokenizer/tokenizer.py
@@ -0,0 +1,20 @@
+def FunCineForgeTokenizer(init_param_path, **kwargs):
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(init_param_path)
+ special_tokens = {
+ 'eos_token': '<|endoftext|>',
+ 'pad_token': '<|endoftext|>',
+ 'additional_special_tokens': [
+ '<|im_start|>', '<|im_end|>',
+ '<|startofclue|>', '<|endofclue|>', '<|endofprompt|>',
+ '[breath]', '', '', '[noise]',
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
+ '[quick_breath]',
+ "", "",
+ "[hissing]", "[sigh]", "[vocalized-noise]",
+ "[lipsmack]", "[mn]", "<|endofsystem|>"
+ ]
+ }
+ tokenizer.add_special_tokens(special_tokens)
+
+ return tokenizer
\ No newline at end of file
diff --git a/funcineforge/utils/__init__.py b/funcineforge/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/funcineforge/utils/device_funcs.py b/funcineforge/utils/device_funcs.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd7fc76a2862a273d90b0d0aaeb6ad88243a52ff
--- /dev/null
+++ b/funcineforge/utils/device_funcs.py
@@ -0,0 +1,64 @@
+import dataclasses
+import warnings
+
+import numpy as np
+import torch
+
+
+def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
+ """Change the device of object recursively"""
+ if isinstance(data, dict):
+ return {k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()}
+ elif dataclasses.is_dataclass(data) and not isinstance(data, type):
+ return type(data)(
+ *[to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data)]
+ )
+ # maybe namedtuple. I don't know the correct way to judge namedtuple.
+ elif isinstance(data, tuple) and type(data) is not tuple:
+ return type(data)(*[to_device(o, device, dtype, non_blocking, copy) for o in data])
+ elif isinstance(data, (list, tuple)):
+ return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
+ elif isinstance(data, np.ndarray):
+ return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
+ elif isinstance(data, torch.Tensor):
+ return data.to(device, dtype, non_blocking, copy)
+ else:
+ return data
+
+
+def force_gatherable(data, device):
+ """Change object to gatherable in torch.nn.DataParallel recursively
+
+ The difference from to_device() is changing to torch.Tensor if float or int
+ value is found.
+
+ The restriction to the returned value in DataParallel:
+ The object must be
+ - torch.cuda.Tensor
+ - 1 or more dimension. 0-dimension-tensor sends warning.
+ or a list, tuple, dict.
+
+ """
+ if isinstance(data, dict):
+ return {k: force_gatherable(v, device) for k, v in data.items()}
+ # DataParallel can't handle NamedTuple well
+ elif isinstance(data, tuple) and type(data) is not tuple:
+ return type(data)(*[force_gatherable(o, device) for o in data])
+ elif isinstance(data, (list, tuple, set)):
+ return type(data)(force_gatherable(v, device) for v in data)
+ elif isinstance(data, np.ndarray):
+ return force_gatherable(torch.from_numpy(data), device)
+ elif isinstance(data, torch.Tensor):
+ if data.dim() == 0:
+ # To 1-dim array
+ data = data[None]
+ return data.to(device)
+ elif isinstance(data, float):
+ return torch.tensor([data], dtype=torch.float, device=device)
+ elif isinstance(data, int):
+ return torch.tensor([data], dtype=torch.long, device=device)
+ elif data is None:
+ return None
+ else:
+ warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
+ return data
diff --git a/funcineforge/utils/export_utils.py b/funcineforge/utils/export_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4f281715758b444d2054544ae2105701ec7050
--- /dev/null
+++ b/funcineforge/utils/export_utils.py
@@ -0,0 +1,196 @@
+import os
+import torch
+import functools
+
+
+def export(
+ model, data_in=None, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
+):
+ model_scripts = model.export(**kwargs)
+ export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
+ os.makedirs(export_dir, exist_ok=True)
+
+ if not isinstance(model_scripts, (list, tuple)):
+ model_scripts = (model_scripts,)
+ for m in model_scripts:
+ m.eval()
+ if type == "onnx":
+ _onnx(
+ m,
+ data_in=data_in,
+ quantize=quantize,
+ opset_version=opset_version,
+ export_dir=export_dir,
+ **kwargs,
+ )
+ elif type == "torchscript":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print("Exporting torchscripts on device {}".format(device))
+ _torchscripts(m, path=export_dir, device=device)
+ elif type == "bladedisc":
+ assert (
+ torch.cuda.is_available()
+ ), "Currently bladedisc optimization only supports GPU"
+ # bladedisc only optimizes encoder/decoder modules
+ if hasattr(m, "encoder") and hasattr(m, "decoder"):
+ _bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True)
+ else:
+ _torchscripts(m, path=export_dir, device="cuda")
+ print("output dir: {}".format(export_dir))
+
+ return export_dir
+
+
+def _onnx(
+ model,
+ data_in=None,
+ quantize: bool = False,
+ opset_version: int = 14,
+ export_dir: str = None,
+ **kwargs,
+):
+
+ dummy_input = model.export_dummy_inputs()
+
+ verbose = kwargs.get("verbose", False)
+
+ export_name = model.export_name + ".onnx"
+ model_path = os.path.join(export_dir, export_name)
+ torch.onnx.export(
+ model,
+ dummy_input,
+ model_path,
+ verbose=verbose,
+ opset_version=opset_version,
+ input_names=model.export_input_names(),
+ output_names=model.export_output_names(),
+ dynamic_axes=model.export_dynamic_axes(),
+ )
+
+ if quantize:
+ from onnxruntime.quantization import QuantType, quantize_dynamic
+ import onnx
+
+ quant_model_path = model_path.replace(".onnx", "_quant.onnx")
+ if not os.path.exists(quant_model_path):
+ onnx_model = onnx.load(model_path)
+ nodes = [n.name for n in onnx_model.graph.node]
+ nodes_to_exclude = [
+ m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
+ ]
+ quantize_dynamic(
+ model_input=model_path,
+ model_output=quant_model_path,
+ op_types_to_quantize=["MatMul"],
+ per_channel=True,
+ reduce_range=False,
+ weight_type=QuantType.QUInt8,
+ nodes_to_exclude=nodes_to_exclude,
+ )
+
+
+def _torchscripts(model, path, device="cuda"):
+ dummy_input = model.export_dummy_inputs()
+
+ if device == "cuda":
+ model = model.cuda()
+ if isinstance(dummy_input, torch.Tensor):
+ dummy_input = dummy_input.cuda()
+ else:
+ dummy_input = tuple([i.cuda() for i in dummy_input])
+
+ model_script = torch.jit.trace(model, dummy_input)
+ model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
+
+
+def _bladedisc_opt(model, model_inputs, enable_fp16=True):
+ model = model.eval()
+ try:
+ import torch_blade
+ except Exception as e:
+ print(
+ f"Warning, if you are exporting bladedisc, please install it and try it again: pip install -U torch_blade\n"
+ )
+ torch_config = torch_blade.config.Config()
+ torch_config.enable_fp16 = enable_fp16
+ with torch.no_grad(), torch_config:
+ opt_model = torch_blade.optimize(
+ model,
+ allow_tracing=True,
+ model_inputs=model_inputs,
+ )
+ return opt_model
+
+
+def _rescale_input_hook(m, x, scale):
+ if len(x) > 1:
+ return (x[0] / scale, *x[1:])
+ else:
+ return (x[0] / scale,)
+
+
+def _rescale_output_hook(m, x, y, scale):
+ if isinstance(y, tuple):
+ return (y[0] / scale, *y[1:])
+ else:
+ return y / scale
+
+
+def _rescale_encoder_model(model, input_data):
+ # Calculate absmax
+ absmax = torch.tensor(0).cuda()
+
+ def stat_input_hook(m, x, y):
+ val = x[0] if isinstance(x, tuple) else x
+ absmax.copy_(torch.max(absmax, val.detach().abs().max()))
+
+ encoders = model.encoder.model.encoders
+ hooks = [m.register_forward_hook(stat_input_hook) for m in encoders]
+ model = model.cuda()
+ model(*input_data)
+ for h in hooks:
+ h.remove()
+
+ # Rescale encoder modules
+ fp16_scale = int(2 * absmax // 65536)
+ print(f"rescale encoder modules with factor={fp16_scale}")
+ model.encoder.model.encoders0.register_forward_pre_hook(
+ functools.partial(_rescale_input_hook, scale=fp16_scale),
+ )
+ for name, m in model.encoder.model.named_modules():
+ if name.endswith("self_attn"):
+ m.register_forward_hook(functools.partial(_rescale_output_hook, scale=fp16_scale))
+ if name.endswith("feed_forward.w_2"):
+ state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()}
+ m.load_state_dict(state_dict)
+
+
+def _bladedisc_opt_for_encdec(model, path, enable_fp16):
+ # Get input data
+ # TODO: better to use real data
+ input_data = model.export_dummy_inputs()
+ if isinstance(input_data, torch.Tensor):
+ input_data = input_data.cuda()
+ else:
+ input_data = tuple([i.cuda() for i in input_data])
+
+ # Get input data for decoder module
+ decoder_inputs = list()
+
+ def get_input_hook(m, x):
+ decoder_inputs.extend(list(x))
+
+ hook = model.decoder.register_forward_pre_hook(get_input_hook)
+ model = model.cuda()
+ model(*input_data)
+ hook.remove()
+
+ # Prevent FP16 overflow
+ if enable_fp16:
+ _rescale_encoder_model(model, input_data)
+
+ # Export and optimize encoder/decoder modules
+ model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
+ model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
+ model_script = torch.jit.trace(model, input_data)
+ model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
diff --git a/funcineforge/utils/hinter.py b/funcineforge/utils/hinter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4f1a85c20889a74e705c30b782511a71ab7764c
--- /dev/null
+++ b/funcineforge/utils/hinter.py
@@ -0,0 +1,62 @@
+import sys
+import logging
+import os
+import torch
+
+HINTED = set()
+
+
+def hint_once(content, uid, rank=None):
+ if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank:
+ if uid not in HINTED:
+ logging.info(content, stacklevel=3)
+ HINTED.add(uid)
+
+
+def get_logger(fpath=None, log_level=logging.INFO, local_rank=0, world_size=1):
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+ formatter = logging.Formatter(
+ f"[{os.uname()[1].split('.')[0]}]({local_rank}/{world_size})"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
+
+ logging.basicConfig(
+ level=log_level,
+ format=f"[{os.uname()[1].split('.')[0]}]({local_rank}/{world_size})"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logger = logging.getLogger("Pyobj, f")
+ if fpath is not None:
+ # Dump log to file
+ fh = logging.FileHandler(fpath)
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+ return logger
+
+
+def get_current_command():
+ # Get the command-line arguments (including the script name)
+ command_line_args = sys.argv
+
+ # Get the full path of the Python interpreter
+ python_interpreter = os.path.abspath(sys.executable)
+
+ # Combine the interpreter and command-line arguments to reconstruct the command
+ full_command = ' '.join([python_interpreter] + command_line_args)
+
+ return full_command
+
+
+def get_gpu_info():
+ gpu_info = (
+ "GPU, memory: usage: {:.3f} GB, "
+ "peak: {:.3f} GB, "
+ "cache: {:.3f} GB, "
+ "cache_peak: {:.3f} GB".format(
+ torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+ torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
+ torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
+ torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
+ )
+ )
+ return gpu_info
diff --git a/funcineforge/utils/load_pretrained_model.py b/funcineforge/utils/load_pretrained_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe993fd5baaa86e390be519fa0d2604600d51ab9
--- /dev/null
+++ b/funcineforge/utils/load_pretrained_model.py
@@ -0,0 +1,139 @@
+from typing import Any
+from typing import Dict
+from typing import Union
+from io import BytesIO
+import os
+import logging
+import torch
+import torch.nn
+import torch.optim
+import pdb
+
+
+def load_pretrained_model(
+ path,
+ model: torch.nn.Module,
+ ignore_init_mismatch: bool = True,
+ map_location: str = "cpu",
+ oss_bucket=None,
+ scope_map=[],
+ excludes=None,
+ **kwargs,
+):
+ """Load a model state and set it to the model.
+
+ Args:
+ init_param: :::
+
+ Examples:
+
+ """
+
+ obj = model
+ dst_state = obj.state_dict()
+ use_deepspeed = kwargs.get("use_deepspeed", False)
+
+ logging.info(f"ckpt: {path}, use_deepspeed: {use_deepspeed}")
+
+ if use_deepspeed and os.path.isdir(path):
+ ckpt_dir = os.path.dirname(path)
+ ckpt_name = os.path.basename(path)
+ if os.path.exists(f"{ckpt_dir}/zero_to_fp32.py"):
+ print("Detect zero_to_fp32, begin to convert fp32 model")
+ ckpt_fp32 = f"{ckpt_dir}/{ckpt_name[3:]}"
+ if os.path.exists(ckpt_fp32):
+ print(f"Detect zero_to_fp32 already exist! Loading it directly. {ckpt_fp32}")
+ src_state = torch.load(ckpt_fp32, map_location=map_location)
+ else:
+ with open(f"{ckpt_dir}/latest", "w") as latest:
+ latest.write(ckpt_name)
+ latest.flush()
+ from deepspeed.utils.zero_to_fp32 import (
+ get_fp32_state_dict_from_zero_checkpoint,
+ )
+
+ src_state = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir) # already on cpu
+ if kwargs.get("save_deepspeed_zero_fp32", False):
+ print(
+ f'save_deepspeed_zero_fp32: {kwargs.get("save_deepspeed_zero_fp32", False)}, {ckpt_fp32}'
+ )
+ torch.save({"state_dict": src_state}, ckpt_fp32)
+ else:
+ print("Detect deepspeed without zero, load fp32 model directly")
+ for item in os.listdir(path):
+ if item.endswith(".pt"):
+ src_state = torch.load(f"{path}/{item}", map_location=map_location)
+
+ else:
+ src_state = torch.load(path, map_location=map_location)
+
+ src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
+ src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
+ src_state = src_state["model"] if "model" in src_state else src_state
+
+ if isinstance(scope_map, str):
+ scope_map = scope_map.split(",")
+ scope_map += ["module.", "None"]
+ logging.info(f"scope_map: {scope_map}")
+
+ if excludes is not None:
+ if isinstance(excludes, str):
+ excludes = excludes.split(",")
+
+ logging.info(f"excludes: {excludes}")
+
+ param_mapping_count = 0
+ exclusion_match_count = 0
+ missing_key_count = 0
+
+ for k in dst_state.keys():
+ excludes_flag = False
+ if excludes is not None:
+ for k_ex in excludes:
+ if k.startswith(k_ex):
+ logging.info(f"key: {k} matching: {k_ex}, excluded")
+ excludes_flag = True
+ break
+ if excludes_flag:
+ continue
+
+ k_src = k
+
+ if scope_map is not None:
+ src_prefix = ""
+ dst_prefix = ""
+ for i in range(0, len(scope_map), 2):
+ src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
+ dst_prefix = scope_map[i + 1] if scope_map[i + 1].lower() != "none" else ""
+
+ if dst_prefix == "" and (src_prefix + k) in src_state.keys():
+ k_src = src_prefix + k
+ if not k_src.startswith("module."):
+ logging.info(f"init param, map: {k} from {k_src} in ckpt")
+ elif (
+ k.startswith(dst_prefix)
+ and k.replace(dst_prefix, src_prefix, 1) in src_state.keys()
+ ):
+ k_src = k.replace(dst_prefix, src_prefix, 1)
+ if not k_src.startswith("module."):
+ logging.info(f"init param, map: {k} from {k_src} in ckpt")
+
+ if k_src in src_state.keys():
+ if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
+ logging.info(
+ f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
+ )
+ exclusion_match_count += 1
+ else:
+ dst_state[k] = src_state[k_src]
+ param_mapping_count += 1
+
+
+ else:
+ print(f"Warning, miss key in ckpt: {k}, {path}")
+ missing_key_count +=1
+
+ logging.info(f"matched keys: {param_mapping_count}, missing keys: {missing_key_count}, exclusion_match_count: {exclusion_match_count}")
+
+ flag = obj.load_state_dict(dst_state, strict=True)
+ logging.info(f"Loading ckpt: {path}, status: {flag}")
diff --git a/funcineforge/utils/load_utils.py b/funcineforge/utils/load_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ad0ec36f351f3979fa52b74874b36cf8931ce18
--- /dev/null
+++ b/funcineforge/utils/load_utils.py
@@ -0,0 +1,275 @@
+import os
+import torch
+import numpy as np
+import kaldiio
+import librosa
+import torchaudio
+import torchaudio.compliance.kaldi as Kaldi
+from torch.nn.utils.rnn import pad_sequence
+import onnxruntime as ort
+
+try:
+ from funcineforge.download.file import download_from_url
+except:
+ print("urllib is not installed, if you infer from url, please install it first.")
+import subprocess
+from subprocess import CalledProcessError, run
+
+
+def is_ffmpeg_installed():
+ try:
+ output = subprocess.check_output(["ffmpeg", "-version"], stderr=subprocess.STDOUT)
+ return "ffmpeg version" in output.decode("utf-8")
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ return False
+
+
+use_ffmpeg = False
+if is_ffmpeg_installed():
+ use_ffmpeg = True
+else:
+ print(
+ "Notice: ffmpeg is not installed. torchaudio is used to load audio\n"
+ "If you want to use ffmpeg backend to load audio, please install it by:"
+ "\n\tsudo apt install ffmpeg # ubuntu"
+ "\n\t# brew install ffmpeg # mac"
+ )
+
+
+def load_audio_text_image_video(
+ data_or_path_or_list,
+ fs: int = 16000,
+ audio_fs: int = 16000,
+ data_type="sound",
+ tokenizer=None,
+ **kwargs,
+):
+ if isinstance(data_or_path_or_list, (list, tuple)):
+ if data_type is not None and isinstance(data_type, (list, tuple)):
+ data_types = [data_type] * len(data_or_path_or_list)
+ data_or_path_or_list_ret = [[] for d in data_type]
+ for i, (data_type_i, data_or_path_or_list_i) in enumerate(
+ zip(data_types, data_or_path_or_list)
+ ):
+ for j, (data_type_j, data_or_path_or_list_j) in enumerate(
+ zip(data_type_i, data_or_path_or_list_i)
+ ):
+ data_or_path_or_list_j = load_audio_text_image_video(
+ data_or_path_or_list_j,
+ fs=fs,
+ audio_fs=audio_fs,
+ data_type=data_type_j,
+ tokenizer=tokenizer,
+ **kwargs,
+ )
+ data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
+
+ return data_or_path_or_list_ret
+ else:
+ return [
+ load_audio_text_image_video(
+ audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs
+ )
+ for audio in data_or_path_or_list
+ ]
+ if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith(
+ "http"
+ ): # download url to local file
+ data_or_path_or_list = download_from_url(data_or_path_or_list)
+ if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
+ if data_type is None or data_type in ["sound", "kaldi_ark_or_sound"]:
+ if kwargs.get("use_ffmpeg", False):
+ data_or_path_or_list = _load_audio_ffmpeg(data_or_path_or_list, sr=fs)
+ data_or_path_or_list = torch.from_numpy(
+ data_or_path_or_list
+ ).squeeze() # [n_samples,]
+
+ else:
+ try:
+ data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
+ if kwargs.get("reduce_channels", True):
+ data_or_path_or_list = data_or_path_or_list.mean(0)
+ except:
+ data_or_path_or_list = _load_audio_ffmpeg(data_or_path_or_list, sr=fs)
+ data_or_path_or_list = torch.from_numpy(
+ data_or_path_or_list
+ ).squeeze() # [n_samples,]
+ elif data_type == "text" and tokenizer is not None:
+ data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+ elif data_type == "image": # undo
+ pass
+ elif data_type == "video": # undo
+ pass
+
+ # if data_in is a file or url, set is_final=True
+ if "cache" in kwargs:
+ kwargs["cache"]["is_final"] = True
+ kwargs["cache"]["is_streaming_input"] = False
+ elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
+ data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+ elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
+ data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
+ elif isinstance(data_or_path_or_list, str) and data_type in ["kaldi_ark", "kaldi_ark_or_sound", "sound"]:
+ if len(data_or_path_or_list.split()) == 2:
+ data_or_path_or_list, audio_fs = data_or_path_or_list.split()
+ audio_fs = int(audio_fs)
+ data_mat = kaldiio.load_mat(data_or_path_or_list)
+ if isinstance(data_mat, tuple):
+ audio_fs, mat = data_mat
+ else:
+ mat = data_mat
+ if mat.dtype == "int16":
+ mat = mat.astype(np.float32)
+ mat = mat / (2 ** 16)
+ elif mat.dtype == "int32":
+ mat = mat.astype(np.float32)
+ mat = mat / (2 ** 32)
+ if mat.ndim == 2:
+ mat = mat[:, 0]
+ data_or_path_or_list = torch.from_numpy(mat)
+ elif isinstance(data_or_path_or_list, bytes): # audio bytes
+ data_or_path_or_list = load_bytes(data_or_path_or_list)
+ else:
+ pass
+ print(f"unsupport data type: {data_or_path_or_list}, return raw data")
+
+ if audio_fs != fs and data_type != "text":
+ resampler = torchaudio.transforms.Resample(audio_fs, fs, dtype=data_or_path_or_list.dtype)
+ data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
+ return data_or_path_or_list
+
+
+class FBank(object):
+ def __init__(self,
+ n_mels,
+ sample_rate,
+ mean_nor: bool = False,
+ ):
+ self.n_mels = n_mels
+ self.sample_rate = sample_rate
+ self.mean_nor = mean_nor
+
+ def __call__(self, wav, dither=0):
+ sr = 16000
+ assert sr == self.sample_rate
+ if len(wav.shape) == 1:
+ wav = wav.unsqueeze(0)
+ if wav.shape[0] > 1:
+ wav = torch.mean(wav, dim=0, keepdim=True)
+ assert len(wav.shape) == 2 and wav.shape[0] == 1, wav.shape
+ feat = Kaldi.fbank(wav, num_mel_bins=self.n_mels,
+ sample_frequency=sr, dither=dither)
+ # feat: [T, N]
+ if self.mean_nor:
+ feat = feat - feat.mean(0, keepdim=True)
+ return feat
+
+class OnnxModel(object):
+ def __init__(self, pretrained_model):
+ session_options = ort.SessionOptions()
+ self.model = ort.InferenceSession(pretrained_model, session_options)
+ self.input_name = self.model.get_inputs()[0].name
+ self.output_name = self.model.get_outputs()[0].name
+ self.feature_extractor = FBank(n_mels=80, sample_rate=16000, mean_nor=True)
+
+ def __call__(self, wav):
+ feat = self.feature_extractor(torch.as_tensor(wav))
+ feat = feat.float().unsqueeze(0).numpy()
+ emb = self.model.run([self.output_name], {self.input_name: feat})[0]
+ return emb
+
+def extract_campp_xvec(
+ wav_path: str = "",
+ target_sr: int = 16000,
+ **kwargs,
+):
+ wav, sr = librosa.load(wav_path, dtype=np.float32, sr=target_sr, mono=True)
+ if sr != target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
+ onnx_path = kwargs.get("xvec_model", None)
+ model = OnnxModel(onnx_path)
+ xvec = model(wav)
+ return xvec
+
+
+
+def load_bytes(input):
+ middle_data = np.frombuffer(input, dtype=np.int16)
+ middle_data = np.asarray(middle_data)
+ if middle_data.dtype.kind not in "iu":
+ raise TypeError("'middle_data' must be an array of integers")
+ dtype = np.dtype("float32")
+ if dtype.kind != "f":
+ raise TypeError("'dtype' must be a floating point type")
+
+ i = np.iinfo(middle_data.dtype)
+ abs_max = 2 ** (i.bits - 1)
+ offset = i.min + abs_max
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+ return array
+
+
+def extract_fbank(data, data_len=None, data_type: str = "sound", frontend=None, **kwargs):
+ if isinstance(data, np.ndarray):
+ data = torch.from_numpy(data)
+ if len(data.shape) < 2:
+ data = data[None, :] # data: [batch, N]
+ data_len = [data.shape[1]] if data_len is None else data_len
+ elif isinstance(data, torch.Tensor):
+ if len(data.shape) < 2:
+ data = data[None, :] # data: [batch, N]
+ data_len = [data.shape[1]] if data_len is None else data_len
+ elif isinstance(data, (list, tuple)):
+ data_list, data_len = [], []
+ for data_i in data:
+ if isinstance(data_i, np.ndarray):
+ data_i = torch.from_numpy(data_i)
+ data_list.append(data_i)
+ data_len.append(data_i.shape[0])
+ data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
+
+ data, data_len = frontend(data, data_len, **kwargs)
+
+ if isinstance(data_len, (list, tuple)):
+ data_len = torch.tensor([data_len])
+ return data.to(torch.float32), data_len.to(torch.int32)
+
+
+def _load_audio_ffmpeg(file: str, sr: int = 16000):
+ """
+ Open an audio file and read as mono waveform, resampling as necessary
+
+ Parameters
+ ----------
+ file: str
+ The audio file to open
+
+ sr: int
+ The sample rate to resample the audio if necessary
+
+ Returns
+ -------
+ A NumPy array containing the audio waveform, in float32 dtype.
+ """
+
+ # This launches a subprocess to decode audio while down-mixing
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
+ # fmt: off
+ cmd = [
+ "ffmpeg",
+ "-nostdin",
+ "-threads", "0",
+ "-i", file,
+ "-f", "s16le",
+ "-ac", "1",
+ "-acodec", "pcm_s16le",
+ "-ar", str(sr),
+ "-"
+ ]
+ # fmt: on
+ try:
+ out = run(cmd, capture_output=True, check=True).stdout
+ except CalledProcessError as e:
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
diff --git a/funcineforge/utils/misc.py b/funcineforge/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f460b5a0f46118bf8766cb360e61cc0118fba162
--- /dev/null
+++ b/funcineforge/utils/misc.py
@@ -0,0 +1,126 @@
+import os
+import io
+import shutil
+import logging
+from collections import OrderedDict
+import numpy as np
+from omegaconf import DictConfig, OmegaConf
+import torch
+
+
+def statistic_model_parameters(model, prefix=None):
+ var_dict = model.state_dict()
+ numel = 0
+ for i, key in enumerate(
+ sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))
+ ):
+ if prefix is None or key.startswith(prefix):
+ numel += var_dict[key].numel()
+ return numel
+
+
+def int2vec(x, vec_dim=8, dtype=np.int32):
+ b = ("{:0" + str(vec_dim) + "b}").format(x)
+ # little-endian order: lower bit first
+ return (np.array(list(b)[::-1]) == "1").astype(dtype)
+
+
+def seq2arr(seq, vec_dim=8):
+ return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
+
+
+def load_scp_as_dict(scp_path, value_type="str", kv_sep=" "):
+ with io.open(scp_path, "r", encoding="utf-8") as f:
+ ret_dict = OrderedDict()
+ for one_line in f.readlines():
+ one_line = one_line.strip()
+ pos = one_line.find(kv_sep)
+ key, value = one_line[:pos], one_line[pos + 1 :]
+ if value_type == "list":
+ value = value.split(" ")
+ ret_dict[key] = value
+ return ret_dict
+
+
+def load_scp_as_list(scp_path, value_type="str", kv_sep=" "):
+ with io.open(scp_path, "r", encoding="utf8") as f:
+ ret_dict = []
+ for one_line in f.readlines():
+ one_line = one_line.strip()
+ pos = one_line.find(kv_sep)
+ key, value = one_line[:pos], one_line[pos + 1 :]
+ if value_type == "list":
+ value = value.split(" ")
+ ret_dict.append((key, value))
+ return ret_dict
+
+
+def deep_update(original, update):
+ for key, value in update.items():
+ if isinstance(value, dict) and key in original:
+ if len(value) == 0:
+ original[key] = value
+ deep_update(original[key], value)
+ else:
+ original[key] = value
+
+
+def prepare_model_dir(**kwargs):
+
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+ OmegaConf.save(config=kwargs, f=yaml_file)
+ logging.info(f"kwargs: {kwargs}")
+ logging.info("config.yaml is saved to: %s", yaml_file)
+
+ model_path = kwargs.get("model_path", None)
+ if model_path is not None:
+ config_json = os.path.join(model_path, "configuration.json")
+ if os.path.exists(config_json):
+ shutil.copy(
+ config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
+ )
+
+
+def extract_filename_without_extension(file_path):
+ """
+ 从给定的文件路径中提取文件名(不包含路径和扩展名)
+ :param file_path: 完整的文件路径
+ :return: 文件名(不含路径和扩展名)
+ """
+ # 首先,使用os.path.basename获取路径中的文件名部分(含扩展名)
+ filename_with_extension = os.path.basename(file_path)
+ # 然后,使用os.path.splitext分离文件名和扩展名
+ filename, extension = os.path.splitext(filename_with_extension)
+ # 返回不包含扩展名的文件名
+ return filename
+
+
+def smart_remove(path):
+ """Intelligently removes files, empty directories, and non-empty directories recursively."""
+ # Check if the provided path exists
+ if not os.path.exists(path):
+ print(f"{path} does not exist.")
+ return
+
+ # If the path is a file, delete it
+ if os.path.isfile(path):
+ os.remove(path)
+ print(f"File {path} has been deleted.")
+ # If the path is a directory
+ elif os.path.isdir(path):
+ try:
+ # Attempt to remove an empty directory
+ os.rmdir(path)
+ print(f"Empty directory {path} has been deleted.")
+ except OSError:
+ # If the directory is not empty, remove it along with all its contents
+ shutil.rmtree(path)
+ print(f"Non-empty directory {path} has been recursively deleted.")
+
+
+def tensor_to_scalar(x):
+ if torch.is_tensor(x):
+ return x.detach().item()
+ return x
diff --git a/funcineforge/utils/postprocess_utils.py b/funcineforge/utils/postprocess_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..34f933c967d4add7680ad371cd2565af708d461b
--- /dev/null
+++ b/funcineforge/utils/postprocess_utils.py
@@ -0,0 +1,301 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import string
+import logging
+from typing import Any, List, Union
+
+
+def isChinese(ch: str):
+ if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039" or ch == "@":
+ return True
+ return False
+
+
+def isAllChinese(word: Union[List[Any], str]):
+ word_lists = []
+ for i in word:
+ cur = i.replace(" ", "")
+ cur = cur.replace("", "")
+ cur = cur.replace("", "")
+ cur = cur.replace("", "")
+ cur = cur.replace("", "")
+ word_lists.append(cur)
+
+ if len(word_lists) == 0:
+ return False
+
+ for ch in word_lists:
+ if isChinese(ch) is False:
+ return False
+ return True
+
+
+def isAllAlpha(word: Union[List[Any], str]):
+ word_lists = []
+ for i in word:
+ cur = i.replace(" ", "")
+ cur = cur.replace("", "")
+ cur = cur.replace("", "")
+ cur = cur.replace("", "")
+ cur = cur.replace("", "")
+ word_lists.append(cur)
+
+ if len(word_lists) == 0:
+ return False
+
+ for ch in word_lists:
+ if ch.isalpha() is False and ch != "'":
+ return False
+ elif ch.isalpha() is True and isChinese(ch) is True:
+ return False
+
+ return True
+
+
+# def abbr_dispose(words: List[Any]) -> List[Any]:
+def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
+ words_size = len(words)
+ word_lists = []
+ abbr_begin = []
+ abbr_end = []
+ last_num = -1
+ ts_lists = []
+ ts_nums = []
+ ts_index = 0
+ for num in range(words_size):
+ if num <= last_num:
+ continue
+
+ if len(words[num]) == 1 and words[num].encode("utf-8").isalpha():
+ if (
+ num + 1 < words_size
+ and words[num + 1] == " "
+ and num + 2 < words_size
+ and len(words[num + 2]) == 1
+ and words[num + 2].encode("utf-8").isalpha()
+ ):
+ # found the begin of abbr
+ abbr_begin.append(num)
+ num += 2
+ abbr_end.append(num)
+ # to find the end of abbr
+ while True:
+ num += 1
+ if num < words_size and words[num] == " ":
+ num += 1
+ if (
+ num < words_size
+ and len(words[num]) == 1
+ and words[num].encode("utf-8").isalpha()
+ ):
+ abbr_end.pop()
+ abbr_end.append(num)
+ last_num = num
+ else:
+ break
+ else:
+ break
+
+ for num in range(words_size):
+ if words[num] == " ":
+ ts_nums.append(ts_index)
+ else:
+ ts_nums.append(ts_index)
+ ts_index += 1
+ last_num = -1
+ for num in range(words_size):
+ if num <= last_num:
+ continue
+
+ if num in abbr_begin:
+ if time_stamp is not None:
+ begin = time_stamp[ts_nums[num]][0]
+ abbr_word = words[num].upper()
+ num += 1
+ while num < words_size:
+ if num in abbr_end:
+ abbr_word += words[num].upper()
+ last_num = num
+ break
+ else:
+ if words[num].encode("utf-8").isalpha():
+ abbr_word += words[num].upper()
+ num += 1
+ word_lists.append(abbr_word)
+ if time_stamp is not None:
+ end = time_stamp[ts_nums[num]][1]
+ ts_lists.append([begin, end])
+ else:
+ word_lists.append(words[num])
+ if time_stamp is not None and words[num] != " ":
+ begin = time_stamp[ts_nums[num]][0]
+ end = time_stamp[ts_nums[num]][1]
+ ts_lists.append([begin, end])
+ begin = end
+
+ if time_stamp is not None:
+ return word_lists, ts_lists
+ else:
+ return word_lists
+
+
+def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
+ middle_lists = []
+ word_lists = []
+ word_item = ""
+ ts_lists = []
+
+ # wash words lists
+ for i in words:
+ word = ""
+ if isinstance(i, str):
+ word = i
+ else:
+ word = i.decode("utf-8")
+
+ if word in ["", "", "", ""]:
+ continue
+ else:
+ middle_lists.append(word)
+
+ # all chinese characters
+ if isAllChinese(middle_lists):
+ for i, ch in enumerate(middle_lists):
+ word_lists.append(ch.replace(" ", ""))
+ if time_stamp is not None:
+ ts_lists = time_stamp
+
+ # all alpha characters
+ elif isAllAlpha(middle_lists):
+ ts_flag = True
+ for i, ch in enumerate(middle_lists):
+ if ts_flag and time_stamp is not None:
+ begin = time_stamp[i][0]
+ end = time_stamp[i][1]
+ word = ""
+ if "@@" in ch:
+ word = ch.replace("@@", "")
+ word_item += word
+ if time_stamp is not None:
+ ts_flag = False
+ end = time_stamp[i][1]
+ else:
+ word_item += ch
+ word_lists.append(word_item)
+ word_lists.append(" ")
+ word_item = ""
+ if time_stamp is not None:
+ ts_flag = True
+ end = time_stamp[i][1]
+ ts_lists.append([begin, end])
+ begin = end
+
+ # mix characters
+ else:
+ alpha_blank = False
+ ts_flag = True
+ begin = -1
+ end = -1
+ for i, ch in enumerate(middle_lists):
+ if ts_flag and time_stamp is not None:
+ begin = time_stamp[i][0]
+ end = time_stamp[i][1]
+ word = ""
+ if isAllChinese(ch):
+ if alpha_blank is True:
+ word_lists.pop()
+ word_lists.append(ch)
+ alpha_blank = False
+ if time_stamp is not None:
+ ts_flag = True
+ ts_lists.append([begin, end])
+ begin = end
+ elif "@@" in ch:
+ word = ch.replace("@@", "")
+ word_item += word
+ alpha_blank = False
+ if time_stamp is not None:
+ ts_flag = False
+ end = time_stamp[i][1]
+ elif isAllAlpha(ch):
+ word_item += ch
+ word_lists.append(word_item)
+ word_lists.append(" ")
+ word_item = ""
+ alpha_blank = True
+ if time_stamp is not None:
+ ts_flag = True
+ end = time_stamp[i][1]
+ ts_lists.append([begin, end])
+ begin = end
+ else:
+ word_lists.append(ch)
+
+ if time_stamp is not None:
+ word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
+ real_word_lists = []
+ for ch in word_lists:
+ if ch != " ":
+ real_word_lists.append(ch)
+ sentence = " ".join(real_word_lists).strip()
+ return sentence, ts_lists, real_word_lists
+ else:
+ word_lists = abbr_dispose(word_lists)
+ real_word_lists = []
+ for ch in word_lists:
+ if ch != " ":
+ real_word_lists.append(ch)
+ sentence = "".join(word_lists).strip()
+ return sentence, real_word_lists
+
+
+def sentence_postprocess_sentencepiece(words):
+ middle_lists = []
+ word_lists = []
+ word_item = ""
+
+ # wash words lists
+ for i in words:
+ word = ""
+ if isinstance(i, str):
+ word = i
+ else:
+ word = i.decode("utf-8")
+
+ if word in ["", "", "", ""]:
+ continue
+ else:
+ middle_lists.append(word)
+
+ # all alpha characters
+ for i, ch in enumerate(middle_lists):
+ word = ""
+ if "\u2581" in ch and i == 0:
+ word_item = ""
+ word = ch.replace("\u2581", "")
+ word_item += word
+ elif "\u2581" in ch and i != 0:
+ word_lists.append(word_item)
+ word_lists.append(" ")
+ word_item = ""
+ word = ch.replace("\u2581", "")
+ word_item += word
+ else:
+ word_item += ch
+ if word_item is not None:
+ word_lists.append(word_item)
+ # word_lists = abbr_dispose(word_lists)
+ real_word_lists = []
+ for ch in word_lists:
+ if ch != " ":
+ if ch == "i":
+ ch = ch.replace("i", "I")
+ elif ch == "i'm":
+ ch = ch.replace("i'm", "I'm")
+ elif ch == "i've":
+ ch = ch.replace("i've", "I've")
+ elif ch == "i'll":
+ ch = ch.replace("i'll", "I'll")
+ real_word_lists.append(ch)
+ sentence = "".join(word_lists)
+ return sentence, real_word_lists
diff --git a/funcineforge/utils/set_all_random_seed.py b/funcineforge/utils/set_all_random_seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebdca3f537aac53bdc6e6cea168c49805bdf2d2f
--- /dev/null
+++ b/funcineforge/utils/set_all_random_seed.py
@@ -0,0 +1,10 @@
+import random
+
+import numpy as np
+import torch
+
+
+def set_all_random_seed(seed: int):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
diff --git a/funcineforge/utils/torch_function.py b/funcineforge/utils/torch_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..f637bbf82e5848dd4ae75133dd457eca46429ab2
--- /dev/null
+++ b/funcineforge/utils/torch_function.py
@@ -0,0 +1,84 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+
+class MakePadMask(nn.Module):
+ def __init__(self, max_seq_len=512, flip=True):
+ super().__init__()
+ if flip:
+ self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
+ else:
+ self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
+
+ def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
+ """Make mask tensor containing indices of padded part.
+ This implementation creates the same mask tensor with original make_pad_mask,
+ which can be converted into onnx format.
+ Dimension length of xs should be 2 or 3.
+ """
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if xs is not None and len(xs.shape) == 3:
+ if length_dim == 1:
+ lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
+ else:
+ lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
+
+ if maxlen is not None:
+ m = maxlen
+ elif xs is not None:
+ m = xs.shape[-1]
+ else:
+ m = torch.max(lengths)
+
+ mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
+
+ if length_dim == 1:
+ return mask.transpose(1, 2)
+ else:
+ return mask
+
+
+class sequence_mask(nn.Module):
+ def __init__(self, max_seq_len=512, flip=True):
+ super().__init__()
+
+ def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
+ if max_seq_len is None:
+ max_seq_len = lengths.max()
+ row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
+ matrix = torch.unsqueeze(lengths, dim=-1)
+ mask = row_vector < matrix
+
+ return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+
+
+def normalize(
+ input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ if out is None:
+ denom = input.norm(p, dim, keepdim=True).expand_as(input)
+ return input / denom
+ else:
+ denom = input.norm(p, dim, keepdim=True).expand_as(input)
+ return torch.div(input, denom, out=out)
+
+
+def subsequent_mask(size: torch.Tensor):
+ return torch.ones(size, size).tril()
+
+
+def MakePadMask_test():
+ feats_length = torch.tensor([10]).type(torch.long)
+ mask_fn = MakePadMask()
+ mask = mask_fn(feats_length)
+ print(mask)
+
+
+if __name__ == "__main__":
+ MakePadMask_test()
diff --git a/funcineforge/utils/types.py b/funcineforge/utils/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b36f9c4b87ed9258a5d1e254ba298ed5dbc01d2
--- /dev/null
+++ b/funcineforge/utils/types.py
@@ -0,0 +1,149 @@
+from distutils.util import strtobool
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import humanfriendly
+
+
+def str2bool(value: str) -> bool:
+ return bool(strtobool(value))
+
+
+def remove_parenthesis(value: str):
+ value = value.strip()
+ if value.startswith("(") and value.endswith(")"):
+ value = value[1:-1]
+ elif value.startswith("[") and value.endswith("]"):
+ value = value[1:-1]
+ return value
+
+
+def remove_quotes(value: str):
+ value = value.strip()
+ if value.startswith('"') and value.endswith('"'):
+ value = value[1:-1]
+ elif value.startswith("'") and value.endswith("'"):
+ value = value[1:-1]
+ return value
+
+
+def int_or_none(value: str) -> Optional[int]:
+ """int_or_none.
+
+ Examples:
+ >>> import argparse
+ >>> parser = argparse.ArgumentParser()
+ >>> _ = parser.add_argument('--foo', type=int_or_none)
+ >>> parser.parse_args(['--foo', '456'])
+ Namespace(foo=456)
+ >>> parser.parse_args(['--foo', 'none'])
+ Namespace(foo=None)
+ >>> parser.parse_args(['--foo', 'null'])
+ Namespace(foo=None)
+ >>> parser.parse_args(['--foo', 'nil'])
+ Namespace(foo=None)
+
+ """
+ if value.strip().lower() in ("none", "null", "nil"):
+ return None
+ return int(value)
+
+
+def float_or_none(value: str) -> Optional[float]:
+ """float_or_none.
+
+ Examples:
+ >>> import argparse
+ >>> parser = argparse.ArgumentParser()
+ >>> _ = parser.add_argument('--foo', type=float_or_none)
+ >>> parser.parse_args(['--foo', '4.5'])
+ Namespace(foo=4.5)
+ >>> parser.parse_args(['--foo', 'none'])
+ Namespace(foo=None)
+ >>> parser.parse_args(['--foo', 'null'])
+ Namespace(foo=None)
+ >>> parser.parse_args(['--foo', 'nil'])
+ Namespace(foo=None)
+
+ """
+ if value.strip().lower() in ("none", "null", "nil"):
+ return None
+ return float(value)
+
+
+def humanfriendly_parse_size_or_none(value) -> Optional[float]:
+ if value.strip().lower() in ("none", "null", "nil"):
+ return None
+ return humanfriendly.parse_size(value)
+
+
+def str_or_int(value: str) -> Union[str, int]:
+ try:
+ return int(value)
+ except ValueError:
+ return value
+
+
+def str_or_none(value: str) -> Optional[str]:
+ """str_or_none.
+
+ Examples:
+ >>> import argparse
+ >>> parser = argparse.ArgumentParser()
+ >>> _ = parser.add_argument('--foo', type=str_or_none)
+ >>> parser.parse_args(['--foo', 'aaa'])
+ Namespace(foo='aaa')
+ >>> parser.parse_args(['--foo', 'none'])
+ Namespace(foo=None)
+ >>> parser.parse_args(['--foo', 'null'])
+ Namespace(foo=None)
+ >>> parser.parse_args(['--foo', 'nil'])
+ Namespace(foo=None)
+
+ """
+ if value.strip().lower() in ("none", "null", "nil"):
+ return None
+ return value
+
+
+def str2pair_str(value: str) -> Tuple[str, str]:
+ """str2pair_str.
+
+ Examples:
+ >>> import argparse
+ >>> str2pair_str('abc,def ')
+ ('abc', 'def')
+ >>> parser = argparse.ArgumentParser()
+ >>> _ = parser.add_argument('--foo', type=str2pair_str)
+ >>> parser.parse_args(['--foo', 'abc,def'])
+ Namespace(foo=('abc', 'def'))
+
+ """
+ value = remove_parenthesis(value)
+ a, b = value.split(",")
+
+ # Workaround for configargparse issues:
+ # If the list values are given from yaml file,
+ # the value givent to type() is shaped as python-list,
+ # e.g. ['a', 'b', 'c'],
+ # so we need to remove double quotes from it.
+ return remove_quotes(a), remove_quotes(b)
+
+
+def str2triple_str(value: str) -> Tuple[str, str, str]:
+ """str2triple_str.
+
+ Examples:
+ >>> str2triple_str('abc,def ,ghi')
+ ('abc', 'def', 'ghi')
+ """
+ value = remove_parenthesis(value)
+ a, b, c = value.split(",")
+
+ # Workaround for configargparse issues:
+ # If the list values are given from yaml file,
+ # the value givent to type() is shaped as python-list,
+ # e.g. ['a', 'b', 'c'],
+ # so we need to remove quotes from it.
+ return remove_quotes(a), remove_quotes(b), remove_quotes(c)
diff --git a/funcineforge/utils/vad_utils.py b/funcineforge/utils/vad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eba48a9723be04b193bc919572f3bdcc08e741fe
--- /dev/null
+++ b/funcineforge/utils/vad_utils.py
@@ -0,0 +1,55 @@
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
+ speech_i = speech[0, bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+ return feats_pad, speech_lengths_pad
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+
+ return speech_list, speech_lengths_list
+
+
+def merge_vad(vad_result, max_length=15000):
+ new_result = []
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
+ time_step = sorted(list(set(time_step)))
+ if len(time_step) == 0:
+ return []
+ bg = 0
+ for i in range(len(time_step) - 1):
+ time = time_step[i]
+ if time_step[i + 1] - bg < max_length:
+ continue
+ if time - bg < max_length * 1.5:
+ new_result.append([bg, time])
+ else:
+ split_num = int(time - bg) // max_length + 1
+ spl_l = int(time - bg) // split_num
+ for j in range(split_num):
+ new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
+ bg = time
+ new_result.append([bg, time_step[-1]])
+ return new_result
diff --git a/pretrained_models/.DS_Store b/pretrained_models/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..2bc766ccc1da6a842ff19ab6e2bcd352c4580792
Binary files /dev/null and b/pretrained_models/.DS_Store differ
diff --git a/pretrained_models/funcineforge_zh_en/.DS_Store b/pretrained_models/funcineforge_zh_en/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..76d0dc84236f1517ac6122b0e56f649074da2d07
Binary files /dev/null and b/pretrained_models/funcineforge_zh_en/.DS_Store differ
diff --git a/pretrained_models/funcineforge_zh_en/llm/config.yaml b/pretrained_models/funcineforge_zh_en/llm/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5175b2b84430d0c3f99f8dce6254a15af5903237
--- /dev/null
+++ b/pretrained_models/funcineforge_zh_en/llm/config.yaml
@@ -0,0 +1,108 @@
+model: FunCineForgeLM
+model_conf:
+ lsm_weight: 0.0
+ length_normalized_loss: true
+ codec_unit: 6761
+ timespk_unit: 1550
+ face_size: 512
+llm: Qwen2-0.5B
+llm_conf:
+ hub: hf
+ freeze: false
+ llm_dtype: fp32
+ init_param_path: pretrained_models/Qwen2-0.5B-CosyVoice-BlankEN
+ use_lora: false
+ lora_conf:
+ task_type: CAUSAL_LM
+ r: 16
+ lora_alpha: 32
+ lora_dropout: 0.05
+ bias: none
+ target_modules:
+ - q_proj
+ - v_proj
+train_conf:
+ use_lora: ${llm_conf.use_lora}
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 200
+ log_interval: 100
+ effective_save_name_excludes:
+ - none
+ resume: true
+ validate_interval: 5000
+ save_checkpoint_interval: 5000
+ keep_nbest_models: 100000
+ avg_nbest_model: 5
+ use_bf16: false
+ save_init_model: false
+ loss_rescale_by_rank: false
+ use_deepspeed: true
+ deepspeed_config: decode_conf/ds_stage0_fp32.json
+optim: adamw
+optim_conf:
+ lr: 8.0e-05
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 2000
+dataset: FunCineForgeDataset
+dataset_conf:
+ use_emotion_clue: true
+ codebook_size: 6561
+ sos: 6561
+ eos: 6562
+ turn_of_speech: 6563
+ fill_token: 6564
+ ignore_id: -100
+ startofclue_token: 151646
+ endofclue_token: 151647
+ frame_shift: 25
+ timebook_size: 1500
+ pangbai: 1500
+ dubai: 1501
+ duihua: 1502
+ duoren: 1503
+ male: 1504
+ female: 1505
+ child: 1506
+ youth: 1507
+ adult: 1508
+ middle: 1509
+ elderly: 1510
+ speaker_id_start: 1511
+ index_ds: CosyVoice
+ dataloader: DataloaderMapStyle
+ load_meta_data_key: text,clue,token,face,dialogue
+ data_split_num: 1
+ batch_sampler: BatchSampler
+ shuffle: true
+ sort_size: 512
+ face_size: 512
+ batch_type: token
+ batch_size: 3000
+ batch_size_token_max: 20000
+ batch_size_sample_max: 100
+ max_token_length: 5000
+ max_text_length: 300
+ batch_size_scale_threshold: 3000
+ num_workers: 20
+ retry: 100
+ specaug: FunCineForgeSpecAug
+ specaug_conf:
+ apply_time_warp: false
+ apply_freq_mask: false
+ apply_time_mask: true
+ time_mask_width_ratio_range:
+ - 0
+ - 0.05
+ num_time_mask: 10
+ fill_value: -100
+tokenizer: FunCineForgeTokenizer
+tokenizer_conf:
+ init_param_path: ${llm_conf.init_param_path}
+face_encoder: FaceRecIR101
+face_encoder_conf:
+ init_param_path: pretrained_models/face_recog_ir101.onnx
+enable_tf32: true
+debug: false
+device: cpu
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7eebc2ef21dd59dc462928fcd9ea4273822872e4
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,66 @@
+addict==2.4.0
+aiohttp==3.12.15
+anyio==4.9.0
+beartype==0.14.1
+click==8.2.2
+cryptography==45.0.5
+curl_cffi==0.12.0
+dashscope
+datasets==3.6.0
+deepspeed==0.18.0
+einops==0.8.1
+face_alignment==1.4.1
+fastapi==0.116.1
+fastcluster==1.3.0
+ffmpy==0.6.1
+filelock==3.18.0
+funasr==1.2.6
+gradio
+g4f
+hdbscan==0.8.40
+huggingface_hub==0.35.3
+imageio==2.37.0
+imageio-ffmpeg==0.6.0
+kaldiio==2.18.1
+joblib==1.5.1
+librosa==0.11.0
+matplotlib
+ml_collections==0.1.1
+modelscope==1.30.0
+moviepy==2.2.1
+numba==0.61.2
+numpy==2.2.6
+omegaconf==2.3.0
+onnx==1.20.1
+onnxruntime==1.23.1
+openai
+openai_whisper==20250625
+opencc_python_reimplemented==0.1.7
+opencv_python==4.12.0.88
+packaging
+pandas==2.3.1
+pillow==11.3.0
+propcache==0.3.2
+pyannote.audio
+pypinyin==0.44.0
+python_speech_features==0.6
+python-Levenshtein
+pytorch-lightning==2.6.0
+PyYAML==6.0.2
+Requests==2.32.5
+rotary_embedding_torch==0.8.9
+scikit-learn==1.7.1
+scipy==1.15.3
+setuptools==81.0.0
+simplejson
+soundfile==0.13.1
+starlette==0.47.2
+tensorboardX==2.6.4
+tensorflow>=2.16
+torch==2.4.1
+torchaudio==2.4.1
+torchvision==0.19.1
+tqdm==4.67.1
+transformers==4.57.0
+x_transformers==2.16.2
+umap_learn==0.5.7
diff --git a/speaker_diarization/.DS_Store b/speaker_diarization/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..4cc3bb9ccd677c3e99d7f05c1a8d811bea6fbed4
Binary files /dev/null and b/speaker_diarization/.DS_Store differ
diff --git a/speaker_diarization/README.md b/speaker_diarization/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fbf3b3b8e234696e8d0b68b0502441d52cdd43f9
--- /dev/null
+++ b/speaker_diarization/README.md
@@ -0,0 +1,50 @@
+# Speaker Diarization
+
+## Introduction
+This recipe offers speaker diarization methods that address the problem of "who spoke when". It provides multimodal diarization. The audio diarization comprises overlap detection, voice activity detection, speech segmentation, speaker embedding extraction. The video diarization comprises face detection, cctive speaker detection, face recognition, lip recognition.
+Then multimodal speaker clustering results are achieved.
+
+
+The DER results of two diarization pipelines on a multi-person conversation video dataset.
+| Pipeline | DER |
+|:-----:|:------:|
+|Audio-only diarization|5.3%|
+|Multimodal diarization|3.7%|
+
+## Usage
+### Quick Start
+
+Ensure that ffmpeg is available in your environment.
+``` sh
+sudo apt-get update
+sudo apt-get install ffmpeg
+```
+The [pyannote/segmentation-3.0](https://huggingface.co/pyannote/segmentation-3.0) is used as a overlapping speech detection module. Make sure to accept [pyannote/segmentation-3.0](https://huggingface.co/pyannote/segmentation-3.0) user conditions and create an access token at [hf.co/settings/tokens](https://hf.co/settings/tokens)
+
+- Stage1: Generate video.list and wav.list
+- Stage2: Process the wav and use the CAM++ speaker recognition model (Tongyi) to extract speaker embeddings (auditory modality) for each sub-segment of the audio.
+ - First, perform speaker overlap detection to obtain overlap.list.
+ - Delete speaker overlap samples to obtain clean_wav.list and clean_video.list.
+ - Use the [FSMN-Monophone VAD](https://www.modelscope.cn/models/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch) model to perform VAD on the audio and perform fine-grained segmentation to obtain json/vad.json.
+ - Prepare subsegment information to obtain json/subseg.json.
+ - Use [CAM++](https://www.modelscope.cn/models/iic/speech_campplus_sv_zh_en_16k-common_advanced) model to extract the speaker embedding of wav audio and save it to embs_wav.
+- Stage 3: Process the video and extract the speaker's facial data (visual modality) through a face detection model, an active speaker detection model, a face recognition model, and a facial landmark detection model.
+ - For 25fps video, sample one frame every 5 frames (every 0.2 seconds).
+ - Detect all faces in the sampled frames using the a lightweight fast [face detection](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB) model
+ - Score all faces using the [TalkNet-ASD](https://github.com/TaoRuijie/TalkNet-ASD) model, and use the face with the highest score as the active speaker's face
+ - (Optional, but not recommended) Use a [face quality assessment](https://modelscope.cn/models/iic/cv_manual_face-quality-assessment_fqa) model to filter out faces with poor quality.
+ - Use the [CurricularFace](https://github.com/HuangYG123/CurricularFace) model to extract the face embedding of the speaker in the active frame.
+ - Use the [FAN](https://github.com/1adrianb/face-alignment) model to perform 2D facial key point detection on the speaker's face, obtain the mouth coordinate (relative coordinates) of each face frame and extract the raw face and mouth data.
+- Stage 4: Joint cluster the audio and visual embeddings to obtain the multimodal active speaker detection results and save them in RTTM file.
+
+
+hf_access_token is your access token
+``` sh
+bash run.sh --stage 1 --stop_stage 4 --hf_access_token hf_xxx --root datasets/clean/zh --gpus "0 1 2 3"
+```
+
+To better understand the source code, you can refer to the **sample.mp4** and **run.sh** files in the subfolder **speaker_diarization_sample** to perform single-sample inference.
+
+## Limitations
+- It may not perform well when the audio duration is too short and when the number of speakers is too large.
+- The final accuracy is highly dependent on the performance of each modules. Among them, the ASD model affects the quality of the results
\ No newline at end of file
diff --git a/speaker_diarization/local/models/campplus/DTDNN.py b/speaker_diarization/local/models/campplus/DTDNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..b17d3c332364a6158fd84a1ad974674b97c0183d
--- /dev/null
+++ b/speaker_diarization/local/models/campplus/DTDNN.py
@@ -0,0 +1,112 @@
+from collections import OrderedDict
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from local.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear
+
+
+class FCM(nn.Module):
+ def __init__(self,
+ block=BasicResBlock,
+ num_blocks=[2, 2],
+ m_channels=32,
+ feat_dim=80):
+ super(FCM, self).__init__()
+ self.in_planes = m_channels
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(m_channels)
+
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
+
+ self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(m_channels)
+ self.out_channels = m_channels * (feat_dim // 8)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = F.relu(self.bn2(self.conv2(out)))
+
+ shape = out.shape
+ out = out.reshape(shape[0], shape[1]*shape[2], shape[3])
+ return out
+
+class CAMPPlus(nn.Module):
+ def __init__(self,
+ feat_dim=80,
+ embedding_size=512,
+ growth_rate=32,
+ bn_size=4,
+ init_channels=128,
+ config_str='batchnorm-relu',
+ memory_efficient=True):
+ super(CAMPPlus, self).__init__()
+
+ self.head = FCM(feat_dim=feat_dim)
+ channels = self.head.out_channels
+
+ self.xvector = nn.Sequential(
+ OrderedDict([
+
+ ('tdnn',
+ TDNNLayer(channels,
+ init_channels,
+ 5,
+ stride=2,
+ dilation=1,
+ padding=-1,
+ config_str=config_str)),
+ ]))
+ channels = init_channels
+ for i, (num_layers, kernel_size,
+ dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
+ block = CAMDenseTDNNBlock(num_layers=num_layers,
+ in_channels=channels,
+ out_channels=growth_rate,
+ bn_channels=bn_size * growth_rate,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ config_str=config_str,
+ memory_efficient=memory_efficient)
+ self.xvector.add_module('block%d' % (i + 1), block)
+ channels = channels + num_layers * growth_rate
+ self.xvector.add_module(
+ 'transit%d' % (i + 1),
+ TransitLayer(channels,
+ channels // 2,
+ bias=False,
+ config_str=config_str))
+ channels //= 2
+
+ self.xvector.add_module(
+ 'out_nonlinear', get_nonlinear(config_str, channels))
+
+ self.xvector.add_module('stats', StatsPool())
+ self.xvector.add_module(
+ 'dense',
+ DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
+
+ for m in self.modules():
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.kaiming_normal_(m.weight.data)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+ x = self.head(x)
+ x = self.xvector(x)
+ return x
diff --git a/speaker_diarization/local/models/campplus/classifier.py b/speaker_diarization/local/models/campplus/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6c5e9d24d59d9af24797a1c20a3480dafc615c4
--- /dev/null
+++ b/speaker_diarization/local/models/campplus/classifier.py
@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from local.models.campplus.layers import DenseLayer
+
+
+class CosineClassifier(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_blocks=0,
+ inter_dim=512,
+ out_neurons=1000,
+ ):
+
+ super().__init__()
+ self.blocks = nn.ModuleList()
+
+ for index in range(num_blocks):
+ self.blocks.append(
+ DenseLayer(input_dim, inter_dim, config_str='batchnorm')
+ )
+ input_dim = inter_dim
+
+ self.weight = nn.Parameter(
+ torch.FloatTensor(out_neurons, input_dim)
+ )
+ nn.init.xavier_uniform_(self.weight)
+
+ def forward(self, x):
+ # x: [B, dim]
+ for layer in self.blocks:
+ x = layer(x)
+
+ # normalized
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
+ return x
+
+class LinearClassifier(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_blocks=0,
+ inter_dim=512,
+ out_neurons=1000,
+ ):
+
+ super().__init__()
+ self.blocks = nn.ModuleList()
+
+ self.nonlinear = nn.ReLU(inplace=True)
+ for index in range(num_blocks):
+ self.blocks.append(
+ DenseLayer(input_dim, inter_dim, bias=True)
+ )
+ input_dim = inter_dim
+
+ self.linear = nn.Linear(input_dim, out_neurons, bias=True)
+
+ def forward(self, x):
+ # x: [B, dim]
+ x = self.nonlinear(x)
+ for layer in self.blocks:
+ x = layer(x)
+ x = self.linear(x)
+ return x
diff --git a/speaker_diarization/local/models/campplus/layers.py b/speaker_diarization/local/models/campplus/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc3de0bbb94a042c577a371d3e8eb000ec81a730
--- /dev/null
+++ b/speaker_diarization/local/models/campplus/layers.py
@@ -0,0 +1,250 @@
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch import nn
+
+
+def get_nonlinear(config_str, channels):
+ nonlinear = nn.Sequential()
+ for name in config_str.split('-'):
+ if name == 'relu':
+ nonlinear.add_module('relu', nn.ReLU(inplace=True))
+ elif name == 'prelu':
+ nonlinear.add_module('prelu', nn.PReLU(channels))
+ elif name == 'batchnorm':
+ nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
+ elif name == 'batchnorm_':
+ nonlinear.add_module('batchnorm',
+ nn.BatchNorm1d(channels, affine=False))
+ else:
+ raise ValueError('Unexpected module ({}).'.format(name))
+ return nonlinear
+
+def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
+ mean = x.mean(dim=dim)
+ std = x.std(dim=dim, unbiased=unbiased)
+ stats = torch.cat([mean, std], dim=-1)
+ if keepdim:
+ stats = stats.unsqueeze(dim=dim)
+ return stats
+
+
+class StatsPool(nn.Module):
+ def forward(self, x):
+ return statistics_pooling(x)
+
+
+class TDNNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu'):
+ super(TDNNLayer, self).__init__()
+ if padding < 0:
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+ kernel_size)
+ padding = (kernel_size - 1) // 2 * dilation
+ self.linear = nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+
+class CAMLayer(nn.Module):
+ def __init__(self,
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias,
+ reduction=2):
+ super(CAMLayer, self).__init__()
+ self.linear_local = nn.Conv1d(bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
+ self.relu = nn.ReLU(inplace=True)
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ y = self.linear_local(x)
+ context = x.mean(-1, keepdim=True)+self.seg_pooling(x)
+ context = self.relu(self.linear1(context))
+ m = self.sigmoid(self.linear2(context))
+ return y*m
+
+ def seg_pooling(self, x, seg_len=100, stype='avg'):
+ if stype == 'avg':
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ elif stype == 'max':
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ else:
+ raise ValueError('Wrong segment pooling type.')
+ shape = seg.shape
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
+ seg = seg[..., :x.shape[-1]]
+ return seg
+
+
+class CAMDenseTDNNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu',
+ memory_efficient=False):
+ super(CAMDenseTDNNLayer, self).__init__()
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+ kernel_size)
+ padding = (kernel_size - 1) // 2 * dilation
+ self.memory_efficient = memory_efficient
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
+ self.cam_layer = CAMLayer(bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ def bn_function(self, x):
+ return self.linear1(self.nonlinear1(x))
+
+ def forward(self, x):
+ if self.training and self.memory_efficient:
+ x = cp.checkpoint(self.bn_function, x)
+ else:
+ x = self.bn_function(x)
+ x = self.cam_layer(self.nonlinear2(x))
+ return x
+
+
+class CAMDenseTDNNBlock(nn.ModuleList):
+ def __init__(self,
+ num_layers,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu',
+ memory_efficient=False):
+ super(CAMDenseTDNNBlock, self).__init__()
+ for i in range(num_layers):
+ layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
+ out_channels=out_channels,
+ bn_channels=bn_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ bias=bias,
+ config_str=config_str,
+ memory_efficient=memory_efficient)
+ self.add_module('tdnnd%d' % (i + 1), layer)
+
+ def forward(self, x):
+ for layer in self:
+ x = torch.cat([x, layer(x)], dim=1)
+ return x
+
+
+class TransitLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=True,
+ config_str='batchnorm-relu'):
+ super(TransitLayer, self).__init__()
+ self.nonlinear = get_nonlinear(config_str, in_channels)
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+
+ def forward(self, x):
+ x = self.nonlinear(x)
+ x = self.linear(x)
+ return x
+
+
+class DenseLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=False,
+ config_str='batchnorm-relu'):
+ super(DenseLayer, self).__init__()
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ if len(x.shape) == 2:
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
+ else:
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+
+class BasicResBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicResBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=(stride, 1),
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=(stride, 1),
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
diff --git a/speaker_diarization/local/models/talknet/attentionLayer.py b/speaker_diarization/local/models/talknet/attentionLayer.py
new file mode 100644
index 0000000000000000000000000000000000000000..17853119aea22050ada569b2c0d1bbfa918b9074
--- /dev/null
+++ b/speaker_diarization/local/models/talknet/attentionLayer.py
@@ -0,0 +1,35 @@
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn import MultiheadAttention
+
+class attentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.1):
+ super(attentionLayer, self).__init__()
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ self.linear1 = nn.Linear(d_model, d_model * 4)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_model * 4, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = F.relu
+
+ def forward(self, src, tar):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
+ src = src.transpose(0, 1) # B, T, C -> T, B, C
+ tar = tar.transpose(0, 1) # B, T, C -> T, B, C
+ src2 = self.self_attn(tar, src, src, attn_mask=None,
+ key_padding_mask=None)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ src = src.transpose(0, 1) # T, B, C -> B, T, C
+ return src
diff --git a/speaker_diarization/local/models/talknet/audioEncoder.py b/speaker_diarization/local/models/talknet/audioEncoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..262a45dbcce1d12f76902b7a1688549e66a9193d
--- /dev/null
+++ b/speaker_diarization/local/models/talknet/audioEncoder.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class SEBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
+ super(SEBasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.se = SELayer(planes, reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.relu(out)
+ out = self.bn1(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+ return out
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=8):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+class audioEncoder(nn.Module):
+ def __init__(self, layers, num_filters, **kwargs):
+ super(audioEncoder, self).__init__()
+ block = SEBasicBlock
+ self.inplanes = num_filters[0]
+
+ self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=7, stride=(2, 1), padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, num_filters[0], layers[0])
+ self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
+ self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
+ self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(1, 1))
+ out_dim = num_filters[3] * block.expansion
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = torch.mean(x, dim=2, keepdim=True)
+ x = x.view((x.size()[0], x.size()[1], -1))
+ x = x.transpose(1, 2)
+
+ return x
diff --git a/speaker_diarization/local/models/talknet/talknet.py b/speaker_diarization/local/models/talknet/talknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e5c259b721ffb48a7bc9d8b82f92706518fb43
--- /dev/null
+++ b/speaker_diarization/local/models/talknet/talknet.py
@@ -0,0 +1,69 @@
+import torch
+import torch.nn as nn
+from local.models.talknet.audioEncoder import audioEncoder
+from local.models.talknet.visualEncoder import visualFrontend, visualTCN, visualConv1D
+from local.models.talknet.attentionLayer import attentionLayer
+
+class talkNetModel(nn.Module):
+ """
+ TalkNet model for active speaker detection task.
+ Reference:
+ - Is someone talking? TalkNet: Audio-visual active speaker detection Model.
+ - https://github.com/TaoRuijie/TalkNet-ASD
+ """
+ def __init__(self):
+ super(talkNetModel, self).__init__()
+ # Visual Temporal Encoder
+ self.visualFrontend = visualFrontend() # Visual Frontend
+ self.visualTCN = visualTCN() # Visual Temporal Network TCN
+ self.visualConv1D = visualConv1D() # Visual Temporal Network Conv1d
+
+ # Audio Temporal Encoder
+ self.audioEncoder = audioEncoder(layers = [3, 4, 6, 3], num_filters = [16, 32, 64, 128])
+
+ # Audio-visual Cross Attention
+ self.crossA2V = attentionLayer(d_model = 128, nhead = 8)
+ self.crossV2A = attentionLayer(d_model = 128, nhead = 8)
+
+ # Audio-visual Self Attention
+ self.selfAV = attentionLayer(d_model = 256, nhead = 8)
+
+ # Classifier
+ self.fcAV = nn.Linear(256, 2)
+ self.fcA = nn.Linear(128, 2)
+ self.fcV = nn.Linear(128, 2)
+
+ def visual_frontend(self, x):
+ B, T, W, H = x.shape
+ x = x.view(B*T, 1, 1, W, H)
+ x = (x / 255 - 0.4161) / 0.1688
+ x = self.visualFrontend(x)
+ x = x.view(B, T, 512)
+ x = x.transpose(1,2)
+ x = self.visualTCN(x)
+ x = self.visualConv1D(x)
+ x = x.transpose(1,2)
+ return x
+
+ def audio_frontend(self, x):
+ x = x.unsqueeze(1).transpose(2, 3)
+ x = self.audioEncoder(x)
+ return x
+
+ def cross_attention(self, x1, x2):
+ x1_c = self.crossA2V(src = x1, tar = x2)
+ x2_c = self.crossV2A(src = x2, tar = x1)
+ return x1_c, x2_c
+
+ def audio_visual_backend(self, x1, x2):
+ x = torch.cat((x1,x2), 2)
+ x = self.selfAV(src = x, tar = x)
+ return x
+
+ def forward(self, audioX, visualX):
+ audioX = self.audio_frontend(audioX)
+ visualX = self.visual_frontend(visualX)
+ audioX, visualX = self.cross_attention(audioX, visualX)
+ audio_visualX = self.audio_visual_backend(audioX, visualX)
+
+ return self.fcAV(audio_visualX), self.fcA(audioX), self.fcV(visualX)
diff --git a/speaker_diarization/local/models/talknet/visualEncoder.py b/speaker_diarization/local/models/talknet/visualEncoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..be9d954bcfdf4aedcf63324dc90ad743e54a65c1
--- /dev/null
+++ b/speaker_diarization/local/models/talknet/visualEncoder.py
@@ -0,0 +1,163 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResNetLayer(nn.Module):
+
+ """
+ A ResNet layer used to build the ResNet network.
+ Architecture:
+ --> conv-bn-relu -> conv -> + -> bn-relu -> conv-bn-relu -> conv -> + -> bn-relu -->
+ | | | |
+ -----> downsample ------> ------------------------------------->
+ """
+
+ def __init__(self, inplanes, outplanes, stride):
+ super(ResNetLayer, self).__init__()
+ self.conv1a = nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1a = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+ self.conv2a = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.downsample = nn.Sequential()
+ if stride != 1:
+ self.downsample = nn.Conv2d(inplanes, outplanes, kernel_size=(1,1), stride=stride, bias=False)
+ self.outbna = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+
+ self.conv1b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1b = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+ self.conv2b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.outbnb = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+
+ def forward(self, inputBatch):
+ batch = F.relu(self.bn1a(self.conv1a(inputBatch)))
+ batch = self.conv2a(batch)
+ residualBatch = self.downsample(inputBatch)
+ batch = batch + residualBatch
+ intermediateBatch = batch
+ batch = F.relu(self.outbna(batch))
+
+ batch = F.relu(self.bn1b(self.conv1b(batch)))
+ batch = self.conv2b(batch)
+ residualBatch = intermediateBatch
+ batch = batch + residualBatch
+ outputBatch = F.relu(self.outbnb(batch))
+ return outputBatch
+
+
+
+class ResNet(nn.Module):
+
+ """
+ An 18-layer ResNet architecture.
+ """
+
+ def __init__(self):
+ super(ResNet, self).__init__()
+ self.layer1 = ResNetLayer(64, 64, stride=1)
+ self.layer2 = ResNetLayer(64, 128, stride=2)
+ self.layer3 = ResNetLayer(128, 256, stride=2)
+ self.layer4 = ResNetLayer(256, 512, stride=2)
+ self.avgpool = nn.AvgPool2d(kernel_size=(4,4), stride=(1,1))
+
+ return
+
+
+ def forward(self, inputBatch):
+ batch = self.layer1(inputBatch)
+ batch = self.layer2(batch)
+ batch = self.layer3(batch)
+ batch = self.layer4(batch)
+ outputBatch = self.avgpool(batch)
+ return outputBatch
+
+
+class GlobalLayerNorm(nn.Module):
+ def __init__(self, channel_size):
+ super(GlobalLayerNorm, self).__init__()
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ self.gamma.data.fill_(1)
+ self.beta.data.zero_()
+
+ def forward(self, y):
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
+ var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + 1e-8, 0.5) + self.beta
+ return gLN_y
+
+class visualFrontend(nn.Module):
+
+ """
+ A visual feature extraction module. Generates a 512-dim feature vector per video frame.
+ Architecture: A 3D convolution block followed by an 18-layer ResNet.
+ """
+
+ def __init__(self):
+ super(visualFrontend, self).__init__()
+ self.frontend3D = nn.Sequential(
+ nn.Conv3d(1, 64, kernel_size=(5,7,7), stride=(1,2,2), padding=(2,3,3), bias=False),
+ nn.BatchNorm3d(64, momentum=0.01, eps=0.001),
+ nn.ReLU(),
+ nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))
+ )
+ self.resnet = ResNet()
+ return
+
+
+ def forward(self, inputBatch):
+ inputBatch = inputBatch.transpose(0, 1).transpose(1, 2)
+ batchsize = inputBatch.shape[0]
+ batch = self.frontend3D(inputBatch)
+
+ batch = batch.transpose(1, 2)
+ batch = batch.reshape(batch.shape[0]*batch.shape[1], batch.shape[2], batch.shape[3], batch.shape[4])
+ outputBatch = self.resnet(batch)
+ outputBatch = outputBatch.reshape(batchsize, -1, 512)
+ outputBatch = outputBatch.transpose(1 ,2)
+ outputBatch = outputBatch.transpose(1, 2).transpose(0, 1)
+ return outputBatch
+
+class DSConv1d(nn.Module):
+ def __init__(self):
+ super(DSConv1d, self).__init__()
+ self.net = nn.Sequential(
+ nn.ReLU(),
+ nn.BatchNorm1d(512),
+ nn.Conv1d(512, 512, 3, stride=1, padding=1,dilation=1, groups=512, bias=False),
+ nn.PReLU(),
+ GlobalLayerNorm(512),
+ nn.Conv1d(512, 512, 1, bias=False),
+ )
+
+ def forward(self, x):
+ out = self.net(x)
+ return out + x
+
+class visualTCN(nn.Module):
+ def __init__(self):
+ super(visualTCN, self).__init__()
+ stacks = []
+ for x in range(5):
+ stacks += [DSConv1d()]
+ self.net = nn.Sequential(*stacks) # Visual Temporal Network V-TCN
+
+ def forward(self, x):
+ out = self.net(x)
+ return out
+
+class visualConv1D(nn.Module):
+ def __init__(self):
+ super(visualConv1D, self).__init__()
+ self.net = nn.Sequential(
+ nn.Conv1d(512, 256, 5, stride=1, padding=2),
+ nn.BatchNorm1d(256),
+ nn.ReLU(),
+ nn.Conv1d(256, 128, 1),
+ )
+
+ def forward(self, x):
+ out = self.net(x)
+ return out
diff --git a/speaker_diarization/local/process/augmentation.py b/speaker_diarization/local/process/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bad4b7d976ec202464874828988bb384e1dee67
--- /dev/null
+++ b/speaker_diarization/local/process/augmentation.py
@@ -0,0 +1,92 @@
+import torch
+import torchaudio
+from scipy import signal
+import numpy as np
+import random
+
+from local.utils.fileio import load_wav_scp
+
+def addreverb(wav, rir_wav):
+ # wav: [T,], rir_wav: [T,]
+ wav = wav.numpy()
+ rir_wav = rir_wav.numpy()
+ wav_len = wav.shape[0]
+ rir_wav = rir_wav / np.sqrt(np.sum(rir_wav**2))
+ out_wav = signal.convolve(wav, rir_wav,
+ mode='full')[:wav_len]
+
+ out_wav = out_wav / (np.max(np.abs(out_wav)) + 1e-6)
+ return torch.from_numpy(out_wav)
+
+def addnoise(wav, noise=None, snr_high=15, snr_low=0):
+ # wav: [T,], noise: [T,]
+ if noise is None:
+ noise = torch.randn_like(wav)
+ noise = noise.numpy()
+ wav = wav.numpy()
+
+ wav_len = wav.shape[0]
+ noise_len = noise.shape[0]
+ if noise_len >= wav_len:
+ start = random.randint(0, noise_len - wav_len)
+ noise = noise[start:start + wav_len]
+ else:
+ noise = noise.repeat(wav_len // noise_len + 1)
+ noise = noise[:wav_len]
+
+ wav_db = 10 * np.log10(np.mean(wav**2) + 1e-6)
+ noise_db = 10 * np.log10(np.mean(noise**2) + 1e-6)
+ noise_snr = random.uniform(snr_low, snr_high)
+ noise = np.sqrt(10**(
+ (wav_db - noise_db - noise_snr) / 10)) * noise
+ out_wav = wav + noise
+
+ out_wav = out_wav / (np.max(np.abs(out_wav)) + 1e-6)
+ return torch.from_numpy(out_wav)
+
+
+class NoiseReverbCorrupter(object):
+ def __init__(
+ self,
+ noise_prob=0.0,
+ reverb_prob=0.0,
+ noise_file=None,
+ reverb_file=None,
+ noise_snr_low=0,
+ noise_snr_high=15,
+ ):
+ if reverb_prob > 0.0:
+ if reverb_file is None:
+ raise ValueError('Reverb_file not be assigned.')
+ self.add_reverb = addreverb
+ self.reverb_data = load_wav_scp(reverb_file)
+ self.reverb_data_keys = list(self.reverb_data.keys())
+
+ if noise_prob > 0.0:
+ if noise_file is None:
+ raise ValueError('Noise_file not be assigned.')
+
+ self.add_noise = addnoise
+ self.noise_data = load_wav_scp(noise_file)
+ self.noise_data_keys = list(self.noise_data.keys())
+
+ self.reverb_prob = reverb_prob
+ self.noise_prob = noise_prob
+ self.noise_snr_low = noise_snr_low
+ self.noise_snr_high = noise_snr_high
+
+ def __call__(self, wav, fs=16000):
+ if self.reverb_prob > random.random():
+ reverb_path = self.reverb_data[random.choice(self.reverb_data_keys)]
+ reverb, fs_rir = torchaudio.load(reverb_path)
+ assert fs_rir == fs
+ wav = self.add_reverb(wav, reverb[0])
+ if self.noise_prob > random.random():
+ noise_path = self.noise_data[random.choice(self.noise_data_keys)]
+ noise, fs_noise = torchaudio.load(noise_path)
+ assert fs_noise == fs
+ wav = self.add_noise(
+ wav, noise[0],
+ snr_high=self.noise_snr_high,
+ snr_low=self.noise_snr_low,)
+ return wav
diff --git a/speaker_diarization/local/process/cluster.py b/speaker_diarization/local/process/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61b179d21c90e086c355ddbe6fb7131b55c4d0f
--- /dev/null
+++ b/speaker_diarization/local/process/cluster.py
@@ -0,0 +1,361 @@
+import numpy as np
+import scipy
+from sklearn.cluster._kmeans import k_means
+from sklearn.metrics.pairwise import cosine_similarity
+
+import fastcluster
+from scipy.cluster.hierarchy import fcluster
+from scipy.spatial.distance import squareform
+
+try:
+ import umap, hdbscan
+except ImportError:
+ raise ImportError(
+ "Package \"umap\" or \"hdbscan\" not found. \
+ Please install them first by \"pip install umap-learn hdbscan\"."
+ )
+
+
+class SpectralCluster:
+ """A spectral clustering method using unnormalized Laplacian of affinity matrix.
+ This implementation is adapted from https://github.com/speechbrain/speechbrain.
+ """
+
+ def __init__(self, min_num_spks=1, max_num_spks=10, pval=0.02, min_pnum=6, oracle_num=None):
+ self.min_num_spks = min_num_spks
+ self.max_num_spks = max_num_spks
+ self.min_pnum = min_pnum
+ self.pval = pval
+ self.k = oracle_num
+
+ def __call__(self, X, **kwargs):
+ pval = kwargs.get('pval', None)
+ oracle_num = kwargs.get('speaker_num', None)
+
+ # Similarity matrix computation
+ sim_mat = self.get_sim_mat(X)
+
+ # Refining similarity matrix with pval
+ prunned_sim_mat = self.p_pruning(sim_mat, pval)
+
+ # Symmetrization
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
+
+ # Laplacian calculation
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
+
+ # Get Spectral Embeddings
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
+
+ # Perform clustering
+ labels = self.cluster_embs(emb, num_of_spk)
+
+ return labels
+
+ def get_sim_mat(self, X):
+ # Cosine similarities
+ M = cosine_similarity(X, X)
+ return M
+
+ def p_pruning(self, A, pval=None):
+ if pval is None:
+ pval = self.pval
+ n_elems = int((1 - pval) * A.shape[0])
+ n_elems = min(n_elems, A.shape[0]-self.min_pnum)
+
+ # For each row in a affinity matrix
+ for i in range(A.shape[0]):
+ low_indexes = np.argsort(A[i, :])
+ low_indexes = low_indexes[0:n_elems]
+
+ # Replace smaller similarity values by 0s
+ A[i, low_indexes] = 0
+ return A
+
+ def get_laplacian(self, M):
+ M[np.diag_indices(M.shape[0])] = 0
+ D = np.sum(np.abs(M), axis=1)
+ D = np.diag(D)
+ L = D - M
+ return L
+
+ def get_spec_embs(self, L, k_oracle=None):
+ if k_oracle is None:
+ k_oracle = self.k
+
+ lambdas, eig_vecs = scipy.sparse.linalg.eigsh(L, k=min(self.max_num_spks+1, L.shape[0]), which='SM')
+
+ if k_oracle is not None:
+ num_of_spk = k_oracle
+ else:
+ lambda_gap_list = self.getEigenGaps(
+ lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
+ num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
+
+ emb = eig_vecs[:, :num_of_spk]
+ return emb, num_of_spk
+
+ def cluster_embs(self, emb, k):
+ # k-means
+ _, labels, _ = k_means(emb, k)
+ return labels
+
+ def getEigenGaps(self, eig_vals):
+ eig_vals_gap_list = []
+ for i in range(len(eig_vals) - 1):
+ gap = float(eig_vals[i + 1]) - float(eig_vals[i])
+ eig_vals_gap_list.append(gap)
+ return eig_vals_gap_list
+
+
+class UmapHdbscan:
+ """
+ Reference:
+ - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
+ Emphasis On Topological Structure. ICASSP2022
+ """
+
+ def __init__(self, n_neighbors=20, n_components=60, min_samples=20, min_cluster_size=10, metric='euclidean'):
+ self.n_neighbors = n_neighbors
+ self.n_components = n_components
+ self.min_samples = min_samples
+ self.min_cluster_size = min_cluster_size
+ self.metric = metric
+
+ def __call__(self, X, **kwargs):
+ umap_X = umap.UMAP(
+ n_neighbors=self.n_neighbors,
+ min_dist=0.0,
+ n_components=min(self.n_components, X.shape[0]-2),
+ metric=self.metric,
+ ).fit_transform(X)
+ labels = hdbscan.HDBSCAN(min_samples=self.min_samples, min_cluster_size=self.min_cluster_size).fit_predict(umap_X)
+ return labels
+
+class AHCluster:
+ """
+ Agglomerative Hierarchical Clustering, a bottom-up approach which iteratively merges
+ the closest clusters until a termination condition is reached.
+ This implementation is adapted from https://github.com/BUTSpeechFIT/VBx.
+ """
+
+ def __init__(self, fix_cos_thr=0.4):
+ self.fix_cos_thr = fix_cos_thr
+
+ def __call__(self, X, **kwargs):
+ scr_mx = cosine_similarity(X)
+ scr_mx = squareform(-scr_mx, checks=False)
+ lin_mat = fastcluster.linkage(scr_mx, method='average', preserve_input='False')
+ adjust = abs(lin_mat[:, 2].min())
+ lin_mat[:, 2] += adjust
+ labels = fcluster(lin_mat, -self.fix_cos_thr + adjust, criterion='distance') - 1
+ return labels
+
+
+class CommonClustering:
+ """Perfom clustering for input embeddings and output the labels.
+ """
+
+ def __init__(self, cluster_type, cluster_line=40, mer_cos=None, min_cluster_size=4, **kwargs):
+ self.cluster_type = cluster_type
+ self.cluster_line = cluster_line
+ self.min_cluster_size = min_cluster_size
+ self.mer_cos = mer_cos
+ if self.cluster_type == 'spectral':
+ self.cluster = SpectralCluster(**kwargs)
+ elif self.cluster_type == 'umap_hdbscan':
+ kwargs['min_cluster_size'] = min_cluster_size
+ self.cluster = UmapHdbscan(**kwargs)
+ elif self.cluster_type == 'AHC':
+ self.cluster = AHCluster(**kwargs)
+ else:
+ raise ValueError(
+ '%s is not currently supported.' % self.cluster_type
+ )
+ if self.cluster_type != 'AHC':
+ self.cluster_for_short = AHCluster()
+ else:
+ self.cluster_for_short = self.cluster
+
+ def __call__(self, X, **kwargs):
+ # clustering and return the labels
+ assert len(X.shape) == 2, 'Shape of input should be [N, C]'
+ if X.shape[0] <= 1:
+ return np.zeros(X.shape[0], dtype=int)
+ if X.shape[0] < self.cluster_line:
+ labels = self.cluster_for_short(X)
+ else:
+ labels = self.cluster(X, **kwargs)
+
+ # remove extremely minor cluster
+ labels = self.filter_minor_cluster(labels, X, self.min_cluster_size)
+ # merge similar speaker
+ if self.mer_cos is not None:
+ labels = self.merge_by_cos(labels, X, self.mer_cos)
+
+ return labels
+
+ def filter_minor_cluster(self, labels, x, min_cluster_size):
+ cset = np.unique(labels)
+ csize = np.array([(labels == i).sum() for i in cset])
+ minor_idx = np.where(csize <= self.min_cluster_size)[0]
+ if len(minor_idx) == 0:
+ return labels
+
+ minor_cset = cset[minor_idx]
+ major_idx = np.where(csize > self.min_cluster_size)[0]
+ if len(major_idx) == 0:
+ return np.zeros_like(labels)
+ major_cset = cset[major_idx]
+ major_center = np.stack([x[labels == i].mean(0) \
+ for i in major_cset])
+ for i in range(len(labels)):
+ if labels[i] in minor_cset:
+ cos_sim = cosine_similarity(x[i][np.newaxis], major_center)
+ labels[i] = major_cset[cos_sim.argmax()]
+
+ return labels
+
+ def merge_by_cos(self, labels, x, cos_thr):
+ # merge the similar speakers by cosine similarity
+ assert cos_thr > 0 and cos_thr <= 1
+ while True:
+ cset = np.unique(labels)
+ if len(cset) == 1:
+ break
+ centers = np.stack([x[labels == i].mean(0) \
+ for i in cset])
+ affinity = cosine_similarity(centers, centers)
+ affinity = np.triu(affinity, 1)
+ idx = np.unravel_index(np.argmax(affinity), affinity.shape)
+ if affinity[idx] < cos_thr:
+ break
+ c1, c2 = cset[np.array(idx)]
+ labels[labels==c2]=c1
+ return labels
+
+
+class JointClustering:
+ """Perfom joint clustering for input audio and visual embeddings and output the labels.
+ """
+
+ def __init__(self, audio_cluster, vision_cluster):
+ self.audio_cluster = audio_cluster
+ self.vision_cluster = vision_cluster
+
+ def __call__(self, audioX, visionX, audioT, visionT, conf):
+ # audio-only and video-only clustering
+ alabels = self.audio_cluster(audioX)
+ vlabels = self.vision_cluster(visionX)
+
+ alabels = self.arrange_labels(alabels)
+ vlist, vspk_embs, vspk_dur = self.get_vlist_embs(audioX, alabels, vlabels, audioT, visionT, conf)
+
+ # modify alabels according to vlabels
+ aspk_num = alabels.max()+1
+ for i in range(aspk_num):
+ aspki_index = np.where(alabels==i)[0]
+ aspki_embs = audioX[alabels==i]
+
+ aspkiT_part = np.array(audioT)[alabels==i]
+ overlap_vspk = self.overlap_spks(self.cast_overlap(aspkiT_part), vlist, vspk_dur)
+ if len(overlap_vspk) > 1:
+ centers = np.stack([vspk_embs[s] for s in overlap_vspk])
+ distribute_labels = self.distribute_embs(aspki_embs, centers)
+ for j in range(distribute_labels.max()+1):
+ for loc in aspki_index[distribute_labels==j]:
+ alabels[loc] = overlap_vspk[j]
+ elif len(overlap_vspk) == 1:
+ for loc in aspki_index:
+ alabels[loc] = overlap_vspk[0]
+
+ alabels = self.arrange_labels(alabels)
+ return alabels
+
+ def overlap_spks(self, times, vlist, vspk_dur=None):
+ # get the vspk that overlaps with times.
+ overlap_dur = {}
+ for [a_st, a_ed] in times:
+ for [v_st, v_ed, v_id] in vlist:
+ if a_ed > v_st and v_ed > a_st:
+ if v_id not in overlap_dur:
+ overlap_dur[v_id]=0
+ overlap_dur[v_id] += min(a_ed, v_ed) - max(a_st, v_st)
+ vspk_list = []
+ for v_id, dur in overlap_dur.items():
+ # set the criteria for confirming overlap.
+ if (vspk_dur is None and dur > 0.5) or (vspk_dur is not None and dur > min(vspk_dur[v_id]*0.5, 0.5)):
+ vspk_list.append(v_id)
+ return vspk_list
+
+ def distribute_embs(self, embs, centers):
+ # embs: [n, D]. centers: [k, D]
+ norm_centers = centers / np.linalg.norm(centers, axis=1, keepdims=True)
+ norm_embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
+ similarity = np.matmul(norm_embs, norm_centers.T) # [n, k]
+ argsort = np.argsort(similarity, axis=-1)
+ return argsort[:, -1]
+
+ def get_vlist_embs(self, audioX, alabels, vlabels, audioT, visionT, conf):
+ assert len(vlabels) == len(visionT)
+ vlist = []
+ for i, ti in enumerate(visionT):
+ if len(vlist)==0 or vlabels[i] != vlist[-1][2] or ti - visionT[i-1] > conf.face_det_stride*0.04 + 1e-4:
+ if len(vlist) > 0 and vlist[-1][1] - vlist[-1][0] < 1e-4:
+ # remove too short intervals.
+ vlist.pop()
+ vlist.append([ti, ti, vlabels[i]])
+ else:
+ vlist[-1][1] = ti
+
+ # adjust vision labels
+ vlabels_arrange = self.arrange_labels([i[2] for i in vlist], a_st=alabels.max()+1)
+ vlist = [[i[0], i[1], j] for i, j in zip(vlist, vlabels_arrange)]
+
+ # get audio spk embs aligning with 'vlist'
+ vspk_embs = {}
+ for [v_st, v_ed, v_id] in vlist:
+ for i, [a_st, a_ed] in enumerate(audioT):
+ if a_ed >= v_st and v_ed >= a_st:
+ if min(a_ed, v_ed) - max(a_st, v_st) > 1:
+ if v_id not in vspk_embs:
+ vspk_embs[v_id] = []
+ vspk_embs[v_id].append(audioX[i])
+ for k in vspk_embs:
+ vspk_embs[k] = np.stack(vspk_embs[k]).mean(0)
+
+ vlist_new = []
+ for i in vlist:
+ if i[2] in vspk_embs:
+ vlist_new.append(i)
+ # get duration of v_spk
+ vspk_dur = {}
+ for i in vlist_new:
+ if i[2] not in vspk_dur:
+ vspk_dur[i[2]]=0
+ vspk_dur[i[2]] += i[1]-i[0]
+
+ return vlist_new, vspk_embs, vspk_dur
+
+ def cast_overlap(self, input_time):
+ if len(input_time)==0:
+ return input_time
+ output_time = []
+ for i in range(0, len(input_time)-1):
+ if i == 0 or output_time[-1][1] < input_time[i][0]:
+ output_time.append(input_time[i])
+ else:
+ output_time[-1][1] = input_time[i][1]
+ return output_time
+
+ def arrange_labels(self, labels, a_st=0):
+ # arrange labels in order from 0.
+ new_labels = []
+ labels_dict = {}
+ idx = a_st
+ for i in labels:
+ if i not in labels_dict:
+ labels_dict[i] = idx
+ idx += 1
+ new_labels.append(labels_dict[i])
+ return np.array(new_labels)
diff --git a/speaker_diarization/local/process/processor.py b/speaker_diarization/local/process/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b69a93649ec01a197c66c638e343a59bad531ca
--- /dev/null
+++ b/speaker_diarization/local/process/processor.py
@@ -0,0 +1,155 @@
+import random
+import pickle
+import torch
+import torchaudio
+import torch.nn.functional as F
+import torchaudio.compliance.kaldi as Kaldi
+
+from local.process.augmentation import NoiseReverbCorrupter
+from local.utils.fileio import load_data_csv
+
+
+class WavReader(object):
+ def __init__(self,
+ sample_rate = 16000,
+ duration: float = 3.0,
+ speed_pertub: bool = False,
+ lm: bool = True,
+ ):
+ self.duration = duration
+ self.sample_rate = sample_rate
+ self.speed_pertub = speed_pertub
+ self.lm = lm
+
+ def __call__(self, wav_path):
+ wav, sr = torchaudio.load(wav_path)
+ assert sr == self.sample_rate
+ wav = wav[0]
+
+ if self.speed_pertub and self.lm:
+ speeds = [1.0, 0.9, 1.1]
+ speed_idx = random.randint(0, 2)
+ if speed_idx > 0:
+ wav, _ = torchaudio.sox_effects.apply_effects_tensor(
+ wav.unsqueeze(0), self.sample_rate, [['speed', str(speeds[speed_idx])], ['rate', str(self.sample_rate)]])
+ else:
+ speed_idx = 0
+
+ wav = wav.squeeze(0)
+ data_len = wav.shape[0]
+
+ chunk_len = int(self.duration * sr)
+ if data_len >= chunk_len:
+ start = random.randint(0, data_len - chunk_len)
+ end = start + chunk_len
+ wav = wav[start:end]
+ else:
+ wav = F.pad(wav, (0, chunk_len - data_len))
+
+ return wav, speed_idx
+
+class SpkLabelEncoder(object):
+ def __init__(self, data_file):
+ self.lab2ind = {}
+ self.ind2lab = {}
+ self.starting_index = -1
+ self.load_from_csv(data_file)
+
+ def __call__(self, spk, speed_idx=0):
+ spkid = self.lab2ind[spk]
+ spkid = spkid + len(self.lab2ind) * speed_idx
+ return spkid
+
+ def load_from_csv(self, path):
+ self.data = load_data_csv(path)
+ for key in self.data:
+ self.add(self.data[key]['spk'])
+
+ def add(self, label):
+ if label in self.lab2ind:
+ return
+ index = self._next_index()
+ self.lab2ind[label] = index
+ self.ind2lab[index] = label
+
+ def _next_index(self):
+ self.starting_index += 1
+ return self.starting_index
+
+ def __len__(self):
+ return len(self.lab2ind)
+
+ def save(self, path, device=None):
+ with open(path, 'wb') as f:
+ pickle.dump(self.lab2ind, f)
+
+ def load(self, path, device=None):
+ self.lab2ind = {}
+ self.ind2lab = {}
+ with open(path, 'rb') as f:
+ self.lab2ind = pickle.load(f)
+ for label in self.lab2ind:
+ self.ind2lab[self.lab2ind[label]] = label
+
+
+class SpkVeriAug(object):
+ def __init__(
+ self,
+ aug_prob: float = 0.0,
+ noise_file: str = None,
+ reverb_file: str = None,
+ ):
+ self.aug_prob = aug_prob
+ if aug_prob > 0:
+ self.add_noise = NoiseReverbCorrupter(
+ noise_prob=1.0,
+ noise_file=noise_file,
+ )
+ self.add_rir = NoiseReverbCorrupter(
+ reverb_prob=1.0,
+ reverb_file=reverb_file,
+ )
+ self.add_rir_noise = NoiseReverbCorrupter(
+ noise_prob=1.0,
+ reverb_prob=1.0,
+ noise_file=noise_file,
+ reverb_file=reverb_file,
+ )
+
+ self.augmentations = [self.add_noise, self.add_rir, self.add_rir_noise]
+
+ def __call__(self, wav):
+ sample_rate = 16000
+ if self.aug_prob > random.random():
+ aug = random.choice(self.augmentations)
+ wav = aug(wav, sample_rate)
+
+ return wav
+
+
+class FBank(object):
+ def __init__(self,
+ n_mels,
+ sample_rate,
+ mean_nor: bool = False,
+ ):
+ self.n_mels = n_mels
+ self.sample_rate = sample_rate
+ self.mean_nor = mean_nor
+
+ def __call__(self, wav, dither=0):
+ sr = 16000
+ assert sr==self.sample_rate
+ if len(wav.shape) == 1:
+ wav = wav.unsqueeze(0)
+ # select single channel
+ if wav.shape[0] > 1:
+ wav = wav[0, :]
+ wav = wav.unsqueeze(0)
+ assert len(wav.shape) == 2 and wav.shape[0]==1
+ feat = Kaldi.fbank(wav, num_mel_bins=self.n_mels,
+ sample_frequency=sr, dither=dither)
+ # feat: [T, N]
+ if self.mean_nor:
+ feat = feat - feat.mean(0, keepdim=True)
+ return feat
diff --git a/speaker_diarization/local/utils/builder.py b/speaker_diarization/local/utils/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9117224a81a64d7084ff030d347e73258504bef6
--- /dev/null
+++ b/speaker_diarization/local/utils/builder.py
@@ -0,0 +1,98 @@
+import re
+import importlib
+from speaker_diarization.local.utils.config import Config
+
+
+def dynamic_import(import_path):
+ module_name, obj_name = import_path.rsplit('.', 1)
+ m = importlib.import_module(module_name)
+ return getattr(m, obj_name)
+
+def is_ref_type(value: str):
+ assert isinstance(value, str), 'Input value is not a str.'
+ if re.match('^<[a-zA-Z]\w*>$', value):
+ return True
+ else:
+ return False
+
+def is_built(ins):
+ if isinstance(ins, dict):
+ if 'obj' in ins and 'args' in ins:
+ return False
+ for i in ins.values():
+ if not is_built(i):
+ return False
+ elif isinstance(ins, str):
+ if '/' in ins: # reference may exist in a path string.
+ inss = ins.split('/')
+ return is_built(inss)
+ elif is_ref_type(ins):
+ return False
+ elif isinstance(ins, list):
+ for i in ins:
+ if not is_built(i):
+ return False
+ return True
+
+def deep_build(ins, config, build_space: set = None):
+ if is_built(ins):
+ return ins
+
+ if build_space is None:
+ build_space = set()
+
+ if isinstance(ins, list):
+ for i in range(len(ins)):
+ ins[i] = deep_build(ins[i], config, build_space)
+ return ins
+ elif isinstance(ins, dict):
+ if 'obj' in ins and 'args' in ins: # return a instantiated module.
+ obj = ins['obj']
+ args = ins['args']
+ assert isinstance(args, dict), f"Args for {obj} must be a dict."
+ args = deep_build(args, config, build_space)
+
+ module_cls = dynamic_import(obj)
+ mm = module_cls(**args)
+ return mm
+ else: # return a nomal dict.
+ for k in ins:
+ ins[k] = deep_build(ins[k], config, build_space)
+ return ins
+ elif isinstance(ins, str):
+ if '/' in ins: # reference may exist in a path string.
+ inss = ins.split('/')
+ inss = deep_build(inss, config, build_space)
+ ins = '/'.join(inss)
+ return ins
+ elif is_ref_type(ins):
+ ref = ins[1:-1]
+ if ref in build_space:
+ raise ValueError("Cross referencing is not allowed in config.")
+ build_space.add(ref)
+
+ if isinstance(config, dict):
+ if ref not in config:
+ raise AssertionError(f"Key name {ins} not found in config.")
+ attr = config[ref]
+ else:
+ if not hasattr(config, ref):
+ raise AssertionError(f"Key name {ins} not found in config.")
+ attr = getattr(config, ref)
+
+ attr = deep_build(attr, config, build_space)
+
+ if isinstance(config, dict):
+ config[ref] = attr
+ else:
+ setattr(config, ref, attr)
+
+ build_space.remove(ref)
+ return attr
+ else:
+ return ins
+ else:
+ return ins
+
+def build(name: str, config: Config):
+ return deep_build(f"<{name}>", config)
diff --git a/speaker_diarization/local/utils/config.py b/speaker_diarization/local/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fdceaec5109569ab1db2b929e783849bf57b895
--- /dev/null
+++ b/speaker_diarization/local/utils/config.py
@@ -0,0 +1,62 @@
+import os
+import yaml
+
+class Config(object):
+ def __init__(self, conf_dict):
+ for key, value in conf_dict.items():
+ self.__dict__[key] = value
+
+
+def convert_to_yaml(overrides):
+ """Convert args to yaml for overrides"""
+ yaml_string = ""
+
+ # Handle '--arg=val' type args
+ joined_args = "=".join(overrides)
+ split_args = joined_args.split("=")
+
+ for arg in split_args:
+ if arg.startswith("--"):
+ yaml_string += "\n" + arg[len("--") :] + ":"
+ else:
+ yaml_string += " " + arg
+
+ return yaml_string.strip()
+
+
+def yaml_config_loader(conf_file, overrides=None):
+ with open(conf_file, 'r') as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+
+ if overrides is not None:
+ config.update(yaml.load(overrides, Loader=yaml.FullLoader))
+
+ variables = {k: v for k, v in config.items() if isinstance(k, str) and not k.startswith('_') and isinstance(v, (int, float, str, bool))}
+
+ def resolve(x):
+ if isinstance(x, dict):
+ return {k: resolve(v) for k, v in x.items()}
+ elif isinstance(x, list):
+ return [resolve(item) for item in x]
+ elif isinstance(x, str) and x.startswith('<') and x.endswith('>'):
+ key = x[1:-1]
+ return variables.get(key, x)
+ else:
+ return x
+ return resolve(config)
+
+
+def build_config(config_file, overrides=None, copy=False):
+ if config_file.endswith(".yaml"):
+ if overrides is not None:
+ overrides = convert_to_yaml(overrides)
+ conf_dict = yaml_config_loader(config_file, overrides)
+ if copy and 'exp_dir' in conf_dict:
+ os.makedirs(conf_dict['exp_dir'], exist_ok=True)
+ saved_path = os.path.join(conf_dict['exp_dir'], 'config.yaml')
+ with open(saved_path, 'w') as f:
+ f.write(yaml.dump(conf_dict))
+ else:
+ raise ValueError("Unknown config file format")
+
+ return Config(conf_dict)
diff --git a/speaker_diarization/local/utils/epoch.py b/speaker_diarization/local/utils/epoch.py
new file mode 100644
index 0000000000000000000000000000000000000000..760cf3dbf2786ddafb162a73202b1f2eea1cccd1
--- /dev/null
+++ b/speaker_diarization/local/utils/epoch.py
@@ -0,0 +1,62 @@
+import logging
+logger = logging.getLogger(__name__)
+
+class EpochLogger(object):
+ def __init__(self, save_file, precision=2):
+ self.save_file = save_file
+ self.precision = precision
+
+ def item_to_string(self, key, value, prefix=None):
+ if isinstance(value, float) and 1.0 < value < 100.0:
+ value = f"{value:.{self.precision}f}"
+ elif isinstance(value, float):
+ value = f"{value:.{self.precision}e}"
+ if prefix is not None:
+ key = f"{prefix} {key}"
+ return f"{key}: {value}"
+
+ def stats_to_string(self, stats, prefix=None):
+ return ", ".join(
+ [self.item_to_string(k, v, prefix) for k, v in stats.items()]
+ )
+
+ def log_stats(
+ self,
+ stats_meta,
+ stats,
+ stage='train',
+ verbose=True,
+ ):
+ string = self.stats_to_string(stats_meta)
+ if stats is not None:
+ string += " - " + self.stats_to_string(stats, stage)
+
+ with open(self.save_file, "a") as fw:
+ print(string, file=fw)
+ if verbose:
+ logger.info(string)
+
+
+class EpochCounter(object):
+ def __init__(self, limit):
+ self.current = 0
+ self.limit = limit
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.current < self.limit:
+ self.current += 1
+ logger.info(f"Going into epoch {self.current}")
+ return self.current
+ raise StopIteration
+
+ def save(self, path, device=None):
+ with open(path, "w") as f:
+ f.write(str(self.current))
+
+ def load(self, path, device=None):
+ with open(path) as f:
+ saved_value = int(f.read())
+ self.current = saved_value
diff --git a/speaker_diarization/local/utils/fileio.py b/speaker_diarization/local/utils/fileio.py
new file mode 100644
index 0000000000000000000000000000000000000000..e958fac2160cd54327474a3949240dc0ec0283ef
--- /dev/null
+++ b/speaker_diarization/local/utils/fileio.py
@@ -0,0 +1,126 @@
+import csv
+import yaml
+import codecs
+import json
+import torch
+import torchaudio
+import numpy as np
+
+
+def load_yaml(yaml_path):
+ with open(yaml_path) as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+ return config
+
+
+def load_data_csv(fpath):
+ with open(fpath, newline="") as f:
+ result = {}
+ reader = csv.DictReader(f, skipinitialspace=True)
+ for row in reader:
+ if 'ID' not in row:
+ raise KeyError(
+ "CSV file has to have an 'ID' field, with unique ids for all data points."
+ )
+
+ data_id = row["ID"]
+ del row["ID"]
+
+ if data_id in result:
+ raise ValueError(f"Duplicate id: {data_id}")
+ result[data_id] = row
+ return result
+
+
+def load_data_list(fpath):
+ with open(fpath) as f:
+ rows = [i.strip() for i in f.readlines()]
+ result = {idx: row for idx, row in enumerate(rows)}
+ return result
+
+
+def load_wav_scp(fpath):
+ with open(fpath) as f:
+ rows = [i.strip() for i in f.readlines()]
+ result = {i.split()[0]: i.split()[1] for i in rows}
+ return result
+
+
+def load_json_file(json_file):
+ with codecs.open(json_file, "r", encoding="utf-8") as fr:
+ data_dict = json.load(fr)
+ return data_dict
+
+
+def load_trans7time_list(filename):
+ """
+ trans7time: (spk_id, st, ed, content)
+ """
+ with open(filename, "r") as fr:
+ trans7time_list = []
+ lines = fr.readlines()
+ for line in lines:
+ trans7time_list.append(line.strip().split())
+ result_trans7time_list = []
+ for index, item in enumerate(trans7time_list):
+ if len(item) <= 2:
+ raise ValueError(f"filename {filename}: item - {index} = {item}")
+ if len(item) == 3:
+ st = float(item[1])
+ ed = float(item[2])
+ result_trans7time_list.append((
+ item[0], st, ed, ""
+ ))
+ else:
+ result_trans7time_list.append((
+ item[0], float(item[1]), float(item[2]), "".join(item[3:])
+ ))
+ return result_trans7time_list
+
+
+def write_json_file(json_file, data):
+ assert str(json_file).endswith(".json") or str(json_file).endswith(".JSON")
+ with codecs.open(json_file, "w", encoding="utf-8") as fw:
+ json.dump(data, fw, indent=2, ensure_ascii=False)
+
+
+def write_wav_scp(fpath, wav_scp):
+ with open(fpath, "w") as f:
+ for key, value in wav_scp.items():
+ f.write(f"{key} {value}\n")
+
+
+def write_trans7time_list(fpath, trans7time_list):
+ """
+ trans7time_list: [(spk_id, start_time, end_time, text)]
+ """
+ with open(fpath, 'w') as fw:
+ for spk_id, start_time, end_time, text in trans7time_list:
+ text = text.replace("\n", "").replace("\r", "")
+ fw.write(f'{spk_id} {start_time} {end_time} {text}\n')
+
+def load_audio(input, ori_fs=None, obj_fs=None):
+ if isinstance(input, str):
+ wav, fs = torchaudio.load(input)
+ wav = wav.mean(dim=0, keepdim=True)
+ if obj_fs is not None and fs != obj_fs:
+ wav = torchaudio.functional.resample(wav, orig_freq=fs, new_freq=obj_fs)
+ return wav
+ elif isinstance(input, np.ndarray) or isinstance(input, torch.Tensor):
+ wav = torch.from_numpy(input) if isinstance(input, np.ndarray) else input
+ if wav.dtype in (torch.int16, torch.int32, torch.int64):
+ wav = wav.type(torch.float32)
+ wav = wav / 32768
+ wav = wav.type(torch.float32)
+ assert wav.ndim <= 2
+ if wav.ndim == 2:
+ if wav.shape[0] > wav.shape[1]:
+ wav = torch.transpose(wav, 0, 1)
+ wav = wav.mean(dim=0, keepdim=True)
+ if wav.ndim == 1:
+ wav = wav.unsqueeze(0)
+ if ori_fs is not None and obj_fs is not None and ori_fs!=obj_fs:
+ wav = torchaudio.functional.resample(wav, orig_freq=ori_fs, new_freq=obj_fs)
+ return wav
+ else:
+ return input
diff --git a/speaker_diarization/local/utils/score_metrics.py b/speaker_diarization/local/utils/score_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebd6ce11629ce23ddfb99ff30781a4174da012f8
--- /dev/null
+++ b/speaker_diarization/local/utils/score_metrics.py
@@ -0,0 +1,188 @@
+"""
+This script computes the official performance metrics for the NIST 2016 SRE.
+The metrics include EER and DCFs (min/act).
+"""
+
+__author__ = "Omid Sadjadi"
+__email__ = "omid.sadjadi@nist.gov"
+__version__ = "4.1"
+
+import numpy as np
+from scipy.stats import norm
+import matplotlib.pyplot as plt
+import sys
+
+
+def compute_norm_counts(scores, edges, wghts=None):
+ """ computes normalized (and optionally weighted) score counts for the
+ bin edges.
+ """
+
+ if scores.size > 0:
+ score_counts = np.histogram(scores, bins=edges,
+ weights=wghts)[0].astype('f')
+ norm_counts = np.cumsum(score_counts) / score_counts.sum()
+ else:
+ norm_counts = None
+ return norm_counts
+
+
+def compute_pmiss_pfa(scores, labels, weights=None):
+ """ computes false positive rate (FPR) and false negative rate (FNR)
+ given trial socres and their labels. A weights option is also provided to
+ equalize the counts over score partitions (if there is such partitioning).
+ """
+
+ tgt_scores = scores[labels == 1] # target trial scores
+ imp_scores = scores[labels == 0] # impostor trial scores
+
+ resol = max(
+ [np.count_nonzero(labels == 0),
+ np.count_nonzero(labels == 1), 1.e6])
+ edges = np.linspace(np.min(scores), np.max(scores), resol)
+
+ if weights is not None:
+ tgt_weights = weights[labels == 1]
+ imp_weights = weights[labels == 0]
+ else:
+ tgt_weights = None
+ imp_weights = None
+
+ fnr = compute_norm_counts(tgt_scores, edges, tgt_weights)
+ fpr = 1 - compute_norm_counts(imp_scores, edges, imp_weights)
+
+ return fnr, fpr
+
+
+def compute_pmiss_pfa_rbst(scores, labels, weights=None):
+ """ computes false positive rate (FPR) and false negative rate (FNR)
+ given trial socres and their labels. A weights option is also provided to
+ equalize the counts over score partitions (if there is such partitioning).
+ """
+
+ sorted_ndx = np.argsort(scores)
+ labels = labels[sorted_ndx]
+ if weights is not None:
+ weights = weights[sorted_ndx]
+ else:
+ weights = np.ones((labels.shape), dtype='f8')
+
+ tgt_wghts = weights * (labels == 1).astype('f8')
+ imp_wghts = weights * (labels == 0).astype('f8')
+
+ fnr = np.cumsum(tgt_wghts) / np.sum(tgt_wghts)
+ fpr = 1 - np.cumsum(imp_wghts) / np.sum(imp_wghts)
+ return fnr, fpr
+
+
+def compute_eer(fnr, fpr, scores=None):
+ """ computes the equal error rate (EER) given FNR and FPR values calculated
+ for a range of operating points on the DET curve
+ """
+
+ diff_pm_fa = fnr - fpr
+ x1 = np.flatnonzero(diff_pm_fa >= 0)[0]
+ x2 = np.flatnonzero(diff_pm_fa < 0)[-1]
+ a = (fnr[x1] - fpr[x1]) / (fpr[x2] - fpr[x1] - (fnr[x2] - fnr[x1]))
+
+ if scores is not None:
+ score_sort = np.sort(scores)
+ return fnr[x1] + a * (fnr[x2] - fnr[x1]), score_sort[x1]
+
+ return fnr[x1] + a * (fnr[x2] - fnr[x1])
+
+
+def compute_c_norm(fnr, fpr, p_target, c_miss=1, c_fa=1):
+ """ computes normalized minimum detection cost function (DCF) given
+ the costs for false accepts and false rejects as well as a priori
+ probability for target speakers
+ """
+
+ c_det = min(c_miss * fnr * p_target + c_fa * fpr * (1 - p_target))
+ c_def = min(c_miss * p_target, c_fa * (1 - p_target))
+
+ return c_det / c_def
+
+
+def compute_c_dcf(fnr, fpr, p_target, c_miss=1, c_fa=1):
+ """ computes normalized minimum detection cost function (DCF) given
+ the costs for false accepts and false rejects as well as a priori
+ probability for target speakers
+ """
+
+ c_det = min(c_miss * fnr * p_target + c_fa * fpr * (1 - p_target))
+
+ return c_det
+
+
+def plot_det_curve(fnr, fpr, save_path=None):
+ """ plots the detection error trade-off (DET) curve
+ """
+
+ p_miss = norm.ppf(fnr)
+ p_fa = norm.ppf(fpr)
+
+ xytick = [
+ 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1,
+ 0.2, 0.4
+ ]
+ xytick_labels = map(str, [x * 100 for x in xytick])
+
+ plt.plot(p_fa, p_miss, 'r')
+ plt.xticks(norm.ppf(xytick), xytick_labels)
+ plt.yticks(norm.ppf(xytick), xytick_labels)
+ plt.xlim(norm.ppf([0.00051, 0.5]))
+ plt.ylim(norm.ppf([0.00051, 0.5]))
+ plt.xlabel("false-alarm rate [%]", fontsize=12)
+ plt.ylabel("false-reject rate [%]", fontsize=12)
+ eer = compute_eer(fnr, fpr)
+ plt.plot(norm.ppf(eer), norm.ppf(eer), 'o')
+ plt.annotate(
+ "EER = %.2f%%" % (eer * 100),
+ xy=(norm.ppf(eer), norm.ppf(eer)),
+ xycoords='data',
+ xytext=(norm.ppf(eer + 0.05), norm.ppf(eer + 0.05)),
+ textcoords='data',
+ arrowprops=dict(arrowstyle="-|>",
+ connectionstyle="arc3, rad=+0.2",
+ fc="w"),
+ size=12,
+ va='center',
+ ha='center',
+ bbox=dict(boxstyle="round4", fc="w"),
+ )
+ plt.grid()
+ if save_path is not None:
+ plt.savefig(save_path)
+ plt.clf()
+ else:
+ plt.show()
+
+
+def compute_equalized_scores(max_tar_imp_counts, sc, labs, masks):
+
+ count_weights = []
+ scores = []
+ labels = []
+ for ix in range(len(masks)):
+ amask = masks[ix]
+ alabs = labs[amask]
+ num_targets = np.count_nonzero(alabs == 1)
+ num_non_targets = alabs.size - num_targets
+ labels.append(alabs)
+ scores.append(sc[amask])
+ tar_weight = max_tar_imp_counts[
+ 0] / num_targets if num_targets > 0 else 0
+ imp_weight = max_tar_imp_counts[
+ 1] / num_non_targets if num_non_targets > 0 else 0
+
+ acount_weights = np.empty(alabs.shape, dtype='f')
+ acount_weights[alabs == 1] = np.array([tar_weight] * num_targets)
+ acount_weights[alabs == 0] = np.array([imp_weight] * num_non_targets)
+ count_weights.append(acount_weights)
+
+ scores = np.hstack(scores)
+ labels = np.hstack(labels)
+ count_weights = np.hstack(count_weights)
+
+ return scores, labels, count_weights
diff --git a/speaker_diarization/local/utils/utils.py b/speaker_diarization/local/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..48826c3f0004d710f3746514625c608939b80796
--- /dev/null
+++ b/speaker_diarization/local/utils/utils.py
@@ -0,0 +1,234 @@
+import sys
+import os
+import random
+import logging
+import yaml
+import numpy as np
+from contextlib import contextmanager
+import torch
+from speaker_diarization.local.utils.fileio import load_yaml
+
+def parse_config(config_file):
+ if config_file.endwith('.yaml'):
+ config = load_yaml(config_file)
+ else:
+ raise Exception("Other formats not currently supported.")
+ return config
+
+def set_seed(seed=0):
+ np.random.seed(seed)
+ random.seed(seed)
+
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = True
+
+def get_logger(fpath=None, fmt=None):
+ if fmt is None:
+ fmt = "%(asctime)s - %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=fmt)
+ logger = logging.getLogger(__name__)
+ logger.setLevel(logging.INFO)
+ if fpath is not None:
+ handler = logging.FileHandler(fpath)
+ handler.setFormatter(logging.Formatter(fmt))
+ logger.addHandler(handler)
+ return logger
+
+def get_utt2spk_dict(utt2spk, suffix=''):
+ temp_dict={}
+ with open(utt2spk,'r') as utt2spk_f:
+ lines = utt2spk_f.readlines()
+ for i in lines:
+ i=i.strip().split()
+ if suffix == '' or suffix is None:
+ key_i = i[0]
+ value_spk = i[1]
+ else:
+ key_i = i[0]+'_'+suffix
+ value_spk = i[1]+'_'+suffix
+ if key_i in temp_dict:
+ raise ValueError('The key must be unique.')
+ temp_dict[key_i]=value_spk
+ return temp_dict
+
+def get_wavscp_dict(wavscp, suffix=''):
+ temp_dict={}
+ with open(wavscp, 'r') as wavscp_f:
+ lines = wavscp_f.readlines()
+ for i in lines:
+ i=i.strip().split()
+ if suffix == '' or suffix is None:
+ key_i = i[0]
+ else:
+ key_i = i[0]+'_'+suffix
+ value_path = i[1]
+ if key_i in temp_dict:
+ raise ValueError('The key must be unique.')
+ temp_dict[key_i]=value_path
+ return temp_dict
+
+def accuracy(x, target):
+ # x: [*, C], target: [*,]
+ _, pred = x.topk(1)
+ pred = pred.squeeze(-1)
+ acc = pred.eq(target).float().mean()
+ return acc*100
+
+def average_precision(scores, labels):
+ # scores: [N, ], labels: [N, ]
+ if torch.is_tensor(scores):
+ scores = scores.cpu().numpy()
+ if torch.is_tensor(labels):
+ labels = labels.cpu().numpy()
+ if isinstance(scores, list):
+ scores = np.array(scores)
+ if isinstance(labels, list):
+ labels = np.array(labels)
+ assert isinstance(scores, np.ndarray) and isinstance(
+ labels, np.ndarray), 'Input should be numpy.array.'
+ assert len(scores.shape)==1 and len(labels.shape)==1 and \
+ scores.shape[0]==labels.shape[0]
+
+ sort_idx = np.argsort(scores)[::-1]
+ scores = scores[sort_idx]
+ labels = labels[sort_idx]
+ tp_count = (labels==1).sum()
+ tp = labels.cumsum()
+ recall = tp / tp_count
+ precision = tp / (np.arange(len(labels)) + 1)
+
+ recall = np.concatenate([[0], recall, [1]])
+ precision = np.concatenate([[0], precision, [0]])
+
+ # Smooth precision to be monotonically decreasing.
+ for i in range(len(precision) - 2, -1, -1):
+ precision[i] = np.maximum(precision[i], precision[i + 1])
+
+ indices = np.where(recall[1:] != recall[:-1])[0] + 1
+ average_precision = np.sum(
+ (recall[indices] - recall[indices - 1]) * precision[indices])
+ return average_precision
+
+def load_params(dst_model, src_state, strict=True):
+ dst_state = {}
+ for k in src_state:
+ if k.startswith('module'):
+ dst_state[k[7:]] = src_state[k]
+ else:
+ dst_state[k] = src_state[k]
+ dst_model.load_state_dict(dst_state, strict=strict)
+ return dst_model
+
+def merge_vad(vad1: list, vad2: list):
+ intervals = vad1 + vad2
+ intervals.sort(key=lambda x: x[0])
+ merged = []
+ for interval in intervals:
+ if not merged or merged[-1][1] < interval[0]:
+ merged.append(interval)
+ else:
+ merged[-1][1] = max(merged[-1][1], interval[1])
+ return merged
+
+class AverageMeter(object):
+ def __init__(self, name, fmt=':f'):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+class AverageMeters(object):
+ def __init__(self, names: list = None, fmts: list = None):
+ self.cont = dict()
+ if names is None or fmts is None:
+ return
+ for name, fmt in zip(names, fmts):
+ self.cont[name] = AverageMeter(name, fmt)
+
+ def add(self, name, fmt=':f'):
+ self.cont[name] = AverageMeter(name, fmt)
+
+ def update(self, name, val, n=1):
+ self.cont[name].update(val, n)
+
+ def avg(self, name):
+ return self.cont[name].avg
+
+ def val(self, name):
+ return self.cont[name].val
+
+ def __str__(self):
+ return '\t'.join([str(s) for s in self.cont.values()])
+
+
+class ProgressMeter(object):
+ def __init__(self, num_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(self.meters)]
+ return '\t'.join(entries)
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = '{:' + str(num_digits) + 'd}'
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
+@contextmanager
+def silent_print():
+ original_stdout = sys.stdout
+ sys.stdout = open(os.devnull, 'w')
+ try:
+ yield
+ finally:
+ sys.stdout.close()
+ sys.stdout = original_stdout
+
+def download_model_from_modelscope(model_id, model_revision=None, cache_dir=None):
+ from modelscope.hub.snapshot_download import snapshot_download
+ if cache_dir is None:
+ cache_dir = snapshot_download(
+ model_id,
+ revision=model_revision,
+ )
+ else:
+ cfg_file = os.path.join(cache_dir, model_id, 'configuration.json')
+ if not os.path.exists(cfg_file):
+ cache_dir = snapshot_download(
+ model_id,
+ revision=model_revision,
+ cache_dir=cache_dir,
+ )
+ else:
+ cache_dir = os.path.join(cache_dir, model_id)
+ return cache_dir
+
+def circle_pad(x: torch.Tensor, target_len, dim=0):
+ xlen = x.shape[dim]
+ if xlen >= target_len:
+ return x
+ n = int(np.ceil(target_len/xlen))
+ xcat = torch.cat([x for _ in range(n)], dim=dim)
+ return torch.narrow(xcat, dim, 0, target_len)
diff --git a/speaker_diarization/local/vision_processer.py b/speaker_diarization/local/vision_processer.py
new file mode 100644
index 0000000000000000000000000000000000000000..588b86014759cb9d43beb47d774b82eb990bb2a0
--- /dev/null
+++ b/speaker_diarization/local/vision_processer.py
@@ -0,0 +1,449 @@
+"""
+This script uses pretrained models to perform speaker visual embeddings extracting.
+This script use following open source models:
+ 1. Face detection: https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
+ 2. Active speaker detection: TalkNet, https://github.com/TaoRuijie/TalkNet-ASD
+ 3. Face quality assessment: https://modelscope.cn/models/iic/cv_manual_face-quality-assessment_fqa
+ 4. Face recognition: https://modelscope.cn/models/iic/cv_ir101_facerecognition_cfglint
+ 5. Lip detection: https://huggingface.co/pyannote/segmentation-3.0
+Processing pipeline:
+ 1. Face detection (input: video frames)
+ 2. Active speaker detection (input: consecutive face frames, audio)
+ 3. Face quality assessment (input: video frames)
+ 4. Face recognition (input: video frames)
+ 5. Lip detection (input: video frames)
+"""
+
+import numpy as np
+from scipy.io import wavfile
+from scipy.interpolate import interp1d
+import time, torch, cv2, pickle, gc, python_speech_features
+from scipy import signal
+
+
+class VisionProcesser():
+ def __init__(
+ self,
+ video_file_path,
+ audio_file_path,
+ audio_vad,
+ out_feat_path,
+ visual_models,
+ conf=None,
+ out_video_path=None
+ ):
+ # read audio data and check the samplerate.
+ fs, audio = wavfile.read(audio_file_path)
+ if len(audio.shape) > 1:
+ audio = audio.mean(axis=1)
+ duration = audio.shape[0] / fs
+ target_length = int(duration * 16000)
+ self.audio = signal.resample(audio, target_length)
+
+ # convert time interval to integer sampling point interval.
+ audio_vad = [[int(i*16000), int(j*16000)] for (i, j) in audio_vad]
+ self.video_path = video_file_path
+
+ # read video data
+ self.cap = cv2.VideoCapture(video_file_path)
+ w = self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)
+ h = self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
+ self.count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.fps = self.cap.get(cv2.CAP_PROP_FPS)
+ print('video %s info: w: {}, h: {}, count: {}, fps: {}'.format(w, h, self.count, self.fps) % self.video_path)
+
+ # initial vision models
+ self.visual_models = visual_models
+
+ # store facial feats along with the necessary information.
+ self.active_facial_embs = {
+ 'frameI':np.empty((0,), dtype=np.int32),
+ 'feat':np.empty((0, 512), dtype=np.float32),
+ 'faceI': np.empty((0,), dtype=np.int32),
+ 'face': [],
+ 'face_bbox': np.empty((0, 4), dtype=np.int32),
+ 'lip': [],
+ 'lip_bbox': np.empty((0, 4), dtype=np.int32),
+ }
+
+ self.audio_vad = audio_vad
+ self.out_video_path = out_video_path
+ self.out_feat_path = out_feat_path
+
+ self.min_track = conf['min_track']
+ self.num_failed_det = conf['num_failed_det']
+ self.crop_scale = conf['crop_scale']
+ self.min_face_size = conf['min_face_size']
+ self.face_det_stride = conf['face_det_stride']
+ self.shot_stride = conf['shot_stride']
+
+ if self.out_video_path is not None:
+ # save the active face detection results video (for debugging).
+ self.v_out = cv2.VideoWriter(out_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (int(w), int(h)))
+
+ # record the time spent by each module.
+ self.elapsed_time = {'faceTime':[], 'trackTime':[], 'cropTime':[],'asdTime':[], 'featTime':[], 'totalTime':[]}
+
+
+ def run(self):
+ frames, face_det_frames = [], []
+ for [audio_sample_st, audio_sample_ed] in self.audio_vad:
+ frame_st, frame_ed = int(audio_sample_st/640), int(audio_sample_ed/640) # 16000采样率/640=25fps,转换为视频的25fps帧数
+ num_frames = frame_ed - frame_st + 1
+ # go to frame 'frame_st'.
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_st)
+ index = 0
+ for _ in range(num_frames):
+ ret, frame = self.cap.read()
+ if not ret:
+ break
+ if index % self.face_det_stride==0:
+ face_det_frames.append(frame)
+ frames.append(frame)
+ if (index + 1) % self.shot_stride==0:
+ audio = self.audio[(frame_st + index + 1 - self.shot_stride)*640:(frame_st + index + 1)*640]
+ self.process_one_shot(frames, face_det_frames, audio, frame_st + index + 1 - self.shot_stride)
+ frames, face_det_frames = [], []
+ index += 1
+ if len(frames) != 0:
+ audio = self.audio[(frame_st + index - len(frames))*640:(frame_st + index)*640]
+ self.process_one_shot(frames, face_det_frames, audio, frame_st + index - len(frames))
+ frames, face_det_frames = [], []
+
+ self.cap.release()
+ if self.out_video_path is not None:
+ self.v_out.release()
+
+ out_data = {
+ 'embeddings':self.active_facial_embs['feat'], # 'times': self.active_facial_embs['frameI']*0.04, # 25 fps
+ 'frameI': self.active_facial_embs['frameI'], # 说话人活跃的人脸帧索引
+ 'faceI': self.active_facial_embs['faceI'], # 存在人脸的帧索引
+ 'face': self.active_facial_embs['face'],
+ 'face_bbox': self.active_facial_embs['face_bbox'],
+ 'lip': self.active_facial_embs['lip'],
+ 'lip_bbox': self.active_facial_embs['lip_bbox'],
+ }
+ pickle.dump(out_data, open(self.out_feat_path, 'wb'))
+
+ # print elapsed time
+ all_elapsed_time = 0
+ for k in self.elapsed_time:
+ all_elapsed_time += sum(self.elapsed_time[k])
+ self.elapsed_time[k] = sum(self.elapsed_time[k])
+ elapsed_time_msg = 'The total time for %s is %.2fs, including' % (self.video_path, all_elapsed_time)
+ for k in self.elapsed_time:
+ elapsed_time_msg += ' %s %.2fs,'%(k, self.elapsed_time[k])
+ print(elapsed_time_msg[:-1]+'.')
+ try:
+ del out_data
+ except Exception:
+ pass
+
+ def process_one_shot(self, frames, face_det_frames, audio, frame_st=None):
+ curTime = time.time()
+ dets = self.face_detection(face_det_frames)
+ faceTime = time.time()
+
+ allTracks, vidTracks = [], []
+ allTracks.extend(self.track_shot(dets))
+ trackTime = time.time()
+
+ for ii, track in enumerate(allTracks):
+ vidTracks.append(self.crop_video(track, frames, audio))
+ cropTime = time.time()
+
+ scores = self.evaluate_asd(vidTracks)
+ asdTime = time.time()
+
+ active_facial_embs = self.evaluate_fr(frames, vidTracks, scores)
+ self.active_facial_embs['frameI'] = np.append(self.active_facial_embs['frameI'], active_facial_embs['frameI'] + frame_st)
+ self.active_facial_embs['feat'] = np.append(self.active_facial_embs['feat'], active_facial_embs['feat'], axis=0)
+ self.active_facial_embs['faceI'] = np.append(self.active_facial_embs['faceI'], active_facial_embs['faceI'] + frame_st)
+ self.active_facial_embs['face'].extend(active_facial_embs['face'])
+ self.active_facial_embs['face_bbox'] = np.vstack([self.active_facial_embs['face_bbox'], active_facial_embs['face_bbox']])
+ self.active_facial_embs['lip'].extend(active_facial_embs['lip'])
+ self.active_facial_embs['lip_bbox']= np.vstack([self.active_facial_embs['lip_bbox'], active_facial_embs['lip_bbox']])
+
+ featTime = time.time()
+ if self.out_video_path is not None:
+ self.visualization(frames, vidTracks, scores, active_facial_embs)
+
+ try:
+ del dets, allTracks, vidTracks, active_facial_embs
+ except Exception:
+ pass
+
+ self.elapsed_time['faceTime'].append(faceTime-curTime)
+ self.elapsed_time['trackTime'].append(trackTime-faceTime)
+ self.elapsed_time['cropTime'].append(cropTime-trackTime)
+ self.elapsed_time['asdTime'].append(asdTime-cropTime)
+ self.elapsed_time['featTime'].append(featTime-asdTime)
+ self.elapsed_time['totalTime'].append(featTime-curTime)
+
+ def face_detection(self, frames):
+ dets = []
+ for fidx, image in enumerate(frames):
+ image_input = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ bboxes, _, probs = self.visual_models.detect_faces(image_input, top_k=10, prob_threshold=0.9)
+ bboxes = torch.cat([bboxes, probs.reshape(-1, 1)], dim=-1)
+ dets.append([])
+ for bbox in bboxes:
+ frame_idex = fidx * self.face_det_stride
+ dets[-1].append({'frame':frame_idex, 'bbox':(bbox[:-1]).tolist(), 'conf':bbox[-1]})
+ return dets
+
+ def bb_intersection_over_union(self, boxA, boxB, evalCol=False):
+ # IOU Function to calculate overlap between two image
+ xA = max(boxA[0], boxB[0])
+ yA = max(boxA[1], boxB[1])
+ xB = min(boxA[2], boxB[2])
+ yB = min(boxA[3], boxB[3])
+ interArea = max(0, xB - xA) * max(0, yB - yA)
+ boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
+ boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
+ if evalCol == True:
+ iou = interArea / float(boxAArea)
+ else:
+ iou = interArea / float(boxAArea + boxBArea - interArea)
+ return iou
+
+ def track_shot(self, scene_faces):
+ # Face tracking
+ tracks = []
+ while True: # continuously search for consecutive faces.
+ track = []
+ for frame_faces in scene_faces:
+ for face in frame_faces:
+ if track == []:
+ track.append(face)
+ frame_faces.remove(face)
+ break
+ elif face['frame'] - track[-1]['frame'] <= self.num_failed_det: # the face does not interrupt for 'num_failed_det' frame.
+ iou = self.bb_intersection_over_union(face['bbox'], track[-1]['bbox'])
+ # minimum IOU between consecutive face.
+ if iou > 0.5:
+ track.append(face)
+ frame_faces.remove(face)
+ break
+ else:
+ break
+ if track == []:
+ break
+ elif len(track) > 1 and track[-1]['frame'] - track[0]['frame'] + 1 >= self.min_track:
+ frame_num = np.array([ f['frame'] for f in track ])
+ bboxes = np.array([np.array(f['bbox']) for f in track])
+ frameI = np.arange(frame_num[0], frame_num[-1]+1)
+ bboxesI = []
+ for ij in range(0, 4):
+ interpfn = interp1d(frame_num, bboxes[:,ij]) # missing boxes can be filled by interpolation.
+ bboxesI.append(interpfn(frameI))
+ bboxesI = np.stack(bboxesI, axis=1)
+ if max(np.mean(bboxesI[:,2]-bboxesI[:,0]), np.mean(bboxesI[:,3]-bboxesI[:,1])) > self.min_face_size: # need face size > min_face_size
+ tracks.append({'frame':frameI,'bbox':bboxesI})
+ return tracks
+
+ def crop_video(self, track, frames, audio):
+ # crop the face clips
+ crop_frames = []
+ dets = {'x':[], 'y':[], 's':[]}
+ for det in track['bbox']:
+ dets['s'].append(max((det[3]-det[1]), (det[2]-det[0]))/2)
+ dets['y'].append((det[1]+det[3])/2) # crop center x
+ dets['x'].append((det[0]+det[2])/2) # crop center y
+ for fidx, frame in enumerate(track['frame']):
+ cs = self.crop_scale
+ bs = dets['s'][fidx] # detection box size
+ bsi = int(bs * (1 + 2 * cs)) # pad videos by this amount
+ image = frames[frame]
+ frame = np.pad(image, ((bsi,bsi), (bsi,bsi), (0, 0)), 'constant', constant_values=(110, 110))
+ my = dets['y'][fidx] + bsi # BBox center Y
+ mx = dets['x'][fidx] + bsi # BBox center X
+ face = frame[int(my-bs):int(my+bs*(1+2*cs)),int(mx-bs*(1+cs)):int(mx+bs*(1+cs))]
+ crop_frames.append(cv2.resize(face, (224, 224)))
+ cropaudio = audio[track['frame'][0]*640:(track['frame'][-1]+1)*640]
+ return {'track':track, 'proc_track':dets, 'data':[crop_frames, cropaudio]}
+
+ def evaluate_asd(self, tracks):
+ # active speaker detection by pretrained TalkNet
+ all_scores = []
+ for ins in tracks:
+ video, audio = ins['data']
+ audio_feature = python_speech_features.mfcc(audio, 16000, numcep = 13, winlen = 0.025, winstep = 0.010)
+ video_feature = []
+ for frame in video:
+ face = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ h0, w0 = face.shape
+ interp = cv2.INTER_CUBIC if (h0 < 224 or w0 < 224) else cv2.INTER_AREA
+ face = cv2.resize(face, (224,224), interpolation=interp)
+ # face = cv2.resize(face, (224,224))
+ face = face[int(112-(112/2)):int(112+(112/2)), int(112-(112/2)):int(112+(112/2))]
+ video_feature.append(face)
+ video_feature = np.array(video_feature)
+ length = min((audio_feature.shape[0] - audio_feature.shape[0] % 4) / 100, video_feature.shape[0] / 25)
+ audio_feature = audio_feature[:int(round(length * 100)),:]
+ video_feature = video_feature[:int(round(length * 25)),:,:]
+ audio_feature = np.expand_dims(audio_feature, axis=0).astype(np.float32)
+ video_feature = np.expand_dims(video_feature, axis=0).astype(np.float32)
+ score = self.visual_models.asd_score(audio_feature, video_feature)
+ all_score = np.asarray(score, dtype=np.float32)
+ all_scores.append(all_score)
+ try:
+ del audio_feature, video_feature, score
+ except Exception:
+ pass
+ return all_scores
+
+
+ def evaluate_fr(self, frames, tracks, scores):
+ SMOOTH_W = 4
+ ON_THRESHOLD = 0.0
+ OFF_THRESHOLD = -0.5
+ QUALITY_HIGH = 0.0
+ QUALITY_LOW = -0.3
+
+ # 先平滑每个 track 的 scores
+ smooth_scores_all = []
+ for score in scores:
+ s = np.asarray(score).flatten()
+ if s.size == 0:
+ smooth_scores_all.append(s)
+ continue
+ # 中值 + 简单移动平均
+ s_med = signal.medfilt(s, kernel_size=5 if len(s)>=5 else 3)
+ k = np.ones(5)/5
+ s_avg = np.convolve(s_med, k, mode='same')
+ smooth_scores_all.append(s_avg)
+
+ # aggregate faces per frame
+ faces = [[] for _ in range(len(frames))]
+ for tidx, track in enumerate(tracks):
+ score = smooth_scores_all[tidx]
+ for fidx, frame in enumerate(track['track']['frame'].tolist()):
+ s = score[max(fidx - SMOOTH_W, 0): min(fidx + SMOOTH_W+1, len(score))]
+ s = float(np.mean(s))
+ bbox = track['track']['bbox'][fidx]
+ bbox = bbox.astype(np.int32)
+ face = frames[frame][max(bbox[1],0):min(bbox[3],frames[frame].shape[0]),
+ max(bbox[0],0):min(bbox[2],frames[frame].shape[1])]
+ faces[frame].append({'track':tidx, 'score':s, 'facedata':face, 'bbox': bbox})
+
+ # per-frame decision
+ active_facial_embs = {
+ 'frameI': [],
+ 'trackI': [],
+ 'faceI': [],
+ 'face': [],
+ 'face_bbox': [],
+ 'feat': [],
+ 'lip': [],
+ 'lip_bbox': [],
+ }
+ # 这里做简单 per-frame decision: 选 score 最大的
+ for fidx in range(0, len(faces), max(1, self.face_det_stride)):
+ if len(faces[fidx]) == 0:
+ continue
+ # choose best candidate by score
+ best = max(faces[fidx], key=lambda x: x['score'])
+ res = self.visual_models.detect_lip(best['facedata'])
+ # 如果没有检测到嘴唇,跳过,会筛去低质量像素的人脸
+ if res is None or res.get('lip_crop') is None:
+ continue
+ # 只要该帧检测到一张或者多种人脸,就保存一个最有可能是说话人(best['facedata'])的人脸(不管说不说话)
+ active_facial_embs['faceI'].append(fidx)
+ active_facial_embs['face'].append(best['facedata']) # BGR ndarray
+ active_facial_embs['lip'].append(res.get('lip_crop')) # BGR ndarray
+ active_facial_embs['face_bbox'].append(best['bbox']) # 相对于整个一帧图片的脸的位置坐标
+ active_facial_embs['lip_bbox'].append(res.get('lip_bbox')) # 相对于脸框图的位置坐标
+ feature = self.visual_models.get_face_embedding(best['facedata'])
+ active_facial_embs['feat'].append(feature) # 完整面部特征
+
+
+ s = best['score']
+ if s < OFF_THRESHOLD:
+ continue
+ # 人脸质量评估(可选,开启后只会筛选评分更高的人脸帧)
+ # face_q_score = self.visual_models.face_quality_score(best['facedata'])
+ # if (face_q_score >= QUALITY_HIGH) or (face_q_score >= QUALITY_LOW and s >= ON_THRESHOLD):
+ if s >= OFF_THRESHOLD:
+ # feature, feature_normalized = self.visual_models.get_face_embedding(best['facedata']) # 仅保留模型认为在说话帧
+ active_facial_embs['frameI'].append(fidx)
+ active_facial_embs['trackI'].append(best['track'])
+
+ # 转 numpy
+ active_facial_embs['frameI'] = np.array(active_facial_embs['frameI'], dtype=np.int32)
+ active_facial_embs['trackI'] = np.array(active_facial_embs['trackI'], dtype=np.int32)
+ active_facial_embs['faceI'] = np.array(active_facial_embs['faceI'], dtype=np.int32)
+ active_facial_embs['face_bbox'] = np.array(active_facial_embs['face_bbox'], dtype=np.int32) if active_facial_embs['face_bbox'] else np.empty((0,4), np.int32)
+ active_facial_embs['lip_bbox'] = np.array(active_facial_embs['lip_bbox'], dtype=np.int32) if active_facial_embs['lip_bbox'] else np.empty((0,4), np.int32)
+ active_facial_embs['feat'] = np.vstack(active_facial_embs['feat']) if active_facial_embs['feat'] else np.empty((0,512), np.float32)
+ return active_facial_embs
+
+
+ def visualization(self, frames, tracks, scores, embs=None):
+ # 先聚合所有 track 在每帧的 bbox/score 信息(与原实现一致)
+ faces = [[] for _ in range(len(frames))]
+ for tidx, track in enumerate(tracks):
+ score = scores[tidx]
+ for fidx, frame in enumerate(track['track']['frame'].tolist()):
+ s = score[max(fidx - 2, 0): min(fidx + 3, len(score))] # 注意 len(score) 作为上界
+ s = np.mean(s)
+ faces[frame].append({'track':tidx, 'score':float(s),'bbox':track['track']['bbox'][fidx]})
+
+ # 构造已保存帧集合(相对于本 shot)
+ feat_set = set()
+ lip_bbox_dict = {} # 存储嘴唇边界框的字典
+ if embs is not None:
+ if 'frameI' in embs and embs['frameI'].size > 0:
+ trackI = embs.get('trackI')
+ feat_set = set((int(f), int(t)) for f, t in zip(embs['frameI'].tolist(), trackI.tolist()))
+
+ if 'lip_bbox' in embs and embs['lip_bbox'].size > 0:
+ for i, frame_idx in enumerate(embs['faceI']):
+ lip_bbox_dict[int(frame_idx)] = embs['lip_bbox'][i]
+
+ for fidx, image in enumerate(frames):
+ for face in faces[fidx]:
+ bbox = face['bbox']
+ x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+ # lip bbox
+ lip_bbox = None
+ if fidx in lip_bbox_dict:
+ lip_bbox = lip_bbox_dict[fidx]
+ lip_x1 = x1 + lip_bbox[0]
+ lip_y1 = y1 + lip_bbox[1]
+ lip_x2 = x1 + lip_bbox[2]
+ lip_y2 = y1 + lip_bbox[3]
+ if (fidx, face['track']) in feat_set:
+ # 绿色表示已保存, 蓝色表示嘴唇
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
+ if lip_bbox is not None:
+ cv2.rectangle(image, (lip_x1, lip_y1), (lip_x2, lip_y2), (255, 0, 0), 2)
+ txt = round(face['score'], 2)
+ cv2.putText(image, '%s'%(txt), (x1, max(y1-6,0)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1)
+ else:
+ # 红色表示未保存
+ cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 2)
+ if lip_bbox is not None:
+ cv2.rectangle(image, (lip_x1, lip_y1), (lip_x2, lip_y2), (255, 0, 0), 2)
+ txt = round(face['score'], 2)
+ cv2.putText(image, '%s'%(txt), (x1, max(y1-6,0)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255), 1)
+
+ # 写入视频
+ self.v_out.write(image)
+
+
+ def close(self):
+ try:
+ if hasattr(self, "active_facial_embs"):
+ for k, v in self.active_facial_embs.items():
+ if isinstance(v, np.ndarray):
+ del v
+ elif isinstance(v, list):
+ v.clear()
+ self.active_facial_embs.clear()
+ except Exception as e:
+ print(f"[WARN] Error while closing VisionProcesser: {e}")
+ gc.collect()
+
+ def __del__(self):
+ self.close()
\ No newline at end of file
diff --git a/speaker_diarization/local/vision_tools/active_speaker_detection.py b/speaker_diarization/local/vision_tools/active_speaker_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15b249d806d858396ac16adde838f3ae5c610a7
--- /dev/null
+++ b/speaker_diarization/local/vision_tools/active_speaker_detection.py
@@ -0,0 +1,40 @@
+import os
+import numpy as np
+import onnxruntime
+
+class ASDTalknet:
+ """
+ Active speaker detection with TalkNet pretrained model.
+ Reference:
+ - https://github.com/TaoRuijie/TalkNet-ASD
+ """
+ def __init__(self, onnx_dir, device='cpu', device_id=0):
+ onnx_file_name = os.path.join(onnx_dir, 'asd.onnx')
+ assert os.path.exists(onnx_file_name), '%s does not exist. Please check if it has been downloaded accurately.' % onnx_file_name
+ self.ort_net = self.create_net(onnx_file_name, device, device_id)
+
+ def __call__(self, inputA, inputV):
+ ort_inputs = {self.ort_net.get_inputs()[0].name:inputA, self.ort_net.get_inputs()[1].name:inputV}
+ scores = self.ort_net.run(None, ort_inputs)[0]
+ return scores
+
+ def create_net(self, onnx_file_name, device='cpu', device_id=0):
+ options = onnxruntime.SessionOptions()
+ # set op_num_threads
+ options.intra_op_num_threads = 8
+ options.inter_op_num_threads = 8
+ # set providers
+ providers = ['CPUExecutionProvider']
+ if device == 'cuda':
+ providers.insert(0, ('CUDAExecutionProvider', {'device_id': device_id}))
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, options, providers=providers)
+ return ort_session
+
+
+if __name__ == '__main__':
+ predictor = ASDTalknet('pretrained_models', 'cuda', 0)
+ inputA = np.random.randn(1, 100, 13).astype('float32')
+ inputV = np.random.randn(1, 25, 112, 112).astype('float32')
+ scores = predictor(inputA, inputV)
+ assert scores.shape == (25,)
+
\ No newline at end of file
diff --git a/speaker_diarization/local/vision_tools/api.py b/speaker_diarization/local/vision_tools/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f8d3df9d98350b0fd4bf327d2092d6eb9777f2b
--- /dev/null
+++ b/speaker_diarization/local/vision_tools/api.py
@@ -0,0 +1,308 @@
+"""
+Modified from face-alignment v1.4.1 api.py
+Original source: https://github.com/1adrianb/face-alignment
+License: BSD-3-Clause License
+"""
+import torch
+import warnings
+from enum import IntEnum
+from skimage import io
+import numpy as np
+from packaging import version
+from tqdm import tqdm
+
+from face_alignment.utils import *
+from face_alignment.folder_data import FolderData
+from face_alignment.detection import sfd
+
+class LandmarksType(IntEnum):
+ """Enum class defining the type of landmarks to detect.
+
+ ``TWO_D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
+ ``TWO_HALF_D`` - this points represent the projection of the 3D points into 3D
+ ``THREE_D`` - detect the points ``(x,y,z)``` in a 3D space
+
+ """
+ TWO_D = 1
+ TWO_HALF_D = 2
+ THREE_D = 3
+
+
+class NetworkSize(IntEnum):
+ # TINY = 1
+ # SMALL = 2
+ # MEDIUM = 3
+ LARGE = 4
+
+
+default_model_urls = {
+ '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip',
+ '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4-4a694010b9.zip',
+ 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth-6c4283c0e0.zip',
+}
+
+models_urls = {
+ '1.6': {
+ '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4_1.6-c827573f02.zip',
+ '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4_1.6-ec5cf40a1d.zip',
+ 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth_1.6-2aa3f18772.zip',
+ },
+ '1.5': {
+ '2DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/2DFAN4_1.5-a60332318a.zip',
+ '3DFAN-4': 'https://www.adrianbulat.com/downloads/python-fan/3DFAN4_1.5-176570af4d.zip',
+ 'depth': 'https://www.adrianbulat.com/downloads/python-fan/depth_1.5-bc10f98e39.zip',
+ },
+}
+
+
+class FaceAlignment:
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, net_path=None,
+ device='cuda', dtype=torch.float32, flip_input=False, face_detector_kwargs=None, verbose=False):
+ self.device = device
+ self.flip_input = flip_input
+ self.landmarks_type = landmarks_type
+ self.verbose = verbose
+ self.dtype = dtype
+
+ if version.parse(torch.__version__) < version.parse('1.5.0'):
+ raise ImportError(f'Unsupported pytorch version detected. Minimum supported version of pytorch: 1.5.0\
+ Either upgrade (recommended) your pytorch setup, or downgrade to face-alignment 1.2.0')
+
+ network_size = int(network_size)
+ pytorch_version = torch.__version__
+ if 'dev' in pytorch_version:
+ pytorch_version = pytorch_version.rsplit('.', 2)[0]
+ else:
+ pytorch_version = pytorch_version.rsplit('.', 1)[0]
+
+ if 'cuda' in device:
+ torch.backends.cudnn.benchmark = True
+
+ # Get the face detector
+ # face_detector_module = __import__('face_alignment.detection.' + face_detector,
+ # globals(), locals(), [face_detector], 0)
+ face_detector_kwargs = face_detector_kwargs or {}
+ self.face_detector = sfd.FaceDetector(device=device, verbose=verbose, **face_detector_kwargs)
+
+ # Initialise the face alignemnt networks
+ if landmarks_type == LandmarksType.TWO_D:
+ network_name = '2DFAN-' + str(network_size)
+ else:
+ network_name = '3DFAN-' + str(network_size)
+ if net_path is None:
+ net_path = load_file_from_url(models_urls.get(pytorch_version, default_model_urls)[network_name])
+ self.face_alignment_net = torch.jit.load(net_path)
+
+ self.face_alignment_net.to(device, dtype=dtype)
+ self.face_alignment_net.eval()
+
+ # Initialiase the depth prediciton network
+ if landmarks_type == LandmarksType.THREE_D:
+ self.depth_prediciton_net = torch.jit.load(
+ load_file_from_url(models_urls.get(pytorch_version, default_model_urls)['depth']))
+
+ self.depth_prediciton_net.to(device, dtype=dtype)
+ self.depth_prediciton_net.eval()
+
+ def get_landmarks(self, image_or_path, detected_faces=None, return_bboxes=False, return_landmark_score=False):
+ """Deprecated, please use get_landmarks_from_image
+
+ Arguments:
+ image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it
+
+ Keyword Arguments:
+ detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
+ in the image (default: {None})
+ return_bboxes {boolean} -- If True, return the face bounding boxes in addition to the keypoints.
+ return_landmark_score {boolean} -- If True, return the keypoint scores along with the keypoints.
+ """
+ return self.get_landmarks_from_image(image_or_path, detected_faces, return_bboxes, return_landmark_score)
+
+ @torch.no_grad()
+ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bboxes=False,
+ return_landmark_score=False):
+ """Predict the landmarks for each face present in the image.
+
+ This function predicts a set of 68 2D or 3D images, one for each image present.
+ If detect_faces is None the method will also run a face detector.
+
+ Arguments:
+ image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it.
+
+ Keyword Arguments:
+ detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
+ in the image (default: {None})
+ return_bboxes {boolean} -- If True, return the face bounding boxes in addition to the keypoints.
+ return_landmark_score {boolean} -- If True, return the keypoint scores along with the keypoints.
+
+ Return:
+ result:
+ 1. if both return_bboxes and return_landmark_score are False, result will be:
+ landmark
+ 2. Otherwise, result will be one of the following, depending on the actual value of return_* arguments.
+ (landmark, landmark_score, detected_face)
+ (landmark, None, detected_face)
+ (landmark, landmark_score, None )
+ """
+ image = get_image(image_or_path)
+
+ if detected_faces is None:
+ detected_faces = self.face_detector.detect_from_image(image.copy())
+
+ if len(detected_faces) == 0:
+ warnings.warn("No faces were detected.")
+ if return_bboxes or return_landmark_score:
+ return None, None, None
+ else:
+ return None
+
+ landmarks = []
+ landmarks_scores = []
+ for i, d in enumerate(detected_faces):
+ center = torch.tensor(
+ [d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
+ center[1] = center[1] - (d[3] - d[1]) * 0.12
+ scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
+
+ inp = crop(image, center, scale)
+ inp = torch.from_numpy(inp.transpose(
+ (2, 0, 1))).float()
+
+ inp = inp.to(self.device, dtype=self.dtype)
+ inp.div_(255.0).unsqueeze_(0)
+
+ out = self.face_alignment_net(inp).detach()
+ if self.flip_input:
+ out += flip(self.face_alignment_net(flip(inp)).detach(), is_label=True)
+ out = out.to(device='cpu', dtype=torch.float32).numpy()
+
+ pts, pts_img, scores = get_preds_fromhm(out, center.numpy(), scale)
+ pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
+ pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)
+ scores = scores.squeeze(0)
+
+ if self.landmarks_type == LandmarksType.THREE_D:
+ heatmaps = np.zeros((68, 256, 256), dtype=np.float32)
+ for i in range(68):
+ if pts[i, 0] > 0 and pts[i, 1] > 0:
+ heatmaps[i] = draw_gaussian(
+ heatmaps[i], pts[i], 2)
+ heatmaps = torch.from_numpy(
+ heatmaps).unsqueeze_(0)
+
+ heatmaps = heatmaps.to(self.device, dtype=self.dtype)
+ depth_pred = self.depth_prediciton_net(
+ torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1).to(dtype=torch.float32)
+ pts_img = torch.cat(
+ (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
+
+ landmarks.append(pts_img.numpy())
+ landmarks_scores.append(scores)
+
+ if not return_bboxes:
+ detected_faces = None
+ if not return_landmark_score:
+ landmarks_scores = None
+ if return_bboxes or return_landmark_score:
+ return landmarks, landmarks_scores, detected_faces
+ else:
+ return landmarks
+
+ @torch.no_grad()
+ def get_landmarks_from_batch(self, image_batch, detected_faces=None, return_bboxes=False,
+ return_landmark_score=False):
+ """Predict the landmarks for each face present in the image.
+
+ This function predicts a set of 68 2D or 3D images, one for each image in a batch in parallel.
+ If detect_faces is None the method will also run a face detector.
+
+ Arguments:
+ image_batch {torch.tensor} -- The input images batch
+
+ Keyword Arguments:
+ detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
+ in the image (default: {None})
+ return_bboxes {boolean} -- If True, return the face bounding boxes in addition to the keypoints.
+ return_landmark_score {boolean} -- If True, return the keypoint scores along with the keypoints.
+
+ Return:
+ result:
+ 1. if both return_bboxes and return_landmark_score are False, result will be:
+ landmarks
+ 2. Otherwise, result will be one of the following, depending on the actual value of return_* arguments.
+ (landmark, landmark_score, detected_face)
+ (landmark, None, detected_face)
+ (landmark, landmark_score, None )
+ """
+
+ if detected_faces is None:
+ detected_faces = self.face_detector.detect_from_batch(image_batch)
+
+ if len(detected_faces) == 0:
+ warnings.warn("No faces were detected.")
+ if return_bboxes or return_landmark_score:
+ return None, None, None
+ else:
+ return None
+
+ landmarks = []
+ landmarks_scores_list = []
+ # A batch for each frame
+ for i, faces in enumerate(detected_faces):
+ res = self.get_landmarks_from_image(
+ image_batch[i].cpu().numpy().transpose(1, 2, 0),
+ detected_faces=faces,
+ return_landmark_score=return_landmark_score,
+ )
+ if return_landmark_score:
+ landmark_set, landmarks_scores, _ = res
+ landmarks_scores_list.append(landmarks_scores)
+ else:
+ landmark_set = res
+ # Bacward compatibility
+ if landmark_set is not None:
+ landmark_set = np.concatenate(landmark_set, axis=0)
+ else:
+ landmark_set = []
+ landmarks.append(landmark_set)
+
+ if not return_bboxes:
+ detected_faces = None
+ if not return_landmark_score:
+ landmarks_scores_list = None
+ if return_bboxes or return_landmark_score:
+ return landmarks, landmarks_scores_list, detected_faces
+ else:
+ return landmarks
+
+ def get_landmarks_from_directory(self, path, extensions=['.jpg', '.png'], recursive=True, show_progress_bar=True,
+ return_bboxes=False, return_landmark_score=False):
+ """Scan a directory for images with a given extension type(s) and predict the landmarks for each
+ face present in the images found.
+
+ Arguments:
+ path {str} -- path to the target directory containing the images
+
+ Keyword Arguments:
+ extensions {list of str} -- list containing the image extensions considered (default: ['.jpg', '.png'])
+ recursive {boolean} -- If True, scans for images recursively (default: True)
+ show_progress_bar {boolean} -- If True displays a progress bar (default: True)
+ return_bboxes {boolean} -- If True, return the face bounding boxes in addition to the keypoints.
+ return_landmark_score {boolean} -- If True, return the keypoint scores along with the keypoints.
+ """
+ dataset = FolderData(path, self.face_detector.tensor_or_path_to_ndarray, extensions, recursive, self.verbose)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2, prefetch_factor=4)
+
+ predictions = {}
+ for (image_path, image) in tqdm(dataloader, disable=not show_progress_bar):
+ image_path, image = image_path[0], image[0]
+ bounding_boxes = self.face_detector.detect_from_image(image)
+ if return_bboxes or return_landmark_score:
+ preds, bbox, score = self.get_landmarks_from_image(
+ image, bounding_boxes, return_bboxes=return_bboxes, return_landmark_score=return_landmark_score)
+ predictions[image_path] = (preds, bbox, score)
+ else:
+ preds = self.get_landmarks_from_image(image, bounding_boxes)
+ predictions[image_path] = preds
+
+ return predictions
diff --git a/speaker_diarization/local/vision_tools/face_detection.py b/speaker_diarization/local/vision_tools/face_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ec465171a982ab89db4b57113b5ab4822da7406
--- /dev/null
+++ b/speaker_diarization/local/vision_tools/face_detection.py
@@ -0,0 +1,212 @@
+import torch
+import cv2
+import os
+import numpy as np
+import onnxruntime
+
+
+def area_of(left_top, right_bottom) -> torch.Tensor:
+ """Compute the areas of rectangles given two corners.
+
+ Args:
+ left_top (N, 2): left top corner.
+ right_bottom (N, 2): right bottom corner.
+
+ Returns:
+ area (N): return the area.
+ """
+ hw = torch.clamp(right_bottom - left_top, min=0.0)
+ return hw[..., 0] * hw[..., 1]
+
+def iou_of(boxes0, boxes1, eps=1e-5):
+ """Return intersection-over-union (Jaccard index) of boxes.
+
+ Args:
+ boxes0 (N, 4): ground truth boxes.
+ boxes1 (N or 1, 4): predicted boxes.
+ eps: a small number to avoid 0 as denominator.
+ Returns:
+ iou (N): IoU values.
+ """
+ overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2])
+ overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:])
+
+ overlap_area = area_of(overlap_left_top, overlap_right_bottom)
+ area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
+ area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
+ return overlap_area / (area0 + area1 - overlap_area + eps)
+
+def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
+ """
+ Args:
+ box_scores (N, 5): boxes in corner-form and probabilities.
+ iou_threshold: intersection over union threshold.
+ top_k: keep top_k results. If k <= 0, keep all the results.
+ candidate_size: only consider the candidates with the highest scores.
+ Returns:
+ picked: a list of indexes of the kept boxes
+ """
+ scores = box_scores[:, -1]
+ boxes = box_scores[:, :-1]
+ picked = []
+ _, indexes = scores.sort(descending=True)
+ indexes = indexes[:candidate_size]
+ while len(indexes) > 0:
+ current = indexes[0]
+ picked.append(current.item())
+ if 0 < top_k == len(picked) or len(indexes) == 1:
+ break
+ current_box = boxes[current, :]
+ indexes = indexes[1:]
+ rest_boxes = boxes[indexes, :]
+ iou = iou_of(
+ rest_boxes,
+ current_box.unsqueeze(0),
+ )
+ indexes = indexes[iou <= iou_threshold]
+
+ return box_scores[picked, :]
+
+
+class Resize(object):
+ def __init__(self, size=(300, 300)):
+ self.size = size
+
+ def __call__(self, image, boxes=None, labels=None):
+ image = cv2.resize(image, (self.size[0],
+ self.size[1]))
+ return image, boxes, labels
+
+
+class SubtractMeans(object):
+ def __init__(self, mean):
+ self.mean = np.array(mean, dtype=np.float32)
+
+ def __call__(self, image, boxes=None, labels=None):
+ image = image.astype(np.float32)
+ image -= self.mean
+ return image.astype(np.float32), boxes, labels
+
+
+class ToTensor(object):
+ def __call__(self, cvimage, boxes=None, labels=None):
+ return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels
+
+
+class Compose(object):
+ """Composes several augmentations together.
+ Args:
+ transforms (List[Transform]): list of transforms to compose.
+ Example:
+ >>> augmentations.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, img, boxes=None, labels=None):
+ for t in self.transforms:
+ img, boxes, labels = t(img, boxes, labels)
+ return img, boxes, labels
+
+
+class PredictionTransform:
+ def __init__(self, size, mean=0.0, std=1.0):
+ self.transform = Compose([
+ Resize(size),
+ SubtractMeans(mean),
+ lambda img, boxes=None, labels=None: (img / std, boxes, labels),
+ ToTensor()
+ ])
+
+ def __call__(self, image):
+ image, _, _ = self.transform(image)
+ return image
+
+
+class Config:
+ image_size = [320, 240]
+ image_mean_test = np.array([127, 127, 127])
+ image_std = 128.0
+
+
+class Predictor:
+ """
+ Face detection with pretrained model.
+ Reference:
+ - https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB
+ """
+ def __init__(
+ self,
+ onnx_dir,
+ device='cpu',
+ device_id=0,
+ iou_threshold=0.3,
+ filter_threshold=0.01,
+ candidate_size=200,
+ ):
+ onnx_file_name = os.path.join(onnx_dir, 'version-RFB-320.onnx')
+ assert os.path.exists(onnx_file_name), \
+ '%s does not exist. Please check if it has been downloaded accurately.' % onnx_file_name
+ self.ort_net = self.create_net(onnx_file_name, device, device_id)
+ self.transform = PredictionTransform(
+ Config.image_size, Config.image_mean_test, Config.image_std)
+ self.iou_threshold = iou_threshold
+ self.filter_threshold = filter_threshold
+ self.candidate_size = candidate_size
+ self.device = device
+
+ def __call__(self, image, top_k=-1, prob_threshold=None):
+ height, width, _ = image.shape
+ image = self.transform(image)
+ images = image.unsqueeze(0).numpy()
+ # net inference
+ inputs = {self.ort_net.get_inputs()[0].name:images}
+ scores, boxes = self.ort_net.run(None, inputs)
+ boxes = torch.from_numpy(boxes[0])
+ scores = torch.from_numpy(scores[0])
+ if not prob_threshold:
+ prob_threshold = self.filter_threshold
+ picked_box_probs = []
+ picked_labels = []
+ for class_index in range(1, scores.size(1)):
+ probs = scores[:, class_index]
+ mask = probs > prob_threshold
+ probs = probs[mask]
+ if probs.size(0) == 0:
+ continue
+ subset_boxes = boxes[mask, :]
+ box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1)
+ box_probs = hard_nms(box_probs, self.iou_threshold, top_k, self.candidate_size)
+ picked_box_probs.append(box_probs)
+ picked_labels.extend([class_index] * box_probs.size(0))
+ if not picked_box_probs:
+ return torch.tensor([]), torch.tensor([]), torch.tensor([])
+ picked_box_probs = torch.cat(picked_box_probs)
+ picked_box_probs[:, 0] *= width
+ picked_box_probs[:, 1] *= height
+ picked_box_probs[:, 2] *= width
+ picked_box_probs[:, 3] *= height
+ return picked_box_probs[:, :4], torch.tensor(picked_labels), picked_box_probs[:, 4]
+
+ def create_net(self, onnx_file_name, device='cpu', device_id=0):
+ options = onnxruntime.SessionOptions()
+ # set op_num_threads
+ options.intra_op_num_threads = 8
+ options.inter_op_num_threads = 8
+ # set providers
+ providers = ['CPUExecutionProvider']
+ if device == 'cuda':
+ providers.insert(0, ('CUDAExecutionProvider', {'device_id': device_id}))
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, options, providers=providers)
+ return ort_session
+
+
+if __name__ == '__main__':
+ predictor_det = Predictor('pretrained_models', 'cuda', '0')
+ image_input = np.random.randn(1920, 1080, 3).astype('float32')
+ bboxes, _, probs = predictor_det(image_input, top_k=10, prob_threshold=0.9)
+
\ No newline at end of file
diff --git a/speaker_diarization/local/vision_tools/face_quality_assessment.py b/speaker_diarization/local/vision_tools/face_quality_assessment.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1d9578d0e7ae43f271aa21f614ac58beea9ca79
--- /dev/null
+++ b/speaker_diarization/local/vision_tools/face_quality_assessment.py
@@ -0,0 +1,45 @@
+import cv2
+import os
+import numpy as np
+import onnxruntime
+
+class FaceQualityAssess:
+ """
+ Face quality assessment with pretrained model.
+ Reference:
+ - https://modelscope.cn/models/iic/cv_manual_face-quality-assessment_fqa
+ """
+ def __init__(self, onnx_dir, device='cpu', device_id=0):
+ onnx_file_name = os.path.join(onnx_dir, 'fqa.onnx')
+ assert os.path.exists(onnx_file_name), '%s does not exist. Please check if it has been downloaded accurately.' % onnx_file_name
+ self.ort_net = self.create_net(onnx_file_name, device, device_id)
+
+ def __call__(self, img):
+ img = img[:, :, ::-1] # bgr to rgb
+ img = cv2.resize(img, (112, 112))
+ img = np.transpose(img, axes=(2, 0, 1))
+ img = (img / 255. - 0.5) / 0.5
+ img = np.expand_dims(img.astype(np.float32), 0)
+ ort_inputs = {self.ort_net.get_inputs()[0].name:img}
+ result = self.ort_net.run(None, ort_inputs)[0]
+ score = np.mean(result)
+ return score
+
+ def create_net(self, onnx_file_name, device='cpu', device_id=0):
+ options = onnxruntime.SessionOptions()
+ # set op_num_threads
+ options.intra_op_num_threads = 8
+ options.inter_op_num_threads = 8
+ # set providers
+ providers = ['CPUExecutionProvider']
+ if device == 'cuda':
+ providers.insert(0, ('CUDAExecutionProvider', {'device_id': device_id}))
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, options, providers=providers)
+ return ort_session
+
+
+if __name__ == '__main__':
+ predictor = FaceQualityAssess('pretrained_models', 'cuda', '0')
+ input = np.random.randn(221, 196, 3).astype('float32')
+ output = predictor(input)
+ print(output)
diff --git a/speaker_diarization/local/vision_tools/face_recognition.py b/speaker_diarization/local/vision_tools/face_recognition.py
new file mode 100644
index 0000000000000000000000000000000000000000..05ba833ff492148fbc4b146e1454720bd76b2776
--- /dev/null
+++ b/speaker_diarization/local/vision_tools/face_recognition.py
@@ -0,0 +1,48 @@
+import cv2
+import os
+import numpy as np
+import onnxruntime
+
+
+class FaceRecIR101:
+ """
+ Face embeddings extraction with CurricularFace pretrained model.
+ Reference:
+ - https://modelscope.cn/models/iic/cv_ir101_facerecognition_cfglint
+ """
+ def __init__(self, onnx_dir, device='cpu', device_id=0):
+ onnx_file_name = os.path.join(onnx_dir, 'face_recog_ir101.onnx')
+ assert os.path.exists(onnx_file_name), '%s does not exist. Please check if it has been downloaded accurately.' % onnx_file_name
+ self.ort_net = self.create_net(onnx_file_name, device, device_id)
+
+ def __call__(self, img):
+ img = img[:, :, ::-1] # bgr to rgb
+ img = cv2.resize(img, (112, 112))
+ img = np.transpose(img, axes=(2, 0, 1))
+ img = (img / 255. - 0.5) / 0.5
+ img = np.expand_dims(img.astype(np.float32), 0)
+
+ ort_inputs = {self.ort_net.get_inputs()[0].name:img}
+ emb = self.ort_net.run(None, ort_inputs)[0] # 未归一化的信息保留了强度信息,有更多的表情强度细节
+ # emb_normalized = emb / np.sqrt(np.sum(emb**2, -1, keepdims=True)) # 归一化得到的信息消除了强度信息,更侧重说话人身份信息
+ # return emb, emb_normalized
+ return emb
+
+ def create_net(self, onnx_file_name, device='cpu', device_id=0):
+ options = onnxruntime.SessionOptions()
+ # set op_num_threads
+ options.intra_op_num_threads = 8
+ options.inter_op_num_threads = 8
+ # set providers
+ providers = ['CPUExecutionProvider']
+ if device == 'cuda':
+ providers.insert(0, ('CUDAExecutionProvider', {'device_id': device_id}))
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, options, providers=providers)
+ return ort_session
+
+
+if __name__ == '__main__':
+ predictor = FaceRecIR101('pretrained_models', 'cuda', 0)
+ input = np.random.randn(315, 244, 3).astype('float32')
+ output = predictor(input)
+ print(output.shape)
diff --git a/speaker_diarization/local/vision_tools/lip_detection.py b/speaker_diarization/local/vision_tools/lip_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0be3d044be4be9a4e243b290b529655754a184f
--- /dev/null
+++ b/speaker_diarization/local/vision_tools/lip_detection.py
@@ -0,0 +1,87 @@
+import os
+import cv2
+from .api import FaceAlignment, LandmarksType
+import numpy as np
+
+class LipDetector:
+ """
+ 在 face crop 上检测唇部位置,
+ 基于修改的 face_alignment api 调用 FAN 模型实现。
+ """
+
+ def __init__(
+ self,
+ model_dir=None,
+ device='cpu',
+ device_id=0,
+ landmarks_type=LandmarksType.TWO_D,
+ ):
+ # 设置设备字符串
+ device_str = 'cpu'
+ if device == 'cuda':
+ device_str = f'cuda:{device_id}'
+
+ if model_dir is not None:
+ model_path = os.path.join(model_dir, 'fun_2d.pth')
+ net_path = os.path.join(model_dir, 'fun_2d.zip') # 使用预下载模型避免长时间下载
+ print(f"Loading FAN model from {model_path} on {device_str}...")
+ else:
+ model_path = None
+ self.fa = FaceAlignment(
+ landmarks_type = landmarks_type,
+ device = device_str,
+ net_path = net_path,
+ face_detector_kwargs = {'path_to_detector': model_path},
+ )
+
+
+ def detect_lip(self, face_img):
+ """
+ face_img: BGR image of the face crop (tight face crop, e.g. 224x224)
+ 返回 dict:
+ { 'lip_bbox': (x1,y1,x2,y2) (relative to face_img) or None,
+ 'lip_crop': np.ndarray or None,
+ 'kps': np.ndarray (N,2) or None }
+ """
+ H, W = face_img.shape[:2]
+
+ # 1) 使用 FAN 模型检测关键点
+ try:
+ # 转换颜色空间 BGR -> RGB
+ rgb_img = cv2.cvtColor(face_img, cv2.COLOR_BGR2RGB)
+
+ # 预测关键点
+ preds = self.fa.get_landmarks(rgb_img)
+
+ if preds is not None and len(preds) > 0:
+ # 获取第一个检测到的人脸关键点 (68个点)
+ kps = preds[0]
+
+ # 提取嘴唇关键点 (48-68点)
+ mouth_kps = kps[48:68]
+
+ # 计算边界框
+ min_xy = mouth_kps.min(axis=0)
+ max_xy = mouth_kps.max(axis=0)
+
+ # 添加 padding
+ pad = 0.18 * (max_xy - min_xy)
+ x1 = int(max(0, min_xy[0] - pad[0]))
+ y1 = int(max(0, min_xy[1] - pad[1]))
+ x2 = int(min(W - 1, max_xy[0] + pad[0]))
+ y2 = int(min(H - 1, max_xy[1] + pad[1]))
+
+ # 裁剪嘴唇区域
+ lip_bbox_array = np.array([x1, y1, x2, y2], dtype=np.int32)
+ lip_crop = face_img[y1:y2, x1:x2].copy()
+
+ return {
+ 'lip_bbox': lip_bbox_array,
+ 'lip_crop': lip_crop,
+ 'kps': mouth_kps
+ }
+
+ except Exception as e:
+ print(f"FAN detection failed: {e}")
+ return None
+
\ No newline at end of file
diff --git a/speaker_diarization/path.sh b/speaker_diarization/path.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d2d0234adfb69e703e9bf9139a7bc1c9f7b7bd6f
--- /dev/null
+++ b/speaker_diarization/path.sh
@@ -0,0 +1,3 @@
+export PATH=$PWD:$PATH
+export PYTHONPATH=../:$PYTHONPATH
+export OMP_NUM_THREADS=1
diff --git a/speaker_diarization/run.py b/speaker_diarization/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a117c7a4ea5a388b8e147929c5ec34908e96f29
--- /dev/null
+++ b/speaker_diarization/run.py
@@ -0,0 +1,614 @@
+#!/usr/bin/env python3
+import os
+import json
+import pickle
+import torch
+import gc
+import time
+import queue
+from pathlib import Path
+import argparse
+import numpy as np
+from contextlib import contextmanager
+from pydub import AudioSegment
+from typing import Dict, Any, Optional, Callable
+try:
+ from modelscope.pipelines import pipeline
+ from modelscope.utils.constant import Tasks
+except ImportError:
+ raise ImportError("Please install modelscope: pip install modelscope")
+from speaker_diarization.local.utils.utils import circle_pad
+from speaker_diarization.local.utils.config import yaml_config_loader, build_config
+from speaker_diarization.local.utils.builder import build
+from speaker_diarization.local.utils.fileio import load_audio
+import speaker_diarization.local.vision_tools.face_detection as face_detection
+import speaker_diarization.local.vision_tools.active_speaker_detection as active_speaker_detection
+import speaker_diarization.local.vision_tools.face_recognition as face_recognition
+import speaker_diarization.local.vision_tools.face_quality_assessment as face_quality_assessment
+import speaker_diarization.local.vision_tools.lip_detection as lip_detection
+from speaker_diarization.local.vision_processer import VisionProcesser
+
+class ModelPool:
+ def __init__(self, creator: Callable, pool_size: int = 1):
+ self._q = queue.Queue(maxsize=pool_size)
+ for _ in range(pool_size):
+ self._q.put(creator())
+ @contextmanager
+ def borrow(self, timeout: Optional[float] = None):
+ try:
+ inst = self._q.get(timeout=timeout)
+ except queue.Empty:
+ raise RuntimeError(f"Timeout ({timeout}s) when borrowing model instance")
+ try:
+ yield inst
+ finally:
+ self._q.put(inst)
+
+
+class GlobalModels:
+ _instance = None
+ _initialized = False
+ def __new__(cls, *args, **kwargs):
+ if cls._instance is None:
+ cls._instance = super(GlobalModels, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(
+ self,
+ hf_token: Optional[str] = None,
+ config_path: Optional[str] = None,
+ pretrained_dir: Optional[str] = None,
+ device: Optional[str] = None,
+ device_id: int = 0,
+ pool_sizes: Optional[Dict[str, int]] = None,
+ batch_size: int = 32,
+ preload: bool = True,
+ ):
+ if hasattr(self, "initialized"):
+ return
+ self.hf_token = hf_token
+ self.config_path = config_path
+ self.conf = yaml_config_loader(config_path)
+ self.pretrained_dir = Path(pretrained_dir) if pretrained_dir else None
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
+ self.device_str = "cuda" if self.device.type == "cuda" else "cpu"
+ self.device_id = device_id
+ self.batch_size = batch_size
+ self.pool_sizes = pool_sizes or {}
+ self.visual_pools: Dict[str, ModelPool] = {}
+ self.audio_models: Dict[str, Any] = {
+ "segmentation": None,
+ "vad_pipeline": None,
+ "feature_extractor": None,
+ "embedding_model": None,
+ }
+ if preload:
+ self.preload()
+ self.initialized = True
+
+ def preload(self):
+ """预加载所有模型(音频 + 视觉)"""
+ if not all(self.audio_models.values()) and self.hf_token and self.config_path and self.pretrained_dir:
+ self._init_audio_models()
+ if not self.visual_pools and self.pretrained_dir:
+ self._init_visual_pools()
+
+ def _init_audio_models(self):
+ """初始化音频模型"""
+ if all(self.audio_models.values()):
+ return
+ start_time = time.time()
+
+ # 1. Pyannote Segmentation
+ print("[INFO] Loading segmentation model (overlap detection)...")
+ self.audio_models["segmentation"] = None
+
+ # 2. VAD: ModelScope FSMN-VAD
+ print("[INFO] Loading VAD model...")
+ vad_model_path = self.pretrained_dir / "speech_fsmn_vad"
+ self.audio_models["vad_pipeline"] = pipeline(
+ task=Tasks.voice_activity_detection,
+ model=str(vad_model_path),
+ device=self.device_str,
+ )
+
+ # 3. Speaker Embedding: CAMPPlus
+ print("[INFO] Loading CAMPPlus speaker embedding model...")
+ feature_extractor = build('feature_extractor', self.conf)
+ embedding_model = build('embedding_model', self.conf)
+
+ ckpt = self.pretrained_dir / "speech_campplus" / "campplus_cn_en_common.pt"
+ state_dict = torch.load(ckpt, map_location=self.device)
+ embedding_model.load_state_dict(state_dict)
+ embedding_model.eval().to(self.device)
+ self.audio_models["feature_extractor"] = feature_extractor
+ self.audio_models["embedding_model"] = embedding_model
+
+ print(f"[SUCCESS] Audio models loaded in {time.time() - start_time:.2f}s.")
+
+ def _init_visual_pools(self):
+ """初始化视觉模型池"""
+ if self.visual_pools:
+ return
+ print("[INFO] Initializing visual model pools...")
+ self.visual_pools['face'] = ModelPool(
+ lambda: face_detection.Predictor(self.pretrained_dir, self.device_str, self.device_id),
+ pool_size=self.pool_sizes.get('face', 1)
+ )
+ self.visual_pools['asd'] = ModelPool(
+ lambda: active_speaker_detection.ASDTalknet(self.pretrained_dir, self.device_str, self.device_id),
+ pool_size=self.pool_sizes.get('asd', 1)
+ )
+ self.visual_pools['fr'] = ModelPool(
+ lambda: face_recognition.FaceRecIR101(self.pretrained_dir, self.device_str, self.device_id),
+ pool_size=self.pool_sizes.get('fr', 1)
+ )
+ self.visual_pools['fq'] = ModelPool(
+ lambda: face_quality_assessment.FaceQualityAssess(self.pretrained_dir, self.device_str, self.device_id),
+ pool_size=self.pool_sizes.get('fq', 1)
+ )
+ self.visual_pools['lip'] = ModelPool(
+ lambda: lip_detection.LipDetector(self.pretrained_dir, self.device_str, self.device_id),
+ pool_size=self.pool_sizes.get('lip', 1)
+ )
+ print("[SUCCESS] Visual model pools initialized.")
+
+ # === 音频模型获取接口 ===
+ def get_segmentation_model(self):
+ if self.audio_models["segmentation"] is None:
+ raise RuntimeError("Segmentation model not loaded. Call preload() first.")
+ return self.audio_models["segmentation"]
+
+ def get_vad_pipeline(self):
+ if self.audio_models["vad_pipeline"] is None:
+ raise RuntimeError("VAD pipeline not loaded.")
+ return self.audio_models["vad_pipeline"]
+
+ def get_embedding_components(self):
+ if self.audio_models["feature_extractor"] is None or self.audio_models["embedding_model"] is None:
+ raise RuntimeError("Embedding models not loaded.")
+ return self.audio_models["feature_extractor"], self.audio_models["embedding_model"]
+
+ # === 视觉模型调用接口 ===
+ def detect_faces(self, image, top_k=10, prob_threshold=0.9, borrow_timeout=1000):
+ with self.visual_pools['face'].borrow(timeout=borrow_timeout) as model:
+ return model(image, top_k=top_k, prob_threshold=prob_threshold)
+
+ def asd_score(self, audio_feature, video_feature, borrow_timeout=1000):
+ with self.visual_pools['asd'].borrow(timeout=borrow_timeout) as model:
+ return model(audio_feature, video_feature)
+
+ def get_face_embedding(self, face_image, borrow_timeout=2000):
+ with self.visual_pools['fr'].borrow(timeout=borrow_timeout) as model:
+ return model(face_image)
+
+ def face_quality_score(self, face_image, borrow_timeout=1000):
+ with self.visual_pools['fq'].borrow(timeout=borrow_timeout) as model:
+ return model(face_image)
+
+ def detect_lip(self, face_image, borrow_timeout=3000):
+ with self.visual_pools['lip'].borrow(timeout=borrow_timeout) as model:
+ return model.detect_lip(face_image)
+
+ def release(self):
+ for k in self.audio_models:
+ self.audio_models[k] = None
+ for _, pool in self.visual_pools.items():
+ del pool
+ self.visual_pools.clear()
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+model_pool: Optional[GlobalModels] = None
+
+# =======================
+# 工具函数
+# =======================
+
+def extract_audio_from_video(video_path: str, wav_path: str, sample_rate: int = 16000):
+ """Extract mono 16kHz WAV from video."""
+ print(f"[INFO] Extracting audio from {video_path} to {wav_path}")
+ audio = AudioSegment.from_file(video_path)
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
+ audio.export(wav_path, format="wav")
+
+
+# def detect_overlap(wav_path: str, threshold: float = 0.5) -> bool:
+# """Detect speaker overlap using preloaded segmentation model."""
+# print("[INFO] Running overlap detection...")
+# model = model_pool.get_segmentation_model()
+# device = model_pool.device
+
+# inference = Inference(
+# model,
+# duration=model.specifications.duration,
+# step=0.1 * model.specifications.duration,
+# skip_aggregation=True,
+# batch_size=model_pool.batch_size,
+# device=device,
+# )
+# try:
+# segmentations = inference({"audio": Path(wav_path)})
+# frame_windows = inference.model.receptive_field
+
+# # Aggregate and count active speakers
+# count_feat = Inference.aggregate(
+# np.sum(segmentations, axis=-1, keepdims=True),
+# frame_windows,
+# hamming=False,
+# missing=0.0,
+# skip_average=False,
+# )
+# count_feat.data = np.rint(count_feat.data).astype(np.uint8)
+# count_data = count_feat.data.squeeze()
+# sliding_window = count_feat.sliding_window
+# total_overlap_duration = 0.0
+# current_start = None
+# for i, val in enumerate(count_data):
+# timestamp = sliding_window[i].start
+# if val >= 2:
+# if current_start is None:
+# current_start = timestamp
+# else:
+# if current_start is not None:
+# current_end = timestamp
+# duration = current_end - current_start
+# if duration >= threshold:
+# total_overlap_duration += duration
+# current_start = None
+
+# if current_start is not None:
+# current_end = sliding_window[-1].end
+# duration = current_end - current_start
+# if duration >= threshold:
+# total_overlap_duration += duration
+# has_overlap = total_overlap_duration > 0
+# return has_overlap
+
+# finally:
+# del inference, segmentations
+# gc.collect()
+# if torch.cuda.is_available():
+# torch.cuda.empty_cache()
+
+
+def run_vad(wav_path: str, out_file: str):
+ """Run VAD using preloaded model."""
+ print("[INFO] Running voice activity detection...")
+ vad_pipeline = model_pool.get_vad_pipeline()
+ result = vad_pipeline(wav_path)[0]
+ vad_time = [[round(v[0] / 1000, 3), round(v[1] / 1000, 3)] for v in result['value']]
+
+ basename = Path(wav_path).stem
+ json_dict = {}
+ for start, end in vad_time:
+ seg_id = f"{basename}_{start}_{end}"
+ json_dict[seg_id] = {
+ "file": wav_path,
+ "start": start,
+ "stop": end
+ }
+ os.makedirs(Path(out_file).parent, exist_ok=True)
+ with open(out_file, 'w') as f:
+ json.dump(json_dict, f, indent=2)
+ print(f"[INFO] VAD saved to {out_file}")
+ return json_dict
+
+
+def generate_subsegments(vad_json_path: str, out_file: str, dur: float = 1.5, shift: float = 0.75):
+ """Generate overlapping subsegments from VAD output."""
+ print("[INFO] Generating sub-segments...")
+ with open(vad_json_path, 'r') as f:
+ vad_json = json.load(f)
+
+ subseg_json = {}
+ for segid in vad_json:
+ wavid = segid.rsplit('_', 2)[0]
+ st = vad_json[segid]['start']
+ ed = vad_json[segid]['stop']
+ subseg_st = st
+ while subseg_st + dur < ed + shift:
+ subseg_ed = min(subseg_st + dur, ed)
+ item = vad_json[segid].copy()
+ item.update({
+ 'start': round(subseg_st, 2),
+ 'stop': round(subseg_ed, 2)
+ })
+ subsegid_new = f"{wavid}_{round(subseg_st, 2)}_{round(subseg_ed, 2)}"
+ subseg_json[subsegid_new] = item
+ subseg_st += shift
+
+ os.makedirs(Path(out_file).parent, exist_ok=True)
+ with open(out_file, 'w') as f:
+ json.dump(subseg_json, f, indent=2)
+ print(f"[INFO] Subsegments saved to {out_file}")
+
+
+def merge_overlap_region(vad_time_list):
+ if not vad_time_list:
+ return []
+ vad_time_list.sort(key=lambda x: x[0])
+ out_vad_time_list = []
+ for t in vad_time_list:
+ if len(out_vad_time_list) == 0 or t[0] > out_vad_time_list[-1][1]:
+ out_vad_time_list.append(t[:])
+ else:
+ out_vad_time_list[-1][1] = max(out_vad_time_list[-1][1], t[1])
+ return out_vad_time_list
+
+def create_debug_path(debug_dir, name):
+ if not debug_dir:
+ return None
+ path = Path(debug_dir) / f"{name}_DEBUG.mp4"
+ path.parent.mkdir(parents=True, exist_ok=True)
+ return str(path)
+
+def make_rttms(seg_list, out_rttm, rec_id):
+ """
+ Merge overlapping segments and write RTTM format.
+ seg_list: list of [(start_time, end_time), label]
+ """
+ new_seg_list = []
+ for i, seg in enumerate(seg_list):
+ seg_st, seg_ed = float(seg[0][0]), float(seg[0][1])
+ cluster_id = int(seg[1]) + 1 # 1-indexed
+
+ if not new_seg_list:
+ new_seg_list.append([rec_id, seg_st, seg_ed, cluster_id])
+ else:
+ last = new_seg_list[-1]
+ if cluster_id == last[3]: # Same speaker
+ if seg_st > last[2]:
+ new_seg_list.append([rec_id, seg_st, seg_ed, cluster_id])
+ else:
+ last[2] = max(last[2], seg_ed) # Extend end time
+ else: # Different speaker
+ if seg_st < last[2]: # Overlap → split at midpoint
+ mid = (last[2] + seg_st) / 2
+ last[2] = mid
+ seg_st = mid
+ new_seg_list.append([rec_id, seg_st, seg_ed, cluster_id])
+
+ line_str = "SPEAKER {} 1 {:.3f} {:.3f} {:d} \n"
+ with open(out_rttm, 'w') as f:
+ for seg in new_seg_list:
+ f.write(line_str.format(seg[0], seg[1], seg[2] - seg[1], seg[3]))
+ print(f"[INFO] RTTM saved to {out_rttm}")
+
+
+def extract_wav_embeddings(subseg_json_path: str, wav_emb_path: str):
+ """Extract embeddings using preloaded embedding models."""
+ print("[INFO] Extracting speaker embeddings...")
+ device = model_pool.device
+ batch_size = model_pool.batch_size
+ feature_extractor, embedding_model = model_pool.get_embedding_components()
+
+ with open(subseg_json_path, 'r') as f:
+ subseg_json = json.load(f)
+ if not subseg_json:
+ print("[WARNING] No segments found. Skipping embedding extraction.")
+ return
+ all_keys = list(subseg_json.keys())
+ if Path(wav_emb_path).exists():
+ print(f"[INFO] Embedding already exists: {wav_emb_path}, skipping.")
+ return
+
+ wav_path = subseg_json[all_keys[0]]['file']
+ wav = load_audio(wav_path, obj_fs=16000)
+
+ wavs = []
+ times = []
+ for key in subseg_json:
+ start = int(subseg_json[key]['start'] * 16000)
+ end = int(subseg_json[key]['stop'] * 16000)
+ wavs.append(wav[0, start:end]) # mono
+ times.append([subseg_json[key]['start'], subseg_json[key]['stop']])
+
+ max_len = max(w.shape[0] for w in wavs)
+ wavs = [circle_pad(w, max_len) for w in wavs]
+ wavs_tensor = torch.stack(wavs).unsqueeze(1) # (B, 1, T)
+
+ embeddings = []
+ with torch.no_grad():
+ for i in range(0, len(wavs_tensor), batch_size):
+ batch = wavs_tensor[i:i + batch_size].to(device)
+ feats = torch.vmap(feature_extractor)(batch)
+ embs_batch = embedding_model(feats).cpu()
+ embeddings.append(embs_batch)
+
+ embeddings = torch.cat(embeddings, dim=0).numpy()
+
+ result = {
+ 'embeddings': embeddings,
+ 'times': times
+ }
+ with open(wav_emb_path, 'wb') as f:
+ pickle.dump(result, f)
+ print(f"[INFO] Embeddings saved to {wav_emb_path}")
+
+def extract_visual_embeddings(
+ vad_data: json,
+ video_path: str,
+ wav_path: str,
+ face_emb_pkl: str,
+ debug_dir:str
+):
+ rec_id = video_path.stem
+ subset = {k: v for k, v in vad_data.items() if k.rsplit('_', 2)[0] == rec_id}
+ if len(subset) == 0:
+ print(f"[WARNING] No VAD segments for {rec_id}.")
+ return None
+ rec_vad_time_list = [[v['start'], v['stop']] for v in subset.values()]
+ rec_vad_time_list = merge_overlap_region(rec_vad_time_list)
+ debug_video = create_debug_path(debug_dir, rec_id)
+
+ try:
+ vp = VisionProcesser(
+ video_file_path = video_path,
+ audio_file_path = wav_path,
+ audio_vad = rec_vad_time_list,
+ out_feat_path = face_emb_pkl,
+ visual_models = model_pool,
+ conf = model_pool.conf,
+ out_video_path=debug_video
+ )
+ vp.run()
+ except Exception as e:
+ print(f"[ERROR] Failed to process {video_path}: {e}")
+ raise
+ finally:
+ if 'vp' in locals():
+ vp.close()
+
+def audio_only_cluster(audio_embs_file, rttm_file, rec_id, config):
+ print("[INFO] Running audio-only clustering...")
+ cluster = build('audio_cluster', config)
+ if not os.path.exists(audio_embs_file):
+ print(f"[ERROR] Audio embedding file not found: {audio_embs_file}")
+ return False
+
+ with open(audio_embs_file, 'rb') as f:
+ stat_obj = pickle.load(f)
+ embeddings = stat_obj['embeddings']
+ times = stat_obj['times']
+ # cluster
+ labels = cluster(embeddings)
+ # output rttm
+ new_labels = np.zeros(len(labels), dtype=int)
+ uniq = np.unique(labels)
+ for i in range(len(uniq)):
+ new_labels[labels==uniq[i]] = i
+ seg_list = [(i,j) for i, j in zip(times, new_labels)]
+ make_rttms(seg_list, rttm_file, rec_id)
+ return True
+
+
+def audio_visual_cluster(audio_embs_file, visual_embs_file, rttm_file, rec_id, config):
+ print("[INFO] Running audio-visual joint clustering...")
+ cluster = build('cluster', config)
+ if not os.path.exists(audio_embs_file):
+ print(f"[ERROR] Audio embedding file not found: {audio_embs_file}")
+ return False
+ if not os.path.exists(visual_embs_file):
+ print(f"[ERROR] Visual embedding file not found: {visual_embs_file}")
+ return False
+
+ # Load audio embeddings
+ with open(audio_embs_file, 'rb') as f:
+ a_data = pickle.load(f)
+ audio_embeddings = a_data['embeddings']
+ audio_times = a_data['times']
+
+ # Load visual embeddings
+ with open(visual_embs_file, 'rb') as f:
+ v_data = pickle.load(f)
+ visual_embeddings = v_data['embeddings']
+ frameI = v_data['frameI']
+ faceI = v_data['faceI']
+ visual_times = frameI * 0.04
+ frame_indices = [np.where(faceI == frame)[0][0] for frame in frameI]
+ speak_embeddings = visual_embeddings[frame_indices]
+ visual_embeddings_normalized = speak_embeddings / np.sqrt(np.sum(speak_embeddings**2, axis=-1, keepdims=True))
+
+ labels = cluster(audio_embeddings, visual_embeddings_normalized, audio_times, visual_times, config)
+ # output rttm
+ new_labels = np.zeros(len(labels), dtype=int)
+ uniq = np.unique(labels)
+ for i in range(len(uniq)):
+ new_labels[labels==uniq[i]] = i
+ seg_list = [(i,j) for i, j in zip(audio_times, new_labels)]
+ make_rttms(seg_list, rttm_file, rec_id)
+ return True
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Process a single video for speaker embedding extraction.")
+ parser.add_argument("--video", type=str, required=True, help="Path to input MP4 video file")
+ parser.add_argument("--work_dir", type=str, required=True, help="Working directory to save intermediate files")
+ parser.add_argument("--hf_token", type=str, required=True, help="HuggingFace access token for pyannote")
+ parser.add_argument("--config", default="diar.yaml", help="YAML config file")
+ parser.add_argument("--pretrained", type=str, required=True, help="Path to local pretrained models")
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
+ parser.add_argument("--device", type=str, default="cpu", help="Device to use: 'cuda' or 'cpu'.")
+ parser.add_argument("--jointcluster", action="store_true", help="Use audio-visual joint clustering. If not set, use audio-only clustering.")
+ parser.add_argument("--debug_dir", type=str, default="", help="Optional: save debug video")
+ args = parser.parse_args()
+
+ global model_pool
+ model_pool = GlobalModels(
+ hf_token = args.hf_token,
+ config_path = args.config,
+ pretrained_dir= args.pretrained,
+ device= args.device,
+ pool_sizes = {"face": 1, "asd": 8, "fr": 3},
+ batch_size = args.batch_size,
+ preload = True
+ )
+
+ video_path = Path(args.video)
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"Video not found: {video_path}")
+ work_dir = Path(args.work_dir)
+ work_dir.mkdir(parents=True, exist_ok=True)
+ rec_id = video_path.stem
+ wav_path = work_dir / f"{rec_id}.wav"
+ vad_json = work_dir / "vad.json"
+ subseg_json = work_dir / "subseg.json"
+ wav_emb_pkl = work_dir / "audio.pkl"
+ face_emb_pkl = work_dir / "face.pkl"
+ rttm_file = work_dir / f"{rec_id}.rttm"
+
+ # Pipeline Start
+ infer_start = time.time()
+
+ # 1. Extract audio
+ extract_audio_from_video(video_path, wav_path)
+
+ # # 2. Overlap detection
+ # if detect_overlap(str(wav_path), threshold=1.0):
+ # print("[WARNING] Speaker overlap detected. Skipping this video.")
+ # os.remove(wav_path)
+ # return
+
+ # 3. VAD
+ vad_data = run_vad(str(wav_path), str(vad_json))
+
+ # 4. Sub-segment
+ generate_subsegments(str(vad_json), str(subseg_json), dur=1.5, shift=0.75)
+
+ # 5. Extract audio embeddings
+ extract_wav_embeddings(str(subseg_json), str(wav_emb_pkl))
+
+ # 6. Extract visual embeddings
+ extract_visual_embeddings(vad_data, video_path, str(wav_path), str(face_emb_pkl), args.debug_dir)
+
+ # 7. Cluster audio and visual embeddings
+ config = build_config(args.config)
+ if args.jointcluster and face_emb_pkl.exists():
+ success = audio_visual_cluster(
+ str(wav_emb_pkl),
+ str(face_emb_pkl),
+ str(rttm_file),
+ rec_id,
+ config
+ )
+ else:
+ print("[INFO] Visual embeddings not found, using audio-only mode.")
+ success = audio_only_cluster(
+ str(wav_emb_pkl),
+ str(rttm_file),
+ rec_id,
+ config
+ )
+
+ inference_time = time.time() - infer_start
+ if success:
+ print("✅ PROCESSING COMPLETED")
+ else:
+ print("[FAILED] Clustering failed.")
+ print(f"Inference Time: {inference_time:.2f}s")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/speaker_diarization/speakerlab/models/campplus/DTDNN.py b/speaker_diarization/speakerlab/models/campplus/DTDNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..33028f66b1b38df076344f27cc3eae297fd24376
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/campplus/DTDNN.py
@@ -0,0 +1,112 @@
+from collections import OrderedDict
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from speakerlab.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear
+
+
+class FCM(nn.Module):
+ def __init__(self,
+ block=BasicResBlock,
+ num_blocks=[2, 2],
+ m_channels=32,
+ feat_dim=80):
+ super(FCM, self).__init__()
+ self.in_planes = m_channels
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(m_channels)
+
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
+
+ self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(m_channels)
+ self.out_channels = m_channels * (feat_dim // 8)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = F.relu(self.bn2(self.conv2(out)))
+
+ shape = out.shape
+ out = out.reshape(shape[0], shape[1]*shape[2], shape[3])
+ return out
+
+class CAMPPlus(nn.Module):
+ def __init__(self,
+ feat_dim=80,
+ embedding_size=512,
+ growth_rate=32,
+ bn_size=4,
+ init_channels=128,
+ config_str='batchnorm-relu',
+ memory_efficient=True):
+ super(CAMPPlus, self).__init__()
+
+ self.head = FCM(feat_dim=feat_dim)
+ channels = self.head.out_channels
+
+ self.xvector = nn.Sequential(
+ OrderedDict([
+
+ ('tdnn',
+ TDNNLayer(channels,
+ init_channels,
+ 5,
+ stride=2,
+ dilation=1,
+ padding=-1,
+ config_str=config_str)),
+ ]))
+ channels = init_channels
+ for i, (num_layers, kernel_size,
+ dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
+ block = CAMDenseTDNNBlock(num_layers=num_layers,
+ in_channels=channels,
+ out_channels=growth_rate,
+ bn_channels=bn_size * growth_rate,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ config_str=config_str,
+ memory_efficient=memory_efficient)
+ self.xvector.add_module('block%d' % (i + 1), block)
+ channels = channels + num_layers * growth_rate
+ self.xvector.add_module(
+ 'transit%d' % (i + 1),
+ TransitLayer(channels,
+ channels // 2,
+ bias=False,
+ config_str=config_str))
+ channels //= 2
+
+ self.xvector.add_module(
+ 'out_nonlinear', get_nonlinear(config_str, channels))
+
+ self.xvector.add_module('stats', StatsPool())
+ self.xvector.add_module(
+ 'dense',
+ DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
+
+ for m in self.modules():
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.kaiming_normal_(m.weight.data)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+ x = self.head(x)
+ x = self.xvector(x)
+ return x
diff --git a/speaker_diarization/speakerlab/models/campplus/classifier.py b/speaker_diarization/speakerlab/models/campplus/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..c29d4c312b47111a02ff93d6d7826291dda4f0c3
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/campplus/classifier.py
@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from speakerlab.models.campplus.layers import DenseLayer
+
+
+class CosineClassifier(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_blocks=0,
+ inter_dim=512,
+ out_neurons=1000,
+ ):
+
+ super().__init__()
+ self.blocks = nn.ModuleList()
+
+ for index in range(num_blocks):
+ self.blocks.append(
+ DenseLayer(input_dim, inter_dim, config_str='batchnorm')
+ )
+ input_dim = inter_dim
+
+ self.weight = nn.Parameter(
+ torch.FloatTensor(out_neurons, input_dim)
+ )
+ nn.init.xavier_uniform_(self.weight)
+
+ def forward(self, x):
+ # x: [B, dim]
+ for layer in self.blocks:
+ x = layer(x)
+
+ # normalized
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
+ return x
+
+class LinearClassifier(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_blocks=0,
+ inter_dim=512,
+ out_neurons=1000,
+ ):
+
+ super().__init__()
+ self.blocks = nn.ModuleList()
+
+ self.nonlinear = nn.ReLU(inplace=True)
+ for index in range(num_blocks):
+ self.blocks.append(
+ DenseLayer(input_dim, inter_dim, bias=True)
+ )
+ input_dim = inter_dim
+
+ self.linear = nn.Linear(input_dim, out_neurons, bias=True)
+
+ def forward(self, x):
+ # x: [B, dim]
+ x = self.nonlinear(x)
+ for layer in self.blocks:
+ x = layer(x)
+ x = self.linear(x)
+ return x
diff --git a/speaker_diarization/speakerlab/models/campplus/layers.py b/speaker_diarization/speakerlab/models/campplus/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc3de0bbb94a042c577a371d3e8eb000ec81a730
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/campplus/layers.py
@@ -0,0 +1,250 @@
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch import nn
+
+
+def get_nonlinear(config_str, channels):
+ nonlinear = nn.Sequential()
+ for name in config_str.split('-'):
+ if name == 'relu':
+ nonlinear.add_module('relu', nn.ReLU(inplace=True))
+ elif name == 'prelu':
+ nonlinear.add_module('prelu', nn.PReLU(channels))
+ elif name == 'batchnorm':
+ nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
+ elif name == 'batchnorm_':
+ nonlinear.add_module('batchnorm',
+ nn.BatchNorm1d(channels, affine=False))
+ else:
+ raise ValueError('Unexpected module ({}).'.format(name))
+ return nonlinear
+
+def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
+ mean = x.mean(dim=dim)
+ std = x.std(dim=dim, unbiased=unbiased)
+ stats = torch.cat([mean, std], dim=-1)
+ if keepdim:
+ stats = stats.unsqueeze(dim=dim)
+ return stats
+
+
+class StatsPool(nn.Module):
+ def forward(self, x):
+ return statistics_pooling(x)
+
+
+class TDNNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu'):
+ super(TDNNLayer, self).__init__()
+ if padding < 0:
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+ kernel_size)
+ padding = (kernel_size - 1) // 2 * dilation
+ self.linear = nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+
+class CAMLayer(nn.Module):
+ def __init__(self,
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias,
+ reduction=2):
+ super(CAMLayer, self).__init__()
+ self.linear_local = nn.Conv1d(bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
+ self.relu = nn.ReLU(inplace=True)
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ y = self.linear_local(x)
+ context = x.mean(-1, keepdim=True)+self.seg_pooling(x)
+ context = self.relu(self.linear1(context))
+ m = self.sigmoid(self.linear2(context))
+ return y*m
+
+ def seg_pooling(self, x, seg_len=100, stype='avg'):
+ if stype == 'avg':
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ elif stype == 'max':
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ else:
+ raise ValueError('Wrong segment pooling type.')
+ shape = seg.shape
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
+ seg = seg[..., :x.shape[-1]]
+ return seg
+
+
+class CAMDenseTDNNLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu',
+ memory_efficient=False):
+ super(CAMDenseTDNNLayer, self).__init__()
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+ kernel_size)
+ padding = (kernel_size - 1) // 2 * dilation
+ self.memory_efficient = memory_efficient
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
+ self.cam_layer = CAMLayer(bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ def bn_function(self, x):
+ return self.linear1(self.nonlinear1(x))
+
+ def forward(self, x):
+ if self.training and self.memory_efficient:
+ x = cp.checkpoint(self.bn_function, x)
+ else:
+ x = self.bn_function(x)
+ x = self.cam_layer(self.nonlinear2(x))
+ return x
+
+
+class CAMDenseTDNNBlock(nn.ModuleList):
+ def __init__(self,
+ num_layers,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu',
+ memory_efficient=False):
+ super(CAMDenseTDNNBlock, self).__init__()
+ for i in range(num_layers):
+ layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
+ out_channels=out_channels,
+ bn_channels=bn_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ bias=bias,
+ config_str=config_str,
+ memory_efficient=memory_efficient)
+ self.add_module('tdnnd%d' % (i + 1), layer)
+
+ def forward(self, x):
+ for layer in self:
+ x = torch.cat([x, layer(x)], dim=1)
+ return x
+
+
+class TransitLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=True,
+ config_str='batchnorm-relu'):
+ super(TransitLayer, self).__init__()
+ self.nonlinear = get_nonlinear(config_str, in_channels)
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+
+ def forward(self, x):
+ x = self.nonlinear(x)
+ x = self.linear(x)
+ return x
+
+
+class DenseLayer(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=False,
+ config_str='batchnorm-relu'):
+ super(DenseLayer, self).__init__()
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ if len(x.shape) == 2:
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
+ else:
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+
+class BasicResBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicResBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes,
+ planes,
+ kernel_size=3,
+ stride=(stride, 1),
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=(stride, 1),
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
diff --git a/speaker_diarization/speakerlab/models/resnet/ResNet.py b/speaker_diarization/speakerlab/models/resnet/ResNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..943989a3be5a4aef8d16f9ebb7c6a25b0985b6a0
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/resnet/ResNet.py
@@ -0,0 +1,108 @@
+""" ResNet implementation is adapted from https://github.com/wenet-e2e/wespeaker.
+ Reference: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+"""
+
+import torch
+import math
+import torch.nn as nn
+import torch.nn.functional as F
+import speakerlab.models.eres2net.pooling_layers as pooling_layers
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self,
+ block=BasicBlock,
+ num_blocks=[3, 4, 6, 3],
+ m_channels=32,
+ feat_dim=40,
+ embedding_size=128,
+ pooling_func='TSTP',
+ two_emb_layer=True):
+ super(ResNet, self).__init__()
+ self.in_planes = m_channels
+ self.feat_dim = feat_dim
+ self.embedding_size = embedding_size
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
+ self.two_emb_layer = two_emb_layer
+
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(m_channels)
+
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, m_channels * 4, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, m_channels * 8, num_blocks[3], stride=2)
+
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
+ self.pool = getattr(pooling_layers, pooling_func)(
+ in_dim=self.stats_dim * block.expansion)
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
+ if self.two_emb_layer:
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
+ else:
+ self.seg_bn_1 = nn.Identity()
+ self.seg_2 = nn.Identity()
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+ x = x.unsqueeze_(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out1 = self.layer1(out)
+ out2 = self.layer2(out1)
+ out3 = self.layer3(out2)
+ out = self.layer4(out3)
+ stats = self.pool(out)
+
+ embed_a = self.seg_1(stats)
+ if self.two_emb_layer:
+ out = F.relu(embed_a)
+ out = self.seg_bn_1(out)
+ embed_b = self.seg_2(out)
+ return embed_b
+ else:
+ return embed_a
+
+
+if __name__ == '__main__':
+
+ x = torch.zeros(10, 300, 80)
+ model = ResNet(feat_dim=80, embedding_size=192, pooling_func='TSTP')
+ model.eval()
+ out = model(x)
+ print(out.shape) # torch.Size([10, 192])
+
+ num_params = sum(param.numel() for param in model.parameters())
+ print("{} M".format(num_params / 1e6)) # 6.34M
diff --git a/speaker_diarization/speakerlab/models/talknet/attentionLayer.py b/speaker_diarization/speakerlab/models/talknet/attentionLayer.py
new file mode 100644
index 0000000000000000000000000000000000000000..17853119aea22050ada569b2c0d1bbfa918b9074
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/talknet/attentionLayer.py
@@ -0,0 +1,35 @@
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn import MultiheadAttention
+
+class attentionLayer(nn.Module):
+
+ def __init__(self, d_model, nhead, dropout=0.1):
+ super(attentionLayer, self).__init__()
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
+
+ self.linear1 = nn.Linear(d_model, d_model * 4)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_model * 4, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = F.relu
+
+ def forward(self, src, tar):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
+ src = src.transpose(0, 1) # B, T, C -> T, B, C
+ tar = tar.transpose(0, 1) # B, T, C -> T, B, C
+ src2 = self.self_attn(tar, src, src, attn_mask=None,
+ key_padding_mask=None)[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ src = src.transpose(0, 1) # T, B, C -> B, T, C
+ return src
diff --git a/speaker_diarization/speakerlab/models/talknet/audioEncoder.py b/speaker_diarization/speakerlab/models/talknet/audioEncoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..262a45dbcce1d12f76902b7a1688549e66a9193d
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/talknet/audioEncoder.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class SEBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
+ super(SEBasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.se = SELayer(planes, reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.relu(out)
+ out = self.bn1(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+ return out
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=8):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+class audioEncoder(nn.Module):
+ def __init__(self, layers, num_filters, **kwargs):
+ super(audioEncoder, self).__init__()
+ block = SEBasicBlock
+ self.inplanes = num_filters[0]
+
+ self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=7, stride=(2, 1), padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, num_filters[0], layers[0])
+ self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
+ self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
+ self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(1, 1))
+ out_dim = num_filters[3] * block.expansion
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = torch.mean(x, dim=2, keepdim=True)
+ x = x.view((x.size()[0], x.size()[1], -1))
+ x = x.transpose(1, 2)
+
+ return x
diff --git a/speaker_diarization/speakerlab/models/talknet/talknet.py b/speaker_diarization/speakerlab/models/talknet/talknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eac2116d72eea39d8a03fc517fa28af20f4f033f
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/talknet/talknet.py
@@ -0,0 +1,69 @@
+import torch
+import torch.nn as nn
+from speakerlab.models.talknet.audioEncoder import audioEncoder
+from speakerlab.models.talknet.visualEncoder import visualFrontend, visualTCN, visualConv1D
+from speakerlab.models.talknet.attentionLayer import attentionLayer
+
+class talkNetModel(nn.Module):
+ """
+ TalkNet model for active speaker detection task.
+ Reference:
+ - Is someone talking? TalkNet: Audio-visual active speaker detection Model.
+ - https://github.com/TaoRuijie/TalkNet-ASD
+ """
+ def __init__(self):
+ super(talkNetModel, self).__init__()
+ # Visual Temporal Encoder
+ self.visualFrontend = visualFrontend() # Visual Frontend
+ self.visualTCN = visualTCN() # Visual Temporal Network TCN
+ self.visualConv1D = visualConv1D() # Visual Temporal Network Conv1d
+
+ # Audio Temporal Encoder
+ self.audioEncoder = audioEncoder(layers = [3, 4, 6, 3], num_filters = [16, 32, 64, 128])
+
+ # Audio-visual Cross Attention
+ self.crossA2V = attentionLayer(d_model = 128, nhead = 8)
+ self.crossV2A = attentionLayer(d_model = 128, nhead = 8)
+
+ # Audio-visual Self Attention
+ self.selfAV = attentionLayer(d_model = 256, nhead = 8)
+
+ # Classifier
+ self.fcAV = nn.Linear(256, 2)
+ self.fcA = nn.Linear(128, 2)
+ self.fcV = nn.Linear(128, 2)
+
+ def visual_frontend(self, x):
+ B, T, W, H = x.shape
+ x = x.view(B*T, 1, 1, W, H)
+ x = (x / 255 - 0.4161) / 0.1688
+ x = self.visualFrontend(x)
+ x = x.view(B, T, 512)
+ x = x.transpose(1,2)
+ x = self.visualTCN(x)
+ x = self.visualConv1D(x)
+ x = x.transpose(1,2)
+ return x
+
+ def audio_frontend(self, x):
+ x = x.unsqueeze(1).transpose(2, 3)
+ x = self.audioEncoder(x)
+ return x
+
+ def cross_attention(self, x1, x2):
+ x1_c = self.crossA2V(src = x1, tar = x2)
+ x2_c = self.crossV2A(src = x2, tar = x1)
+ return x1_c, x2_c
+
+ def audio_visual_backend(self, x1, x2):
+ x = torch.cat((x1,x2), 2)
+ x = self.selfAV(src = x, tar = x)
+ return x
+
+ def forward(self, audioX, visualX):
+ audioX = self.audio_frontend(audioX)
+ visualX = self.visual_frontend(visualX)
+ audioX, visualX = self.cross_attention(audioX, visualX)
+ audio_visualX = self.audio_visual_backend(audioX, visualX)
+
+ return self.fcAV(audio_visualX), self.fcA(audioX), self.fcV(visualX)
diff --git a/speaker_diarization/speakerlab/models/talknet/visualEncoder.py b/speaker_diarization/speakerlab/models/talknet/visualEncoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..be9d954bcfdf4aedcf63324dc90ad743e54a65c1
--- /dev/null
+++ b/speaker_diarization/speakerlab/models/talknet/visualEncoder.py
@@ -0,0 +1,163 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResNetLayer(nn.Module):
+
+ """
+ A ResNet layer used to build the ResNet network.
+ Architecture:
+ --> conv-bn-relu -> conv -> + -> bn-relu -> conv-bn-relu -> conv -> + -> bn-relu -->
+ | | | |
+ -----> downsample ------> ------------------------------------->
+ """
+
+ def __init__(self, inplanes, outplanes, stride):
+ super(ResNetLayer, self).__init__()
+ self.conv1a = nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1a = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+ self.conv2a = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.downsample = nn.Sequential()
+ if stride != 1:
+ self.downsample = nn.Conv2d(inplanes, outplanes, kernel_size=(1,1), stride=stride, bias=False)
+ self.outbna = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+
+ self.conv1b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1b = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+ self.conv2b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.outbnb = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001)
+
+ def forward(self, inputBatch):
+ batch = F.relu(self.bn1a(self.conv1a(inputBatch)))
+ batch = self.conv2a(batch)
+ residualBatch = self.downsample(inputBatch)
+ batch = batch + residualBatch
+ intermediateBatch = batch
+ batch = F.relu(self.outbna(batch))
+
+ batch = F.relu(self.bn1b(self.conv1b(batch)))
+ batch = self.conv2b(batch)
+ residualBatch = intermediateBatch
+ batch = batch + residualBatch
+ outputBatch = F.relu(self.outbnb(batch))
+ return outputBatch
+
+
+
+class ResNet(nn.Module):
+
+ """
+ An 18-layer ResNet architecture.
+ """
+
+ def __init__(self):
+ super(ResNet, self).__init__()
+ self.layer1 = ResNetLayer(64, 64, stride=1)
+ self.layer2 = ResNetLayer(64, 128, stride=2)
+ self.layer3 = ResNetLayer(128, 256, stride=2)
+ self.layer4 = ResNetLayer(256, 512, stride=2)
+ self.avgpool = nn.AvgPool2d(kernel_size=(4,4), stride=(1,1))
+
+ return
+
+
+ def forward(self, inputBatch):
+ batch = self.layer1(inputBatch)
+ batch = self.layer2(batch)
+ batch = self.layer3(batch)
+ batch = self.layer4(batch)
+ outputBatch = self.avgpool(batch)
+ return outputBatch
+
+
+class GlobalLayerNorm(nn.Module):
+ def __init__(self, channel_size):
+ super(GlobalLayerNorm, self).__init__()
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ self.gamma.data.fill_(1)
+ self.beta.data.zero_()
+
+ def forward(self, y):
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
+ var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + 1e-8, 0.5) + self.beta
+ return gLN_y
+
+class visualFrontend(nn.Module):
+
+ """
+ A visual feature extraction module. Generates a 512-dim feature vector per video frame.
+ Architecture: A 3D convolution block followed by an 18-layer ResNet.
+ """
+
+ def __init__(self):
+ super(visualFrontend, self).__init__()
+ self.frontend3D = nn.Sequential(
+ nn.Conv3d(1, 64, kernel_size=(5,7,7), stride=(1,2,2), padding=(2,3,3), bias=False),
+ nn.BatchNorm3d(64, momentum=0.01, eps=0.001),
+ nn.ReLU(),
+ nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))
+ )
+ self.resnet = ResNet()
+ return
+
+
+ def forward(self, inputBatch):
+ inputBatch = inputBatch.transpose(0, 1).transpose(1, 2)
+ batchsize = inputBatch.shape[0]
+ batch = self.frontend3D(inputBatch)
+
+ batch = batch.transpose(1, 2)
+ batch = batch.reshape(batch.shape[0]*batch.shape[1], batch.shape[2], batch.shape[3], batch.shape[4])
+ outputBatch = self.resnet(batch)
+ outputBatch = outputBatch.reshape(batchsize, -1, 512)
+ outputBatch = outputBatch.transpose(1 ,2)
+ outputBatch = outputBatch.transpose(1, 2).transpose(0, 1)
+ return outputBatch
+
+class DSConv1d(nn.Module):
+ def __init__(self):
+ super(DSConv1d, self).__init__()
+ self.net = nn.Sequential(
+ nn.ReLU(),
+ nn.BatchNorm1d(512),
+ nn.Conv1d(512, 512, 3, stride=1, padding=1,dilation=1, groups=512, bias=False),
+ nn.PReLU(),
+ GlobalLayerNorm(512),
+ nn.Conv1d(512, 512, 1, bias=False),
+ )
+
+ def forward(self, x):
+ out = self.net(x)
+ return out + x
+
+class visualTCN(nn.Module):
+ def __init__(self):
+ super(visualTCN, self).__init__()
+ stacks = []
+ for x in range(5):
+ stacks += [DSConv1d()]
+ self.net = nn.Sequential(*stacks) # Visual Temporal Network V-TCN
+
+ def forward(self, x):
+ out = self.net(x)
+ return out
+
+class visualConv1D(nn.Module):
+ def __init__(self):
+ super(visualConv1D, self).__init__()
+ self.net = nn.Sequential(
+ nn.Conv1d(512, 256, 5, stride=1, padding=2),
+ nn.BatchNorm1d(256),
+ nn.ReLU(),
+ nn.Conv1d(256, 128, 1),
+ )
+
+ def forward(self, x):
+ out = self.net(x)
+ return out
diff --git a/speaker_diarization/speakerlab/process/augmentation.py b/speaker_diarization/speakerlab/process/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3785150a4655c9468c8d1c77842b834f3fce15a3
--- /dev/null
+++ b/speaker_diarization/speakerlab/process/augmentation.py
@@ -0,0 +1,92 @@
+import torch
+import torchaudio
+from scipy import signal
+import numpy as np
+import random
+
+from speakerlab.utils.fileio import load_wav_scp
+
+def addreverb(wav, rir_wav):
+ # wav: [T,], rir_wav: [T,]
+ wav = wav.numpy()
+ rir_wav = rir_wav.numpy()
+ wav_len = wav.shape[0]
+ rir_wav = rir_wav / np.sqrt(np.sum(rir_wav**2))
+ out_wav = signal.convolve(wav, rir_wav,
+ mode='full')[:wav_len]
+
+ out_wav = out_wav / (np.max(np.abs(out_wav)) + 1e-6)
+ return torch.from_numpy(out_wav)
+
+def addnoise(wav, noise=None, snr_high=15, snr_low=0):
+ # wav: [T,], noise: [T,]
+ if noise is None:
+ noise = torch.randn_like(wav)
+ noise = noise.numpy()
+ wav = wav.numpy()
+
+ wav_len = wav.shape[0]
+ noise_len = noise.shape[0]
+ if noise_len >= wav_len:
+ start = random.randint(0, noise_len - wav_len)
+ noise = noise[start:start + wav_len]
+ else:
+ noise = noise.repeat(wav_len // noise_len + 1)
+ noise = noise[:wav_len]
+
+ wav_db = 10 * np.log10(np.mean(wav**2) + 1e-6)
+ noise_db = 10 * np.log10(np.mean(noise**2) + 1e-6)
+ noise_snr = random.uniform(snr_low, snr_high)
+ noise = np.sqrt(10**(
+ (wav_db - noise_db - noise_snr) / 10)) * noise
+ out_wav = wav + noise
+
+ out_wav = out_wav / (np.max(np.abs(out_wav)) + 1e-6)
+ return torch.from_numpy(out_wav)
+
+
+class NoiseReverbCorrupter(object):
+ def __init__(
+ self,
+ noise_prob=0.0,
+ reverb_prob=0.0,
+ noise_file=None,
+ reverb_file=None,
+ noise_snr_low=0,
+ noise_snr_high=15,
+ ):
+ if reverb_prob > 0.0:
+ if reverb_file is None:
+ raise ValueError('Reverb_file not be assigned.')
+ self.add_reverb = addreverb
+ self.reverb_data = load_wav_scp(reverb_file)
+ self.reverb_data_keys = list(self.reverb_data.keys())
+
+ if noise_prob > 0.0:
+ if noise_file is None:
+ raise ValueError('Noise_file not be assigned.')
+
+ self.add_noise = addnoise
+ self.noise_data = load_wav_scp(noise_file)
+ self.noise_data_keys = list(self.noise_data.keys())
+
+ self.reverb_prob = reverb_prob
+ self.noise_prob = noise_prob
+ self.noise_snr_low = noise_snr_low
+ self.noise_snr_high = noise_snr_high
+
+ def __call__(self, wav, fs=16000):
+ if self.reverb_prob > random.random():
+ reverb_path = self.reverb_data[random.choice(self.reverb_data_keys)]
+ reverb, fs_rir = torchaudio.load(reverb_path)
+ assert fs_rir == fs
+ wav = self.add_reverb(wav, reverb[0])
+ if self.noise_prob > random.random():
+ noise_path = self.noise_data[random.choice(self.noise_data_keys)]
+ noise, fs_noise = torchaudio.load(noise_path)
+ assert fs_noise == fs
+ wav = self.add_noise(
+ wav, noise[0],
+ snr_high=self.noise_snr_high,
+ snr_low=self.noise_snr_low,)
+ return wav
diff --git a/speaker_diarization/speakerlab/process/cluster.py b/speaker_diarization/speakerlab/process/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..55555d3c41885ae23a51eb587dae3dc73c48b797
--- /dev/null
+++ b/speaker_diarization/speakerlab/process/cluster.py
@@ -0,0 +1,362 @@
+import numpy as np
+import scipy
+import sklearn
+from sklearn.cluster._kmeans import k_means
+from sklearn.metrics.pairwise import cosine_similarity
+
+import fastcluster
+from scipy.cluster.hierarchy import fcluster
+from scipy.spatial.distance import squareform
+
+try:
+ import umap, hdbscan
+except ImportError:
+ raise ImportError(
+ "Package \"umap\" or \"hdbscan\" not found. \
+ Please install them first by \"pip install umap-learn hdbscan\"."
+ )
+
+
+class SpectralCluster:
+ """A spectral clustering method using unnormalized Laplacian of affinity matrix.
+ This implementation is adapted from https://github.com/speechbrain/speechbrain.
+ """
+
+ def __init__(self, min_num_spks=1, max_num_spks=10, pval=0.02, min_pnum=6, oracle_num=None):
+ self.min_num_spks = min_num_spks
+ self.max_num_spks = max_num_spks
+ self.min_pnum = min_pnum
+ self.pval = pval
+ self.k = oracle_num
+
+ def __call__(self, X, **kwargs):
+ pval = kwargs.get('pval', None)
+ oracle_num = kwargs.get('speaker_num', None)
+
+ # Similarity matrix computation
+ sim_mat = self.get_sim_mat(X)
+
+ # Refining similarity matrix with pval
+ prunned_sim_mat = self.p_pruning(sim_mat, pval)
+
+ # Symmetrization
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
+
+ # Laplacian calculation
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
+
+ # Get Spectral Embeddings
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
+
+ # Perform clustering
+ labels = self.cluster_embs(emb, num_of_spk)
+
+ return labels
+
+ def get_sim_mat(self, X):
+ # Cosine similarities
+ M = cosine_similarity(X, X)
+ return M
+
+ def p_pruning(self, A, pval=None):
+ if pval is None:
+ pval = self.pval
+ n_elems = int((1 - pval) * A.shape[0])
+ n_elems = min(n_elems, A.shape[0]-self.min_pnum)
+
+ # For each row in a affinity matrix
+ for i in range(A.shape[0]):
+ low_indexes = np.argsort(A[i, :])
+ low_indexes = low_indexes[0:n_elems]
+
+ # Replace smaller similarity values by 0s
+ A[i, low_indexes] = 0
+ return A
+
+ def get_laplacian(self, M):
+ M[np.diag_indices(M.shape[0])] = 0
+ D = np.sum(np.abs(M), axis=1)
+ D = np.diag(D)
+ L = D - M
+ return L
+
+ def get_spec_embs(self, L, k_oracle=None):
+ if k_oracle is None:
+ k_oracle = self.k
+
+ lambdas, eig_vecs = scipy.sparse.linalg.eigsh(L, k=min(self.max_num_spks+1, L.shape[0]), which='SM')
+
+ if k_oracle is not None:
+ num_of_spk = k_oracle
+ else:
+ lambda_gap_list = self.getEigenGaps(
+ lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
+ num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
+
+ emb = eig_vecs[:, :num_of_spk]
+ return emb, num_of_spk
+
+ def cluster_embs(self, emb, k):
+ # k-means
+ _, labels, _ = k_means(emb, k)
+ return labels
+
+ def getEigenGaps(self, eig_vals):
+ eig_vals_gap_list = []
+ for i in range(len(eig_vals) - 1):
+ gap = float(eig_vals[i + 1]) - float(eig_vals[i])
+ eig_vals_gap_list.append(gap)
+ return eig_vals_gap_list
+
+
+class UmapHdbscan:
+ """
+ Reference:
+ - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
+ Emphasis On Topological Structure. ICASSP2022
+ """
+
+ def __init__(self, n_neighbors=20, n_components=60, min_samples=20, min_cluster_size=10, metric='euclidean'):
+ self.n_neighbors = n_neighbors
+ self.n_components = n_components
+ self.min_samples = min_samples
+ self.min_cluster_size = min_cluster_size
+ self.metric = metric
+
+ def __call__(self, X, **kwargs):
+ umap_X = umap.UMAP(
+ n_neighbors=self.n_neighbors,
+ min_dist=0.0,
+ n_components=min(self.n_components, X.shape[0]-2),
+ metric=self.metric,
+ ).fit_transform(X)
+ labels = hdbscan.HDBSCAN(min_samples=self.min_samples, min_cluster_size=self.min_cluster_size).fit_predict(umap_X)
+ return labels
+
+class AHCluster:
+ """
+ Agglomerative Hierarchical Clustering, a bottom-up approach which iteratively merges
+ the closest clusters until a termination condition is reached.
+ This implementation is adapted from https://github.com/BUTSpeechFIT/VBx.
+ """
+
+ def __init__(self, fix_cos_thr=0.4):
+ self.fix_cos_thr = fix_cos_thr
+
+ def __call__(self, X, **kwargs):
+ scr_mx = cosine_similarity(X)
+ scr_mx = squareform(-scr_mx, checks=False)
+ lin_mat = fastcluster.linkage(scr_mx, method='average', preserve_input='False')
+ adjust = abs(lin_mat[:, 2].min())
+ lin_mat[:, 2] += adjust
+ labels = fcluster(lin_mat, -self.fix_cos_thr + adjust, criterion='distance') - 1
+ return labels
+
+
+class CommonClustering:
+ """Perfom clustering for input embeddings and output the labels.
+ """
+
+ def __init__(self, cluster_type, cluster_line=40, mer_cos=None, min_cluster_size=4, **kwargs):
+ self.cluster_type = cluster_type
+ self.cluster_line = cluster_line
+ self.min_cluster_size = min_cluster_size
+ self.mer_cos = mer_cos
+ if self.cluster_type == 'spectral':
+ self.cluster = SpectralCluster(**kwargs)
+ elif self.cluster_type == 'umap_hdbscan':
+ kwargs['min_cluster_size'] = min_cluster_size
+ self.cluster = UmapHdbscan(**kwargs)
+ elif self.cluster_type == 'AHC':
+ self.cluster = AHCluster(**kwargs)
+ else:
+ raise ValueError(
+ '%s is not currently supported.' % self.cluster_type
+ )
+ if self.cluster_type != 'AHC':
+ self.cluster_for_short = AHCluster()
+ else:
+ self.cluster_for_short = self.cluster
+
+ def __call__(self, X, **kwargs):
+ # clustering and return the labels
+ assert len(X.shape) == 2, 'Shape of input should be [N, C]'
+ if X.shape[0] <= 1:
+ return np.zeros(X.shape[0], dtype=int)
+ if X.shape[0] < self.cluster_line:
+ labels = self.cluster_for_short(X)
+ else:
+ labels = self.cluster(X, **kwargs)
+
+ # remove extremely minor cluster
+ labels = self.filter_minor_cluster(labels, X, self.min_cluster_size)
+ # merge similar speaker
+ if self.mer_cos is not None:
+ labels = self.merge_by_cos(labels, X, self.mer_cos)
+
+ return labels
+
+ def filter_minor_cluster(self, labels, x, min_cluster_size):
+ cset = np.unique(labels)
+ csize = np.array([(labels == i).sum() for i in cset])
+ minor_idx = np.where(csize <= self.min_cluster_size)[0]
+ if len(minor_idx) == 0:
+ return labels
+
+ minor_cset = cset[minor_idx]
+ major_idx = np.where(csize > self.min_cluster_size)[0]
+ if len(major_idx) == 0:
+ return np.zeros_like(labels)
+ major_cset = cset[major_idx]
+ major_center = np.stack([x[labels == i].mean(0) \
+ for i in major_cset])
+ for i in range(len(labels)):
+ if labels[i] in minor_cset:
+ cos_sim = cosine_similarity(x[i][np.newaxis], major_center)
+ labels[i] = major_cset[cos_sim.argmax()]
+
+ return labels
+
+ def merge_by_cos(self, labels, x, cos_thr):
+ # merge the similar speakers by cosine similarity
+ assert cos_thr > 0 and cos_thr <= 1
+ while True:
+ cset = np.unique(labels)
+ if len(cset) == 1:
+ break
+ centers = np.stack([x[labels == i].mean(0) \
+ for i in cset])
+ affinity = cosine_similarity(centers, centers)
+ affinity = np.triu(affinity, 1)
+ idx = np.unravel_index(np.argmax(affinity), affinity.shape)
+ if affinity[idx] < cos_thr:
+ break
+ c1, c2 = cset[np.array(idx)]
+ labels[labels==c2]=c1
+ return labels
+
+
+class JointClustering:
+ """Perfom joint clustering for input audio and visual embeddings and output the labels.
+ """
+
+ def __init__(self, audio_cluster, vision_cluster):
+ self.audio_cluster = audio_cluster
+ self.vision_cluster = vision_cluster
+
+ def __call__(self, audioX, visionX, audioT, visionT, conf):
+ # audio-only and video-only clustering
+ alabels = self.audio_cluster(audioX)
+ vlabels = self.vision_cluster(visionX)
+
+ alabels = self.arrange_labels(alabels)
+ vlist, vspk_embs, vspk_dur = self.get_vlist_embs(audioX, alabels, vlabels, audioT, visionT, conf)
+
+ # modify alabels according to vlabels
+ aspk_num = alabels.max()+1
+ for i in range(aspk_num):
+ aspki_index = np.where(alabels==i)[0]
+ aspki_embs = audioX[alabels==i]
+
+ aspkiT_part = np.array(audioT)[alabels==i]
+ overlap_vspk = self.overlap_spks(self.cast_overlap(aspkiT_part), vlist, vspk_dur)
+ if len(overlap_vspk) > 1:
+ centers = np.stack([vspk_embs[s] for s in overlap_vspk])
+ distribute_labels = self.distribute_embs(aspki_embs, centers)
+ for j in range(distribute_labels.max()+1):
+ for loc in aspki_index[distribute_labels==j]:
+ alabels[loc] = overlap_vspk[j]
+ elif len(overlap_vspk) == 1:
+ for loc in aspki_index:
+ alabels[loc] = overlap_vspk[0]
+
+ alabels = self.arrange_labels(alabels)
+ return alabels
+
+ def overlap_spks(self, times, vlist, vspk_dur=None):
+ # get the vspk that overlaps with times.
+ overlap_dur = {}
+ for [a_st, a_ed] in times:
+ for [v_st, v_ed, v_id] in vlist:
+ if a_ed > v_st and v_ed > a_st:
+ if v_id not in overlap_dur:
+ overlap_dur[v_id]=0
+ overlap_dur[v_id] += min(a_ed, v_ed) - max(a_st, v_st)
+ vspk_list = []
+ for v_id, dur in overlap_dur.items():
+ # set the criteria for confirming overlap.
+ if (vspk_dur is None and dur > 0.5) or (vspk_dur is not None and dur > min(vspk_dur[v_id]*0.5, 0.5)):
+ vspk_list.append(v_id)
+ return vspk_list
+
+ def distribute_embs(self, embs, centers):
+ # embs: [n, D]. centers: [k, D]
+ norm_centers = centers / np.linalg.norm(centers, axis=1, keepdims=True)
+ norm_embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
+ similarity = np.matmul(norm_embs, norm_centers.T) # [n, k]
+ argsort = np.argsort(similarity, axis=-1)
+ return argsort[:, -1]
+
+ def get_vlist_embs(self, audioX, alabels, vlabels, audioT, visionT, conf):
+ assert len(vlabels) == len(visionT)
+ vlist = []
+ for i, ti in enumerate(visionT):
+ if len(vlist)==0 or vlabels[i] != vlist[-1][2] or ti - visionT[i-1] > conf.face_det_stride*0.04 + 1e-4:
+ if len(vlist) > 0 and vlist[-1][1] - vlist[-1][0] < 1e-4:
+ # remove too short intervals.
+ vlist.pop()
+ vlist.append([ti, ti, vlabels[i]])
+ else:
+ vlist[-1][1] = ti
+
+ # adjust vision labels
+ vlabels_arrange = self.arrange_labels([i[2] for i in vlist], a_st=alabels.max()+1)
+ vlist = [[i[0], i[1], j] for i, j in zip(vlist, vlabels_arrange)]
+
+ # get audio spk embs aligning with 'vlist'
+ vspk_embs = {}
+ for [v_st, v_ed, v_id] in vlist:
+ for i, [a_st, a_ed] in enumerate(audioT):
+ if a_ed >= v_st and v_ed >= a_st:
+ if min(a_ed, v_ed) - max(a_st, v_st) > 1:
+ if v_id not in vspk_embs:
+ vspk_embs[v_id] = []
+ vspk_embs[v_id].append(audioX[i])
+ for k in vspk_embs:
+ vspk_embs[k] = np.stack(vspk_embs[k]).mean(0)
+
+ vlist_new = []
+ for i in vlist:
+ if i[2] in vspk_embs:
+ vlist_new.append(i)
+ # get duration of v_spk
+ vspk_dur = {}
+ for i in vlist_new:
+ if i[2] not in vspk_dur:
+ vspk_dur[i[2]]=0
+ vspk_dur[i[2]] += i[1]-i[0]
+
+ return vlist_new, vspk_embs, vspk_dur
+
+ def cast_overlap(self, input_time):
+ if len(input_time)==0:
+ return input_time
+ output_time = []
+ for i in range(0, len(input_time)-1):
+ if i == 0 or output_time[-1][1] < input_time[i][0]:
+ output_time.append(input_time[i])
+ else:
+ output_time[-1][1] = input_time[i][1]
+ return output_time
+
+ def arrange_labels(self, labels, a_st=0):
+ # arrange labels in order from 0.
+ new_labels = []
+ labels_dict = {}
+ idx = a_st
+ for i in labels:
+ if i not in labels_dict:
+ labels_dict[i] = idx
+ idx += 1
+ new_labels.append(labels_dict[i])
+ return np.array(new_labels)
diff --git a/speaker_diarization/speakerlab/process/processor.py b/speaker_diarization/speakerlab/process/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cb7fde30907b7230eda34566a8225f6f3dae53b
--- /dev/null
+++ b/speaker_diarization/speakerlab/process/processor.py
@@ -0,0 +1,155 @@
+import random
+import pickle
+import torch
+import torchaudio
+import torch.nn.functional as F
+import torchaudio.compliance.kaldi as Kaldi
+
+from speakerlab.process.augmentation import NoiseReverbCorrupter
+from speakerlab.utils.fileio import load_data_csv
+
+
+class WavReader(object):
+ def __init__(self,
+ sample_rate = 16000,
+ duration: float = 3.0,
+ speed_pertub: bool = False,
+ lm: bool = True,
+ ):
+ self.duration = duration
+ self.sample_rate = sample_rate
+ self.speed_pertub = speed_pertub
+ self.lm = lm
+
+ def __call__(self, wav_path):
+ wav, sr = torchaudio.load(wav_path)
+ assert sr == self.sample_rate
+ wav = wav[0]
+
+ if self.speed_pertub and self.lm:
+ speeds = [1.0, 0.9, 1.1]
+ speed_idx = random.randint(0, 2)
+ if speed_idx > 0:
+ wav, _ = torchaudio.sox_effects.apply_effects_tensor(
+ wav.unsqueeze(0), self.sample_rate, [['speed', str(speeds[speed_idx])], ['rate', str(self.sample_rate)]])
+ else:
+ speed_idx = 0
+
+ wav = wav.squeeze(0)
+ data_len = wav.shape[0]
+
+ chunk_len = int(self.duration * sr)
+ if data_len >= chunk_len:
+ start = random.randint(0, data_len - chunk_len)
+ end = start + chunk_len
+ wav = wav[start:end]
+ else:
+ wav = F.pad(wav, (0, chunk_len - data_len))
+
+ return wav, speed_idx
+
+class SpkLabelEncoder(object):
+ def __init__(self, data_file):
+ self.lab2ind = {}
+ self.ind2lab = {}
+ self.starting_index = -1
+ self.load_from_csv(data_file)
+
+ def __call__(self, spk, speed_idx=0):
+ spkid = self.lab2ind[spk]
+ spkid = spkid + len(self.lab2ind) * speed_idx
+ return spkid
+
+ def load_from_csv(self, path):
+ self.data = load_data_csv(path)
+ for key in self.data:
+ self.add(self.data[key]['spk'])
+
+ def add(self, label):
+ if label in self.lab2ind:
+ return
+ index = self._next_index()
+ self.lab2ind[label] = index
+ self.ind2lab[index] = label
+
+ def _next_index(self):
+ self.starting_index += 1
+ return self.starting_index
+
+ def __len__(self):
+ return len(self.lab2ind)
+
+ def save(self, path, device=None):
+ with open(path, 'wb') as f:
+ pickle.dump(self.lab2ind, f)
+
+ def load(self, path, device=None):
+ self.lab2ind = {}
+ self.ind2lab = {}
+ with open(path, 'rb') as f:
+ self.lab2ind = pickle.load(f)
+ for label in self.lab2ind:
+ self.ind2lab[self.lab2ind[label]] = label
+
+
+class SpkVeriAug(object):
+ def __init__(
+ self,
+ aug_prob: float = 0.0,
+ noise_file: str = None,
+ reverb_file: str = None,
+ ):
+ self.aug_prob = aug_prob
+ if aug_prob > 0:
+ self.add_noise = NoiseReverbCorrupter(
+ noise_prob=1.0,
+ noise_file=noise_file,
+ )
+ self.add_rir = NoiseReverbCorrupter(
+ reverb_prob=1.0,
+ reverb_file=reverb_file,
+ )
+ self.add_rir_noise = NoiseReverbCorrupter(
+ noise_prob=1.0,
+ reverb_prob=1.0,
+ noise_file=noise_file,
+ reverb_file=reverb_file,
+ )
+
+ self.augmentations = [self.add_noise, self.add_rir, self.add_rir_noise]
+
+ def __call__(self, wav):
+ sample_rate = 16000
+ if self.aug_prob > random.random():
+ aug = random.choice(self.augmentations)
+ wav = aug(wav, sample_rate)
+
+ return wav
+
+
+class FBank(object):
+ def __init__(self,
+ n_mels,
+ sample_rate,
+ mean_nor: bool = False,
+ ):
+ self.n_mels = n_mels
+ self.sample_rate = sample_rate
+ self.mean_nor = mean_nor
+
+ def __call__(self, wav, dither=0):
+ sr = 16000
+ assert sr==self.sample_rate
+ if len(wav.shape) == 1:
+ wav = wav.unsqueeze(0)
+ # select single channel
+ if wav.shape[0] > 1:
+ wav = wav[0, :]
+ wav = wav.unsqueeze(0)
+ assert len(wav.shape) == 2 and wav.shape[0]==1
+ feat = Kaldi.fbank(wav, num_mel_bins=self.n_mels,
+ sample_frequency=sr, dither=dither)
+ # feat: [T, N]
+ if self.mean_nor:
+ feat = feat - feat.mean(0, keepdim=True)
+ return feat
diff --git a/speaker_diarization/speakerlab/utils/builder.py b/speaker_diarization/speakerlab/utils/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..baa0c7060733bc8f24015832f8976a3c74e15c91
--- /dev/null
+++ b/speaker_diarization/speakerlab/utils/builder.py
@@ -0,0 +1,88 @@
+import re
+import importlib
+from speakerlab.utils.config import Config
+
+
+def dynamic_import(import_path):
+ module_name, obj_name = import_path.rsplit('.', 1)
+ m = importlib.import_module(module_name)
+ return getattr(m, obj_name)
+
+def is_ref_type(value: str):
+ assert isinstance(value, str), 'Input value is not a str.'
+ if re.match('^<[a-zA-Z]\w*>$', value):
+ return True
+ else:
+ return False
+
+def is_built(ins):
+ if isinstance(ins, dict):
+ if 'obj' in ins and 'args' in ins:
+ return False
+ for i in ins.values():
+ if not is_built(i):
+ return False
+ elif isinstance(ins, str):
+ if '/' in ins: # reference may exist in a path string.
+ inss = ins.split('/')
+ return is_built(inss)
+ elif is_ref_type(ins):
+ return False
+ elif isinstance(ins, list):
+ for i in ins:
+ if not is_built(i):
+ return False
+ return True
+
+def deep_build(ins, config, build_space: set = None):
+ if is_built(ins):
+ return ins
+
+ if build_space is None:
+ build_space = set()
+
+ if isinstance(ins, list):
+ for i in range(len(ins)):
+ ins[i] = deep_build(ins[i], config, build_space)
+ return ins
+ elif isinstance(ins, dict):
+ if 'obj' in ins and 'args' in ins: # return a instantiated module.
+ obj = ins['obj']
+ args = ins['args']
+ assert isinstance(args, dict), f"Args for {obj} must be a dict."
+ args = deep_build(args, config, build_space)
+
+ module_cls = dynamic_import(obj)
+ mm = module_cls(**args)
+ return mm
+ else: # return a nomal dict.
+ for k in ins:
+ ins[k] = deep_build(ins[k], config, build_space)
+ return ins
+ elif isinstance(ins, str):
+ if '/' in ins: # reference may exist in a path string.
+ inss = ins.split('/')
+ inss = deep_build(inss, config, build_space)
+ ins = '/'.join(inss)
+ return ins
+ elif is_ref_type(ins):
+ ref = ins[1:-1]
+
+ if ref in build_space:
+ raise ValueError("Cross referencing is not allowed in config.")
+ build_space.add(ref)
+
+ assert hasattr(config, ref), f"Key name {ins} not found in config."
+ attr = getattr(config, ref)
+ attr = deep_build(attr, config, build_space)
+ setattr(config, ref, attr)
+
+ build_space.remove(ref)
+ return attr
+ else:
+ return ins
+ else:
+ return ins
+
+def build(name: str, config: Config):
+ return deep_build(f"<{name}>", config)
diff --git a/speaker_diarization/speakerlab/utils/config.py b/speaker_diarization/speakerlab/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..366865c071d4ad965c3f0d354f862ffb8b8612de
--- /dev/null
+++ b/speaker_diarization/speakerlab/utils/config.py
@@ -0,0 +1,49 @@
+import os
+import yaml
+
+class Config(object):
+ def __init__(self, conf_dict):
+ for key, value in conf_dict.items():
+ self.__dict__[key] = value
+
+
+def convert_to_yaml(overrides):
+ """Convert args to yaml for overrides"""
+ yaml_string = ""
+
+ # Handle '--arg=val' type args
+ joined_args = "=".join(overrides)
+ split_args = joined_args.split("=")
+
+ for arg in split_args:
+ if arg.startswith("--"):
+ yaml_string += "\n" + arg[len("--") :] + ":"
+ else:
+ yaml_string += " " + arg
+
+ return yaml_string.strip()
+
+
+def yaml_config_loader(conf_file, overrides=None):
+ with open(conf_file, "r") as fr:
+ conf_dict = yaml.load(fr, Loader=yaml.FullLoader)
+ if overrides is not None:
+ overrides = yaml.load(overrides, Loader=yaml.FullLoader)
+ conf_dict.update(overrides)
+ return conf_dict
+
+
+def build_config(config_file, overrides=None, copy=False):
+ if config_file.endswith(".yaml"):
+ if overrides is not None:
+ overrides = convert_to_yaml(overrides)
+ conf_dict = yaml_config_loader(config_file, overrides)
+ if copy and 'exp_dir' in conf_dict:
+ os.makedirs(conf_dict['exp_dir'], exist_ok=True)
+ saved_path = os.path.join(conf_dict['exp_dir'], 'config.yaml')
+ with open(saved_path, 'w') as f:
+ f.write(yaml.dump(conf_dict))
+ else:
+ raise ValueError("Unknown config file format")
+
+ return Config(conf_dict)
diff --git a/speaker_diarization/speakerlab/utils/epoch.py b/speaker_diarization/speakerlab/utils/epoch.py
new file mode 100644
index 0000000000000000000000000000000000000000..760cf3dbf2786ddafb162a73202b1f2eea1cccd1
--- /dev/null
+++ b/speaker_diarization/speakerlab/utils/epoch.py
@@ -0,0 +1,62 @@
+import logging
+logger = logging.getLogger(__name__)
+
+class EpochLogger(object):
+ def __init__(self, save_file, precision=2):
+ self.save_file = save_file
+ self.precision = precision
+
+ def item_to_string(self, key, value, prefix=None):
+ if isinstance(value, float) and 1.0 < value < 100.0:
+ value = f"{value:.{self.precision}f}"
+ elif isinstance(value, float):
+ value = f"{value:.{self.precision}e}"
+ if prefix is not None:
+ key = f"{prefix} {key}"
+ return f"{key}: {value}"
+
+ def stats_to_string(self, stats, prefix=None):
+ return ", ".join(
+ [self.item_to_string(k, v, prefix) for k, v in stats.items()]
+ )
+
+ def log_stats(
+ self,
+ stats_meta,
+ stats,
+ stage='train',
+ verbose=True,
+ ):
+ string = self.stats_to_string(stats_meta)
+ if stats is not None:
+ string += " - " + self.stats_to_string(stats, stage)
+
+ with open(self.save_file, "a") as fw:
+ print(string, file=fw)
+ if verbose:
+ logger.info(string)
+
+
+class EpochCounter(object):
+ def __init__(self, limit):
+ self.current = 0
+ self.limit = limit
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.current < self.limit:
+ self.current += 1
+ logger.info(f"Going into epoch {self.current}")
+ return self.current
+ raise StopIteration
+
+ def save(self, path, device=None):
+ with open(path, "w") as f:
+ f.write(str(self.current))
+
+ def load(self, path, device=None):
+ with open(path) as f:
+ saved_value = int(f.read())
+ self.current = saved_value
diff --git a/speaker_diarization/speakerlab/utils/fileio.py b/speaker_diarization/speakerlab/utils/fileio.py
new file mode 100644
index 0000000000000000000000000000000000000000..e958fac2160cd54327474a3949240dc0ec0283ef
--- /dev/null
+++ b/speaker_diarization/speakerlab/utils/fileio.py
@@ -0,0 +1,126 @@
+import csv
+import yaml
+import codecs
+import json
+import torch
+import torchaudio
+import numpy as np
+
+
+def load_yaml(yaml_path):
+ with open(yaml_path) as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+ return config
+
+
+def load_data_csv(fpath):
+ with open(fpath, newline="") as f:
+ result = {}
+ reader = csv.DictReader(f, skipinitialspace=True)
+ for row in reader:
+ if 'ID' not in row:
+ raise KeyError(
+ "CSV file has to have an 'ID' field, with unique ids for all data points."
+ )
+
+ data_id = row["ID"]
+ del row["ID"]
+
+ if data_id in result:
+ raise ValueError(f"Duplicate id: {data_id}")
+ result[data_id] = row
+ return result
+
+
+def load_data_list(fpath):
+ with open(fpath) as f:
+ rows = [i.strip() for i in f.readlines()]
+ result = {idx: row for idx, row in enumerate(rows)}
+ return result
+
+
+def load_wav_scp(fpath):
+ with open(fpath) as f:
+ rows = [i.strip() for i in f.readlines()]
+ result = {i.split()[0]: i.split()[1] for i in rows}
+ return result
+
+
+def load_json_file(json_file):
+ with codecs.open(json_file, "r", encoding="utf-8") as fr:
+ data_dict = json.load(fr)
+ return data_dict
+
+
+def load_trans7time_list(filename):
+ """
+ trans7time: (spk_id, st, ed, content)
+ """
+ with open(filename, "r") as fr:
+ trans7time_list = []
+ lines = fr.readlines()
+ for line in lines:
+ trans7time_list.append(line.strip().split())
+ result_trans7time_list = []
+ for index, item in enumerate(trans7time_list):
+ if len(item) <= 2:
+ raise ValueError(f"filename {filename}: item - {index} = {item}")
+ if len(item) == 3:
+ st = float(item[1])
+ ed = float(item[2])
+ result_trans7time_list.append((
+ item[0], st, ed, ""
+ ))
+ else:
+ result_trans7time_list.append((
+ item[0], float(item[1]), float(item[2]), "".join(item[3:])
+ ))
+ return result_trans7time_list
+
+
+def write_json_file(json_file, data):
+ assert str(json_file).endswith(".json") or str(json_file).endswith(".JSON")
+ with codecs.open(json_file, "w", encoding="utf-8") as fw:
+ json.dump(data, fw, indent=2, ensure_ascii=False)
+
+
+def write_wav_scp(fpath, wav_scp):
+ with open(fpath, "w") as f:
+ for key, value in wav_scp.items():
+ f.write(f"{key} {value}\n")
+
+
+def write_trans7time_list(fpath, trans7time_list):
+ """
+ trans7time_list: [(spk_id, start_time, end_time, text)]
+ """
+ with open(fpath, 'w') as fw:
+ for spk_id, start_time, end_time, text in trans7time_list:
+ text = text.replace("\n", "").replace("\r", "")
+ fw.write(f'{spk_id} {start_time} {end_time} {text}\n')
+
+def load_audio(input, ori_fs=None, obj_fs=None):
+ if isinstance(input, str):
+ wav, fs = torchaudio.load(input)
+ wav = wav.mean(dim=0, keepdim=True)
+ if obj_fs is not None and fs != obj_fs:
+ wav = torchaudio.functional.resample(wav, orig_freq=fs, new_freq=obj_fs)
+ return wav
+ elif isinstance(input, np.ndarray) or isinstance(input, torch.Tensor):
+ wav = torch.from_numpy(input) if isinstance(input, np.ndarray) else input
+ if wav.dtype in (torch.int16, torch.int32, torch.int64):
+ wav = wav.type(torch.float32)
+ wav = wav / 32768
+ wav = wav.type(torch.float32)
+ assert wav.ndim <= 2
+ if wav.ndim == 2:
+ if wav.shape[0] > wav.shape[1]:
+ wav = torch.transpose(wav, 0, 1)
+ wav = wav.mean(dim=0, keepdim=True)
+ if wav.ndim == 1:
+ wav = wav.unsqueeze(0)
+ if ori_fs is not None and obj_fs is not None and ori_fs!=obj_fs:
+ wav = torchaudio.functional.resample(wav, orig_freq=ori_fs, new_freq=obj_fs)
+ return wav
+ else:
+ return input
diff --git a/speaker_diarization/speakerlab/utils/score_metrics.py b/speaker_diarization/speakerlab/utils/score_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebd6ce11629ce23ddfb99ff30781a4174da012f8
--- /dev/null
+++ b/speaker_diarization/speakerlab/utils/score_metrics.py
@@ -0,0 +1,188 @@
+"""
+This script computes the official performance metrics for the NIST 2016 SRE.
+The metrics include EER and DCFs (min/act).
+"""
+
+__author__ = "Omid Sadjadi"
+__email__ = "omid.sadjadi@nist.gov"
+__version__ = "4.1"
+
+import numpy as np
+from scipy.stats import norm
+import matplotlib.pyplot as plt
+import sys
+
+
+def compute_norm_counts(scores, edges, wghts=None):
+ """ computes normalized (and optionally weighted) score counts for the
+ bin edges.
+ """
+
+ if scores.size > 0:
+ score_counts = np.histogram(scores, bins=edges,
+ weights=wghts)[0].astype('f')
+ norm_counts = np.cumsum(score_counts) / score_counts.sum()
+ else:
+ norm_counts = None
+ return norm_counts
+
+
+def compute_pmiss_pfa(scores, labels, weights=None):
+ """ computes false positive rate (FPR) and false negative rate (FNR)
+ given trial socres and their labels. A weights option is also provided to
+ equalize the counts over score partitions (if there is such partitioning).
+ """
+
+ tgt_scores = scores[labels == 1] # target trial scores
+ imp_scores = scores[labels == 0] # impostor trial scores
+
+ resol = max(
+ [np.count_nonzero(labels == 0),
+ np.count_nonzero(labels == 1), 1.e6])
+ edges = np.linspace(np.min(scores), np.max(scores), resol)
+
+ if weights is not None:
+ tgt_weights = weights[labels == 1]
+ imp_weights = weights[labels == 0]
+ else:
+ tgt_weights = None
+ imp_weights = None
+
+ fnr = compute_norm_counts(tgt_scores, edges, tgt_weights)
+ fpr = 1 - compute_norm_counts(imp_scores, edges, imp_weights)
+
+ return fnr, fpr
+
+
+def compute_pmiss_pfa_rbst(scores, labels, weights=None):
+ """ computes false positive rate (FPR) and false negative rate (FNR)
+ given trial socres and their labels. A weights option is also provided to
+ equalize the counts over score partitions (if there is such partitioning).
+ """
+
+ sorted_ndx = np.argsort(scores)
+ labels = labels[sorted_ndx]
+ if weights is not None:
+ weights = weights[sorted_ndx]
+ else:
+ weights = np.ones((labels.shape), dtype='f8')
+
+ tgt_wghts = weights * (labels == 1).astype('f8')
+ imp_wghts = weights * (labels == 0).astype('f8')
+
+ fnr = np.cumsum(tgt_wghts) / np.sum(tgt_wghts)
+ fpr = 1 - np.cumsum(imp_wghts) / np.sum(imp_wghts)
+ return fnr, fpr
+
+
+def compute_eer(fnr, fpr, scores=None):
+ """ computes the equal error rate (EER) given FNR and FPR values calculated
+ for a range of operating points on the DET curve
+ """
+
+ diff_pm_fa = fnr - fpr
+ x1 = np.flatnonzero(diff_pm_fa >= 0)[0]
+ x2 = np.flatnonzero(diff_pm_fa < 0)[-1]
+ a = (fnr[x1] - fpr[x1]) / (fpr[x2] - fpr[x1] - (fnr[x2] - fnr[x1]))
+
+ if scores is not None:
+ score_sort = np.sort(scores)
+ return fnr[x1] + a * (fnr[x2] - fnr[x1]), score_sort[x1]
+
+ return fnr[x1] + a * (fnr[x2] - fnr[x1])
+
+
+def compute_c_norm(fnr, fpr, p_target, c_miss=1, c_fa=1):
+ """ computes normalized minimum detection cost function (DCF) given
+ the costs for false accepts and false rejects as well as a priori
+ probability for target speakers
+ """
+
+ c_det = min(c_miss * fnr * p_target + c_fa * fpr * (1 - p_target))
+ c_def = min(c_miss * p_target, c_fa * (1 - p_target))
+
+ return c_det / c_def
+
+
+def compute_c_dcf(fnr, fpr, p_target, c_miss=1, c_fa=1):
+ """ computes normalized minimum detection cost function (DCF) given
+ the costs for false accepts and false rejects as well as a priori
+ probability for target speakers
+ """
+
+ c_det = min(c_miss * fnr * p_target + c_fa * fpr * (1 - p_target))
+
+ return c_det
+
+
+def plot_det_curve(fnr, fpr, save_path=None):
+ """ plots the detection error trade-off (DET) curve
+ """
+
+ p_miss = norm.ppf(fnr)
+ p_fa = norm.ppf(fpr)
+
+ xytick = [
+ 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1,
+ 0.2, 0.4
+ ]
+ xytick_labels = map(str, [x * 100 for x in xytick])
+
+ plt.plot(p_fa, p_miss, 'r')
+ plt.xticks(norm.ppf(xytick), xytick_labels)
+ plt.yticks(norm.ppf(xytick), xytick_labels)
+ plt.xlim(norm.ppf([0.00051, 0.5]))
+ plt.ylim(norm.ppf([0.00051, 0.5]))
+ plt.xlabel("false-alarm rate [%]", fontsize=12)
+ plt.ylabel("false-reject rate [%]", fontsize=12)
+ eer = compute_eer(fnr, fpr)
+ plt.plot(norm.ppf(eer), norm.ppf(eer), 'o')
+ plt.annotate(
+ "EER = %.2f%%" % (eer * 100),
+ xy=(norm.ppf(eer), norm.ppf(eer)),
+ xycoords='data',
+ xytext=(norm.ppf(eer + 0.05), norm.ppf(eer + 0.05)),
+ textcoords='data',
+ arrowprops=dict(arrowstyle="-|>",
+ connectionstyle="arc3, rad=+0.2",
+ fc="w"),
+ size=12,
+ va='center',
+ ha='center',
+ bbox=dict(boxstyle="round4", fc="w"),
+ )
+ plt.grid()
+ if save_path is not None:
+ plt.savefig(save_path)
+ plt.clf()
+ else:
+ plt.show()
+
+
+def compute_equalized_scores(max_tar_imp_counts, sc, labs, masks):
+
+ count_weights = []
+ scores = []
+ labels = []
+ for ix in range(len(masks)):
+ amask = masks[ix]
+ alabs = labs[amask]
+ num_targets = np.count_nonzero(alabs == 1)
+ num_non_targets = alabs.size - num_targets
+ labels.append(alabs)
+ scores.append(sc[amask])
+ tar_weight = max_tar_imp_counts[
+ 0] / num_targets if num_targets > 0 else 0
+ imp_weight = max_tar_imp_counts[
+ 1] / num_non_targets if num_non_targets > 0 else 0
+
+ acount_weights = np.empty(alabs.shape, dtype='f')
+ acount_weights[alabs == 1] = np.array([tar_weight] * num_targets)
+ acount_weights[alabs == 0] = np.array([imp_weight] * num_non_targets)
+ count_weights.append(acount_weights)
+
+ scores = np.hstack(scores)
+ labels = np.hstack(labels)
+ count_weights = np.hstack(count_weights)
+
+ return scores, labels, count_weights
diff --git a/speaker_diarization/speakerlab/utils/utils.py b/speaker_diarization/speakerlab/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2821a93b72ca430ce737b8d5b3233d2a360d7181
--- /dev/null
+++ b/speaker_diarization/speakerlab/utils/utils.py
@@ -0,0 +1,234 @@
+import sys
+import os
+import random
+import logging
+import yaml
+import numpy as np
+from contextlib import contextmanager
+import torch
+from speakerlab.utils.fileio import load_yaml
+
+def parse_config(config_file):
+ if config_file.endwith('.yaml'):
+ config = load_yaml(config_file)
+ else:
+ raise Exception("Other formats not currently supported.")
+ return config
+
+def set_seed(seed=0):
+ np.random.seed(seed)
+ random.seed(seed)
+
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = True
+
+def get_logger(fpath=None, fmt=None):
+ if fmt is None:
+ fmt = "%(asctime)s - %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=fmt)
+ logger = logging.getLogger(__name__)
+ logger.setLevel(logging.INFO)
+ if fpath is not None:
+ handler = logging.FileHandler(fpath)
+ handler.setFormatter(logging.Formatter(fmt))
+ logger.addHandler(handler)
+ return logger
+
+def get_utt2spk_dict(utt2spk, suffix=''):
+ temp_dict={}
+ with open(utt2spk,'r') as utt2spk_f:
+ lines = utt2spk_f.readlines()
+ for i in lines:
+ i=i.strip().split()
+ if suffix == '' or suffix is None:
+ key_i = i[0]
+ value_spk = i[1]
+ else:
+ key_i = i[0]+'_'+suffix
+ value_spk = i[1]+'_'+suffix
+ if key_i in temp_dict:
+ raise ValueError('The key must be unique.')
+ temp_dict[key_i]=value_spk
+ return temp_dict
+
+def get_wavscp_dict(wavscp, suffix=''):
+ temp_dict={}
+ with open(wavscp, 'r') as wavscp_f:
+ lines = wavscp_f.readlines()
+ for i in lines:
+ i=i.strip().split()
+ if suffix == '' or suffix is None:
+ key_i = i[0]
+ else:
+ key_i = i[0]+'_'+suffix
+ value_path = i[1]
+ if key_i in temp_dict:
+ raise ValueError('The key must be unique.')
+ temp_dict[key_i]=value_path
+ return temp_dict
+
+def accuracy(x, target):
+ # x: [*, C], target: [*,]
+ _, pred = x.topk(1)
+ pred = pred.squeeze(-1)
+ acc = pred.eq(target).float().mean()
+ return acc*100
+
+def average_precision(scores, labels):
+ # scores: [N, ], labels: [N, ]
+ if torch.is_tensor(scores):
+ scores = scores.cpu().numpy()
+ if torch.is_tensor(labels):
+ labels = labels.cpu().numpy()
+ if isinstance(scores, list):
+ scores = np.array(scores)
+ if isinstance(labels, list):
+ labels = np.array(labels)
+ assert isinstance(scores, np.ndarray) and isinstance(
+ labels, np.ndarray), 'Input should be numpy.array.'
+ assert len(scores.shape)==1 and len(labels.shape)==1 and \
+ scores.shape[0]==labels.shape[0]
+
+ sort_idx = np.argsort(scores)[::-1]
+ scores = scores[sort_idx]
+ labels = labels[sort_idx]
+ tp_count = (labels==1).sum()
+ tp = labels.cumsum()
+ recall = tp / tp_count
+ precision = tp / (np.arange(len(labels)) + 1)
+
+ recall = np.concatenate([[0], recall, [1]])
+ precision = np.concatenate([[0], precision, [0]])
+
+ # Smooth precision to be monotonically decreasing.
+ for i in range(len(precision) - 2, -1, -1):
+ precision[i] = np.maximum(precision[i], precision[i + 1])
+
+ indices = np.where(recall[1:] != recall[:-1])[0] + 1
+ average_precision = np.sum(
+ (recall[indices] - recall[indices - 1]) * precision[indices])
+ return average_precision
+
+def load_params(dst_model, src_state, strict=True):
+ dst_state = {}
+ for k in src_state:
+ if k.startswith('module'):
+ dst_state[k[7:]] = src_state[k]
+ else:
+ dst_state[k] = src_state[k]
+ dst_model.load_state_dict(dst_state, strict=strict)
+ return dst_model
+
+def merge_vad(vad1: list, vad2: list):
+ intervals = vad1 + vad2
+ intervals.sort(key=lambda x: x[0])
+ merged = []
+ for interval in intervals:
+ if not merged or merged[-1][1] < interval[0]:
+ merged.append(interval)
+ else:
+ merged[-1][1] = max(merged[-1][1], interval[1])
+ return merged
+
+class AverageMeter(object):
+ def __init__(self, name, fmt=':f'):
+ self.name = name
+ self.fmt = fmt
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
+
+class AverageMeters(object):
+ def __init__(self, names: list = None, fmts: list = None):
+ self.cont = dict()
+ if names is None or fmts is None:
+ return
+ for name, fmt in zip(names, fmts):
+ self.cont[name] = AverageMeter(name, fmt)
+
+ def add(self, name, fmt=':f'):
+ self.cont[name] = AverageMeter(name, fmt)
+
+ def update(self, name, val, n=1):
+ self.cont[name].update(val, n)
+
+ def avg(self, name):
+ return self.cont[name].avg
+
+ def val(self, name):
+ return self.cont[name].val
+
+ def __str__(self):
+ return '\t'.join([str(s) for s in self.cont.values()])
+
+
+class ProgressMeter(object):
+ def __init__(self, num_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(self.meters)]
+ return '\t'.join(entries)
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = '{:' + str(num_digits) + 'd}'
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
+@contextmanager
+def silent_print():
+ original_stdout = sys.stdout
+ sys.stdout = open(os.devnull, 'w')
+ try:
+ yield
+ finally:
+ sys.stdout.close()
+ sys.stdout = original_stdout
+
+def download_model_from_modelscope(model_id, model_revision=None, cache_dir=None):
+ from modelscope.hub.snapshot_download import snapshot_download
+ if cache_dir is None:
+ cache_dir = snapshot_download(
+ model_id,
+ revision=model_revision,
+ )
+ else:
+ cfg_file = os.path.join(cache_dir, model_id, 'configuration.json')
+ if not os.path.exists(cfg_file):
+ cache_dir = snapshot_download(
+ model_id,
+ revision=model_revision,
+ cache_dir=cache_dir,
+ )
+ else:
+ cache_dir = os.path.join(cache_dir, model_id)
+ return cache_dir
+
+def circle_pad(x: torch.Tensor, target_len, dim=0):
+ xlen = x.shape[dim]
+ if xlen >= target_len:
+ return x
+ n = int(np.ceil(target_len/xlen))
+ xcat = torch.cat([x for _ in range(n)], dim=dim)
+ return torch.narrow(xcat, dim, 0, target_len)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..86bc1a9111218dcc7f054b222721a8ca770a75d5
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,197 @@
+import os
+import shutil
+import json
+import re
+from pydub import AudioSegment
+from moviepy.video.io.VideoFileClip import VideoFileClip
+from speaker_diarization.local.vision_processer import VisionProcesser
+import wave
+
+
+# ==================== 工具函数 ====================
+
+def get_video_duration(video_path):
+ """获取视频时长(秒)"""
+ try:
+ clip = VideoFileClip(video_path)
+ duration = clip.duration
+ clip.close()
+ return duration
+ except Exception as e:
+ return 0.0
+
+def extract_audio_from_video(video_path: str, wav_path: str, sample_rate: int = 16000):
+ """Extract mono 16kHz WAV from video."""
+ print(f"[INFO] Extracting audio from {video_path} to {wav_path}")
+ audio = AudioSegment.from_file(video_path)
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
+ audio.export(wav_path, format="wav")
+
+
+def extract_visual_embeddings(frontend, vad_list, video_path, wav_path, pkl_path):
+ try:
+ vp = VisionProcesser(
+ video_file_path = video_path,
+ audio_file_path = wav_path,
+ audio_vad = vad_list,
+ out_feat_path = pkl_path,
+ visual_models = frontend,
+ conf = frontend.conf,
+ out_video_path=None
+ )
+ vp.run()
+ except Exception as e:
+ print(f"[ERROR] Failed to process {video_path}: {e}")
+ raise
+ finally:
+ if 'vp' in locals():
+ vp.close()
+ return
+
+
+def detect_video_type(video_path):
+ """【占位函数】检测视频类型"""
+ return "独白"
+
+def clip_video_segment(video_path, start_time, end_time, output_dir, clip_name):
+ """裁切视频片段"""
+ try:
+ video_clip = os.path.join(output_dir, f"{clip_name}.mp4")
+ audio_clip = os.path.join(output_dir, f"{clip_name}.wav")
+ clip = VideoFileClip(video_path).subclipped(start_time, end_time)
+ clip.write_videofile(video_clip, codec="libx264", audio_codec='aac', logger=None)
+ if clip.audio is not None:
+ clip.audio.write_audiofile(audio_clip, codec="pcm_s16le", logger=None)
+ else:
+ num_samples = int(16000 * (end_time - start_time))
+ with wave.open(audio_clip, 'wb') as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2) # 16-bit
+ wf.setframerate(16000)
+ wf.writeframes(b'\x00' * (num_samples * 2))
+ clip.close()
+ return video_clip, audio_clip
+ except Exception as e:
+ return None
+
+def generate_jsonl_data(frontend, video_path, segments_data, work_dir, video_duration):
+ """生成 JSONL 格式数据"""
+ video_type = detect_video_type(video_path)
+
+ jsonl_items = []
+
+ for idx, seg in enumerate(segments_data):
+ utt_name = f"clip_{idx}"
+ start, end = max(0.0, float(seg['start']) - 0.1), min(float(seg['end']) + 0.1, video_duration)
+ duration = end - start
+
+
+ video_clip, audio_clip = clip_video_segment(
+ video_path, start, end,
+ work_dir, utt_name
+ )
+ if not video_clip or not audio_clip:
+ continue
+
+ pkl_path = os.path.join(work_dir, f"{utt_name}.pkl")
+
+ extract_visual_embeddings(
+ frontend,
+ vad_list = [[0.0, round(duration, 2)]],
+ video_path = video_clip,
+ wav_path = audio_clip,
+ pkl_path = pkl_path
+ )
+
+ ref_audio_path = audio_clip
+ if seg.get('ref_audio') and os.path.exists(seg['ref_audio']):
+ src = seg['ref_audio']
+ dst = os.path.join(work_dir, f"{utt_name}_ref.wav")
+ shutil.copy(src, dst)
+ ref_audio_path = dst
+
+ item = {
+ "messages": [
+ {"role": "text", "content": seg['text']},
+ {"role": "vocal", "content": ref_audio_path},
+ {"role": "video", "content": video_clip},
+ {"role": "face", "content": pkl_path},
+ {"role": "dialogue", "content": [{
+ "start": 0.0,
+ "duration": round(duration, 2),
+ "spk": "1",
+ "gender": seg['gender'],
+ "age": seg['age']
+ }]},
+ {"role": "clue", "content": seg['clue']}
+ ],
+ "utt": utt_name,
+ "type": video_type,
+ "speech_length": int(duration * 25),
+ "start": start,
+ "end": end
+ }
+ jsonl_items.append(item)
+
+ jsonl_path = os.path.join(work_dir, "input_data.jsonl")
+ with open(jsonl_path, 'w', encoding='utf-8') as f:
+ for item in jsonl_items:
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
+
+ return jsonl_path, jsonl_items
+
+def validate_timestamps(start, end, video_duration):
+ """验证时间戳合法性"""
+ errors = []
+ if start < 0:
+ errors.append(f"起始时间 ({start}s) 不能小于 0")
+ if end > video_duration:
+ errors.append(f"终止时间 ({end}s) 不能大于视频总时长 ({video_duration}s)")
+ duration = end - start
+ if duration <= 0:
+ errors.append(f"起始时间 ({start}s) 必须小于终止时间 ({end}s)")
+ if duration >= 0 and duration <= 2:
+ errors.append(f"配音时长 ({duration}s) 太短,必须大于 2s")
+ if duration >= 30:
+ errors.append(f"配音时长 ({duration}s) 太长,请小于 30s")
+ return errors
+
+#=== SRT 解析 ====
+def parse_srt_time(time_str: str) -> float:
+ """将 SRT 时间字符串 (HH:MM:SS,mmm) 转换为秒"""
+ time_str = time_str.strip().replace(',', '.')
+ h, m, s = time_str.split(':')
+ s_part, ms = s.split('.')
+ return int(h) * 3600 + int(m) * 60 + int(s_part) + int(ms) / 1000.0
+
+def parse_srt_content(srt_text: str) -> list:
+ """解析 SRT 文本"""
+ if not srt_text:
+ return []
+ lines = srt_text.replace('\r\n', '\n').replace('\r', '\n').strip().split('\n')
+ segments = []
+ n = len(lines)
+ i = 0
+ while i < n:
+ line = lines[i].strip()
+ if re.match(r'^\d+(\s+spk\d+)?$', line):
+ if i + 1 < n:
+ time_match = re.search(r'(\d{2}:\d{2}:\d{2}[,.]\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2}[,.]\d{3})', lines[i+1])
+ if time_match:
+ start = parse_srt_time(time_match.group(1))
+ end = parse_srt_time(time_match.group(2))
+ # 收集文本,直到遇到下一个序号行或文件结束
+ text_parts = []
+ j = i + 2
+ while j < n and not re.match(r'^\d+(\s+spk\d+)?$', lines[j].strip()):
+ if lines[j].strip():
+ text_parts.append(lines[j].strip())
+ j += 1
+
+ text = ' '.join(text_parts)
+ if text:
+ segments.append({"start": start, "end": end, "text": text})
+ i = j # 跳过已处理的块,直接进入下一轮
+ continue
+ i += 1
+ return segments
\ No newline at end of file