xuan3986 commited on
Commit
03022ee
·
verified ·
1 Parent(s): 21a001f

Upload 111 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +1 -0
  3. LICENSE +201 -0
  4. README.md +155 -14
  5. README_zh.md +153 -0
  6. app.py +415 -0
  7. data/ref.wav +3 -0
  8. data/sample.mp4 +3 -0
  9. decode_conf/decode.yaml +42 -0
  10. decode_conf/diar.yaml +51 -0
  11. decode_conf/ds_stage0_fp32.json +33 -0
  12. funcineforge/.DS_Store +0 -0
  13. funcineforge/__init__.py +7 -0
  14. funcineforge/auto/__init__.py +0 -0
  15. funcineforge/auto/auto_frontend.py +95 -0
  16. funcineforge/auto/auto_model.py +173 -0
  17. funcineforge/datasets/__init__.py +2 -0
  18. funcineforge/datasets/datasets.py +193 -0
  19. funcineforge/datasets/index_ds.py +151 -0
  20. funcineforge/download/__init__.py +0 -0
  21. funcineforge/download/download_model_from_hub.py +220 -0
  22. funcineforge/download/file.py +320 -0
  23. funcineforge/download/name_maps_from_hub.py +42 -0
  24. funcineforge/face/__init__.py +1 -0
  25. funcineforge/face/face_recognition.py +16 -0
  26. funcineforge/models/__init__.py +5 -0
  27. funcineforge/models/causal_hifigan.py +834 -0
  28. funcineforge/models/flow_matching_model.py +514 -0
  29. funcineforge/models/inference_model.py +116 -0
  30. funcineforge/models/language_model.py +274 -0
  31. funcineforge/models/modules/__init__.py +0 -0
  32. funcineforge/models/modules/dit_flow_matching/__init__.py +0 -0
  33. funcineforge/models/modules/dit_flow_matching/dit_model.py +208 -0
  34. funcineforge/models/modules/dit_flow_matching/dit_modules.py +622 -0
  35. funcineforge/models/modules/hifigan/__init__.py +14 -0
  36. funcineforge/models/modules/hifigan/activations.py +120 -0
  37. funcineforge/models/modules/hifigan/discriminator.py +299 -0
  38. funcineforge/models/modules/hifigan/generator.py +625 -0
  39. funcineforge/models/modules/hifigan/mel_spectrum.py +93 -0
  40. funcineforge/models/modules/hifigan/nsf_utils.py +253 -0
  41. funcineforge/models/specaug/__init__.py +0 -0
  42. funcineforge/models/specaug/mask_along_axis.py +204 -0
  43. funcineforge/models/specaug/specaug.py +103 -0
  44. funcineforge/models/specaug/time_warp.py +89 -0
  45. funcineforge/models/utils/__init__.py +2 -0
  46. funcineforge/models/utils/llm_decoding.py +178 -0
  47. funcineforge/models/utils/mask_along_axis.py +76 -0
  48. funcineforge/models/utils/masks.py +132 -0
  49. funcineforge/models/utils/nets_utils.py +734 -0
  50. funcineforge/tokenizer/__init__.py +1 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/ref.wav filter=lfs diff=lfs merge=lfs -text
37
+ data/sample.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,14 +1,155 @@
1
- ---
2
- title: Fun CineForge Demo
3
- emoji: 🏢
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.9.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Fun-CineForge-zh-en-v1-0.5B
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### <p align="center">「English | [简体中文](./README_zh.md)」</p>
2
+
3
+ <p align="center">
4
+ <b>🎬 Fun-CineForge: A Unified Dataset Pipeline and Model for Zero-Shot Movie Dubbing<br>
5
+ in Diverse Cinematic Scenes</b>
6
+ </p>
7
+
8
+ <div align="center">
9
+
10
+ ![license](https://img.shields.io/github/license/modelscope/modelscope.svg)
11
+ <a href=""><img src="https://img.shields.io/badge/OS-Linux-orange.svg"></a>
12
+ <a href=""><img src="https://img.shields.io/badge/Python->=3.8-aff.svg"></a>
13
+ <a href=""><img src="https://img.shields.io/badge/Pytorch->=2.1-blue"></a>
14
+ </div>
15
+
16
+ <div align="center">
17
+ <h4><a href="#Dataset&Demo">Dataset & Demo</a>
18
+ |<a href="#Environment">Environment</a>
19
+ |<a href="#Dataset-Pipeline">Dataset Pipeline</a>
20
+ |<a href="#Dubbing-Model">Dubbing Model</a>
21
+ |<a href="#Recent-Updates">Recent Updates</a>
22
+ |<a href="#Publication">Publication</a>
23
+ |<a href="#Comminicate">Comminicate</a>
24
+ </h4>
25
+ </div>
26
+
27
+ **Fun-CineForge** contains an end-to-end dataset pipeline for producing large-scale dubbing datasets and an MLLM-based dubbing model designed for diverse cinematic scenes. Using this pipeline, we constructed the first large-scale Chinese television dubbing dataset CineDub-CN, which includes rich annotations and diverse scenes. In monologue, narration, dialogue, and multi-speaker scenes, our dubbing model consistently outperforms state-of-the-art methods in terms of audio quality, lip-sync, timbre transition, and instruction following.
28
+
29
+ <a name="Dataset&Demo"></a>
30
+ ## Dataset & Demo 🎬
31
+ You can access [https://funcineforge.github.io/](https://funcineforge.github.io/) to get our CineDub-CN dataset samples and demo samples.
32
+
33
+ <a name="Environment"></a>
34
+ ## Environmental Installation
35
+
36
+ Fun-CineForge relies on Conda and Python environments. Execute **setup.py** to automatically install the entire project environment and open-source model.
37
+
38
+ ```shell
39
+ # Conda
40
+ git clone git@github.com:FunAudioLLM/FunCineForge.git
41
+ conda create -n FunCineForge python=3.10 -y && conda activate FunCineForge
42
+ sudo apt-get install ffmpeg
43
+ # Initial settings
44
+ python setup.py
45
+ ```
46
+
47
+ <a name="Dataset-Pipeline"></a>
48
+ ## Dataset Pipeline 🔨
49
+
50
+ ### Data collection
51
+ If you want to produce your own data,
52
+ we recommend that you refer to the following requirements to collect the corresponding movies or television series.
53
+
54
+ 1. Video source: TV dramas or movies, non documentaries, with more monologues or dialogue scenes, clear and unobstructed faces (such as without masks and veils).
55
+ 2. Speech Requirements: Standard pronunciation, clear articulation, prominent human voice. Avoid materials with strong dialects, excessive background noise, or strong colloquialism.
56
+ 3. Image Requirements: High resolution, clear facial details, sufficient lighting, avoiding extremely dark or strong backlit scenes.
57
+
58
+ ### How to use
59
+
60
+ - [1] Standardize video format and name; trim the beginning and end of long videos; extract the audio from the trimmed video. (default is to trim 10 seconds from both the beginning and end.)
61
+ ```shell
62
+ python normalize_trim.py --root datasets/raw_zh --intro 10 --outro 10
63
+ ```
64
+
65
+ - [2] [Speech Separation](./speech_separation/README.md). The audio is used to separate the vocals from the instrumental music.
66
+ ```shell
67
+ cd speech_separation
68
+ python run.py --root datasets/clean/zh --gpus 0 1 2 3
69
+ ```
70
+
71
+ - [3] [VideoClipper](./video_clip/README.md). For long videos, VideoClipper is used to obtain sentence-level subtitle files and clip the long video into segments based on timestamps. Now it supports bilingualism in both Chinese and English. Below is an example in Chinese. It is recommended to use gpu acceleration for English.
72
+ ```shell
73
+ cd video_clip
74
+ bash run.sh --stage 1 --stop_stage 2 --input datasets/raw_zh --output datasets/clean/zh --lang zh --device cpu
75
+ ```
76
+
77
+ - Video duration limit and check for cleanup. (Without --execute, only pre-deleted files will be printed. After checking, add --execute to confirm the deletion.)
78
+ ```shell
79
+ python clean_video.py --root datasets/clean/zh
80
+ python clean_srt.py --root datasets/clean/zh --lang zh
81
+ ```
82
+
83
+ - [4] [Speaker Diarization](./speaker_diarization/README.md). Multimodal active speaker recognition obtains RTTM files; identifies the speaker's facial frames, extracts frame-level speaker face and lip raw data.
84
+ ```shell
85
+ cd speaker_diarization
86
+ bash run.sh --stage 1 --stop_stage 4 --hf_access_token hf_xxx --root datasets/clean/zh --gpus "0 1 2 3"
87
+ ```
88
+
89
+ - (Reference) Extract speech tokens based on the CosyVoice3 tokenizer for llm training.
90
+ ```shell
91
+ python speech_tokenizer.py --root datasets/clean/zh
92
+ ```
93
+
94
+ - [5] Multimodal CoT Correction. Based on general-purpose MLLMs, the system uses audio, ASR text, and RTTM files as input. It leverages Chain-of-Thought (CoT) reasoning to extract clues and corrects the results of the specialized models. It also annotates character age, gender, and vocal timbre. Experimental results show that this strategy reduces the CER from 4.53% to 0.94% and the speaker diarization error rate from 8.38% to 1.20%, achieving quality comparable to or even better than manual transcription. Adding the --resume enables breakpoint COT inference to prevent wasted resources from repeated COT inferences. Now supports both Chinese and English.
95
+ ```shell
96
+ python cot.py --root_dir datasets/clean/zh --lang zh --provider google --model gemini-3-pro-preview --api_key xxx --resume
97
+ python cot.py --root_dir datasets/clean/en --lang en --provider google --model gemini-3-pro-preview --api_key xxx --resume
98
+ ```
99
+
100
+ - The construction of the dataset retrieval file will read all production data, perform bidirectional verification of script content and speaker separation results.
101
+ ```shell
102
+ python build_datasets.py --root_zh datasets/clean/zh --root_en datasets/clean/en --out_dir datasets/clean --save
103
+ ```
104
+
105
+ <a name="Dubbing-Model"></a>
106
+ ## Dubbing Model ⚙️
107
+ We've open-sourced the inference code and the **infer.sh** script, and provided some test cases in the data folder for your experience. Inference requires a consumer-grade GPU. Run the following command:
108
+
109
+ ```shell
110
+ cd exps
111
+ bash infer.sh
112
+ ```
113
+
114
+ The API for multi-speaker dubbing from raw videos and SRT scripts is under development ...
115
+
116
+ <a name="Recent-Updates"></a>
117
+ ## Recent Updates 🚀
118
+ - 2025/12/18: Fun-CineForge dataset pipeline toolkit is online! 🔥
119
+ - 2026/01/19: Chinese demo samples and CineDub-CN dataset samples released. 🔥
120
+ - 2026/01/25: Fix some environmental and operational issues.
121
+ - 2026/02/09: Optimized the data pipeline and added support for English videos.
122
+ - 2026/03/05: English demo samples and CineDub-EN dataset samples released. 🔥
123
+ - 2026/03/16: Open source inference code and checkpoints. 🔥
124
+
125
+ <a name="Publication"></a>
126
+ ## Publication 📚
127
+ If you use our dataset or code, please cite the following paper:
128
+ <pre>
129
+ @misc{liu2026funcineforgeunifieddatasettoolkit,
130
+ title={FunCineForge: A Unified Dataset Toolkit and Model for Zero-Shot Movie Dubbing in Diverse Cinematic Scenes},
131
+ author={Jiaxuan Liu and Yang Xiang and Han Zhao and Xiangang Li and Zhenhua Ling},
132
+ year={2026},
133
+ eprint={2601.14777},
134
+ archivePrefix={arXiv},
135
+ primaryClass={cs.CV},
136
+ }
137
+ </pre>
138
+
139
+ <a name="Comminicate"></a>
140
+ ## Comminicate 🍟
141
+ The Fun-CineForge open-source project is developed and maintained by the Tongyi Lab Speech Team and a student from NERCSLIP, University of Science and Technology of China.
142
+ We welcome you to participate in discussions on Fun-CineForge [GitHub Issues](https://github.com/FunAudioLLM/FunCineForge/issues) or contact us for collaborative development.
143
+ For any questions, you can contact the [developer](mailto:jxliu@mail.ustc.edu.cn).
144
+
145
+ ⭐ Hope you will support Fun-CineForge. Thank you.
146
+
147
+ ### Disclaimer
148
+
149
+ This repository contains research artifacts:
150
+
151
+ ⚠️ Currently not a commercial product of Tongyi Lab.
152
+
153
+ ⚠️ Released for academic research / cutting-edge exploration purposes
154
+
155
+ ⚠️ CineDub Dataset samples are subject to specific license terms.
README_zh.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### <p align="center">「[English](./README.md) | 简体中文」</p>
2
+
3
+ <p align="center">
4
+ <b>🎬 Fun-CineForge:一种用于多样化影视场景零样本配音的统一数据集管道和模型</b>
5
+ </p>
6
+
7
+ <div align="center">
8
+
9
+ ![license](https://img.shields.io/github/license/modelscope/modelscope.svg)
10
+ <a href=""><img src="https://img.shields.io/badge/OS-Linux-orange.svg"></a>
11
+ <a href=""><img src="https://img.shields.io/badge/Python->=3.8-aff.svg"></a>
12
+ <a href=""><img src="https://img.shields.io/badge/Pytorch->=2.1-blue"></a>
13
+ </div>
14
+
15
+ <div align="center">
16
+ <h4><a href="#数据集&样例">数据集 & 样例</a>
17
+ |<a href="#环境安装">环境安装</a>
18
+ |<a href="#数据集管道">数据集管道</a>
19
+ |<a href="#配音模型">配音模型</a>
20
+ |<a href="#近期更新">近期更新</a>
21
+ |<a href="#发表">发表</a>
22
+ |<a href="#社区交流">社区交流</a>
23
+ </h4>
24
+ </div>
25
+
26
+ **Fun-CineForge** 包含一个生产大规模配音数据集的端到端数据集管道,和一个基于多模态大模型的配音模型,该模型专为多样的电影场景而设计。利用该管道,我们构建了首个大规模中文电视剧配音数据集 CineDub-CN,该数据集包含丰富的标注和多样化的场景。在独白、旁白、对话和多说话人场景中,我们的配音模型在音频质量、唇形同步、音色转换和指令遵循等方面全部优于最先进的方法。
27
+
28
+ <a name="数据集&样例"></a>
29
+ ## 数据集 & 样例 🎬
30
+ 您可以访问此 [https://funcineforge.github.io/](https://funcineforge.github.io/) 获取我们的 CineDub-CN 数据集和 CineDub-EN 数据集样例和演示样例。
31
+
32
+ <a name="环境安装"></a>
33
+ ## 环境安装
34
+
35
+ Fun-CineForge 依赖 Conda 和 Python 环境。执行 **setup.py** 自动安装整个项目环境和开源模型。
36
+
37
+ ```shell
38
+ # Conda
39
+ git clone git@github.com:FunAudioLLM/FunCineForge.git
40
+ conda create -n FunCineForge python=3.10 -y && conda activate FunCineForge
41
+ sudo apt-get install ffmpeg
42
+ # 初始化设置
43
+ python setup.py
44
+ ```
45
+
46
+ <a name="数据集管道"></a>
47
+ ## 数据集管道 🔨
48
+
49
+ ### 数据收集
50
+ 如果您想自行生产数据,我们建议您参考下面的要求收集相应的电影或影视剧。
51
+
52
+ 1. 视频来源:电视剧或电影,非纪录片,人物独白或对话场景较多,人脸清晰且无遮挡(如无面罩、面纱)。
53
+ 2. 语音要求:发音标准,吐字清晰,人声突出。避免方言浓重、背景噪音过大或口语感过强的素材。
54
+ 3. 图片要求:高分辨率,面部细节清晰,光线充足,避免极端阴暗或强烈逆光的场景。
55
+
56
+ ### 使用方法
57
+
58
+ - [1] 将视频格式、名称标准化;裁剪长视频的片头片尾;提取裁剪后视频的音频。(默认是从起止各裁剪 10 秒。)
59
+ ```shell
60
+ python normalize_trim.py --root datasets/raw_zh --intro 10 --outro 10
61
+ ```
62
+
63
+ - [2] [Speech Separation](./speech_separation/README.md). 音频进行人声乐声分离。
64
+ ```shell
65
+ cd speech_separation
66
+ python run.py --root datasets/clean/zh --gpus 0 1 2 3
67
+ ```
68
+
69
+ - [3] [VideoClipper](./video_clip/README.md). 对于长视频,使用 VideoClipper 获取句子级别的字幕文件,并根据时间戳将长视频剪辑成片段。现在它支持中英双语。以下是中文示例。英文建议采用 gpu 加速处理。
70
+ ```shell
71
+ cd video_clip
72
+ bash run.sh --stage 1 --stop_stage 2 --input datasets/raw_zh --output datasets/clean/zh --lang zh --device cpu
73
+ ```
74
+
75
+ - 视频时长限制及清理检查。(若不使用--execute参数,则仅打印已预删除的文件。检查后,若需确认删除,请添加--execute参数。)
76
+ ```shell
77
+ python clean_video.py --root datasets/clean/zh
78
+ python clean_srt.py --root datasets/clean/zh --lang zh
79
+ ```
80
+
81
+ - [4] [Speaker Diarization](./speaker_diarization/README.md). 多模态主动说话人识别,得到 RTTM 文件;识别说话人的面部帧,提取帧级的说话人面部和唇部原始数据,从面部帧中识别说话帧,提取说话帧的面部特征。
82
+ ```shell
83
+ cd speaker_diarization
84
+ bash run.sh --stage 1 --stop_stage 4 --hf_access_token hf_xxx --root datasets/clean/zh --gpus "0 1 2 3"
85
+ ```
86
+
87
+ - (参考)基于 CosyVoice3 tokenizer 提取 speech tokens 用于大模型训练。
88
+ ```shell
89
+ python speech_tokenizer.py --root datasets/clean/zh
90
+ ```
91
+
92
+ - [5] 多模态思维链校正。该系统基于通用多模态大模型,以音频、ASR 抄本和 RTTM 文件为输入,利用思维链推理来提取线索,并校正专用模型的结果,并标注人物年龄、性别和音色。实验结果表明,该策略将词错率从4.53% 降低到 0.94%,说话人识别错误率从 8.38% 降低到 1.20%,其质量可与人工转录相媲美,甚至更优。添加--resume选项可启用断点思维链推理,以避免重复思维链推理造成的资源浪费。现支持中英文。
93
+ ```shell
94
+ python cot.py --root_dir datasets/clean/zh --lang zh --provider google --model gemini-3-pro-preview --api_key xxx --resume
95
+ python cot.py --root_dir datasets/clean/en --lang en --provider google --model gemini-3-pro-preview --api_key xxx --resume
96
+ ```
97
+
98
+ - 数据集检索文件的构建会读取生产的所有数据,双向校验脚本内容和说话人分离结果。
99
+ ```shell
100
+ python build_datasets.py --root_zh datasets/clean/zh --root_en datasets/clean/en --out_dir datasets/clean --save
101
+ ```
102
+
103
+ <a name="Dubbing-Model"></a>
104
+ ## 配音模型 ⚙️
105
+ 我们开源了推理代码和 **infer.sh** 脚本,在 data 文件夹中提供了一些测试样例,以供体验。推理需要一张消费级 GPU。按下面的命令运行:
106
+
107
+ ```shell
108
+ cd exps
109
+ bash infer.sh
110
+ ```
111
+
112
+ 从原始视频和 SRT 脚本进行多人配音的 API 调用接口在开发中 ...
113
+
114
+ <a name="近期更新"></a>
115
+ ## 近期更新 🚀
116
+ - 2025/12/18:Fun-CineForge 数据集管道工具包上线!🔥
117
+ - 2026/01/19:发布中文演示样例和 CineDub-CN 数据集样例。 🔥
118
+ - 2026/01/25:修复了一些环境和运行问题。
119
+ - 2026/02/09:优化了数据管道,新增支持英文视频的能力。
120
+ - 2026/03/05:发布英文演示样例和 CineDub-EN 数据集样例。 🔥
121
+ - 2026/03/16:开源推理代码和 checkpoints。 🔥
122
+
123
+ <a name="发表"></a>
124
+ ## 发表 📚
125
+ 如果您使用了我们的数据集或代码,请引用以下论文:
126
+ <pre>
127
+ @misc{liu2026funcineforgeunifieddatasettoolkit,
128
+ title={FunCineForge: A Unified Dataset Toolkit and Model for Zero-Shot Movie Dubbing in Diverse Cinematic Scenes},
129
+ author={Jiaxuan Liu and Yang Xiang and Han Zhao and Xiangang Li and Zhenhua Ling},
130
+ year={2026},
131
+ eprint={2601.14777},
132
+ archivePrefix={arXiv},
133
+ primaryClass={cs.CV},
134
+ }
135
+ </pre>
136
+
137
+
138
+ <a name="社区交流"></a>
139
+ ## 社区交流 🍟
140
+ Fun-CineForge 开源项目由通义实验室语音团队和中国科学技术大学 NERCSLIP 学生开发并维护,我们欢迎您在 Fun-CineForge [GitHub Issues](https://github.com/FunAudioLLM/FunCineForge/issues) 参与问题讨论,或联系我们合作开发。
141
+ 有任何问题您可以联系[开发者](mailto:jxliu@mail.ustc.edu.cn)。
142
+
143
+ ⭐ 希望您你支持 Fun-CineForge,谢谢。
144
+
145
+ ### 免责声明
146
+
147
+ 该仓库包含的研究成果:
148
+
149
+ ⚠️ 目前非通义实验室商业化产品
150
+
151
+ ⚠️ 供学术研究/前沿探索用途
152
+
153
+ ⚠️ 数据集样例受特定许可条款约束
app.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import torch
5
+ import gradio as gr
6
+ import typing
7
+ import time
8
+ import shutil
9
+ from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
10
+ from moviepy.audio.AudioClip import CompositeAudioClip
11
+ from modelscope import snapshot_download
12
+ from utils import get_video_duration, generate_jsonl_data, validate_timestamps, parse_srt_content
13
+ # 尝试导入模型库
14
+ from funcineforge import AutoFrontend
15
+ from speaker_diarization.run import GlobalModels
16
+ snapshot_download(
17
+ repo_id="FunAudioLLM/Fun-CineForge",
18
+ revision='v1.0.0',
19
+ local_dir='pretrained_models',
20
+ ignore_patterns=[
21
+ "*.md",
22
+ ".git*",
23
+ "funcineforge_zh_en/llm/config.yaml"
24
+ ],
25
+ repo_type="model",
26
+ )
27
+
28
+
29
+ # ==================== 配置区域 ====================
30
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
31
+ SERVER_PORT = 7860
32
+ TEMP_DIR = "temp_workdir"
33
+ CONFIG_FRONTEND = "decode_conf/diar.yaml"
34
+ CONFIG_MODEL = "decode_conf/decode.yaml"
35
+ PRETRAIN = "pretrained_models"
36
+ MAX_SEGMENTS = 8 # UI 片段数上限
37
+ DEFAULT_VIDEO_PATH="data/sample.mp4"
38
+ DEFAULT_AUDIO_PATH="data/ref.wav"
39
+ DEFAULT_TEXT = "我军无粮,利在急战。今乘魏兵新败,不敢出兵,出其不意,乘机退去,方可平安无事。"
40
+ DEFAULT_CLUE = "一位中年男性以沉稳但略带担忧的语调,分析我军无粮急战的困境与敌军心败状态。他随即提出一种撤退方案,整体流露出对战局的担忧和谋求生路。"
41
+ # 全局模型实例(延迟加载)
42
+ model_pool: typing.Optional[GlobalModels] = None
43
+ engine = None
44
+
45
+ def init_engine():
46
+ """延迟加载模型,避免启动时卡住"""
47
+ global engine
48
+ engine = AutoFrontend(PRETRAIN, CONFIG_MODEL, TEMP_DIR, DEVICE)
49
+ return engine
50
+
51
+ def init_frontend_models():
52
+ global model_pool
53
+ model_pool = GlobalModels(
54
+ hf_token = None,
55
+ config_path = CONFIG_FRONTEND,
56
+ pretrained_dir= PRETRAIN,
57
+ device = DEVICE,
58
+ pool_sizes = {"face": 1, "asd": 1, "fr": 1},
59
+ batch_size = 1,
60
+ preload = True
61
+ )
62
+ return model_pool
63
+
64
+ # ==================== Gradio UI 逻辑 ====================
65
+
66
+ def create_segments_ui():
67
+ segments = []
68
+ accordions = []
69
+ for i in range(MAX_SEGMENTS):
70
+ with gr.Accordion(f"🎬 配音片段 {i + 1}", open=(i == 0), visible=(i == 0)) as acc:
71
+ accordions.append(acc)
72
+ with gr.Row():
73
+ text_input = gr.Textbox(label="📝 配音文本内容", placeholder="输入台词...", lines=2, scale=3, elem_id=f"text_{i}")
74
+ clue_input = gr.Textbox(label="💡 线索描述", placeholder="一位中年男性角色语气沉稳且坚定,流露出对自身忠诚的强烈自信与决心。整体情感是忠贞不渝的承诺和不容置疑的信念。", lines=2, scale=3, elem_id=f"clue_{i}")
75
+ with gr.Row():
76
+ start_time = gr.Number(label="⏱️ 起始时间 (s)", value=0.0 + i*5, precision=2, scale=2, elem_id=f"start_{i}")
77
+ end_time = gr.Number(label="⏱️ 终止时间 (s)", value=5.0 + i*5, precision=2, scale=2, elem_id=f"end_{i}")
78
+ with gr.Row():
79
+ age_input = gr.Dropdown(label="👤 年龄", choices=["儿童", "青年", "中年", "中老年", "老年", "不确定"], value="不确定", scale=2, elem_id=f"age_{i}")
80
+ gender_input = gr.Dropdown(label="👤 性别", choices=["男", "女", "不确定"], value="不确定", scale=2, elem_id=f"gender_{i}")
81
+ with gr.Row():
82
+ ref_audio = gr.Audio(label="🎤 参考语音 (可选,默认以视频原声作为参考音频)", sources=["upload"], type="filepath", scale=4,elem_id=f"audio_{i}")
83
+ load_audio_btn = gr.Button("📂 加载示例音频", size="sm", variant="secondary", scale=1) if i == 0 else None
84
+ with gr.Row():
85
+ enable_check = gr.Checkbox(label="启用此片段", value=(i == 0), scale=1, elem_id=f"enable_{i}")
86
+
87
+ segments.append({
88
+ "accordion": acc, "text": text_input, "clue": clue_input, "start": start_time, "end": end_time,
89
+ "age": age_input, "gender": gender_input, "audio": ref_audio,
90
+ "enable": enable_check, "index": i, "load_audio_btn": load_audio_btn})
91
+ return segments, accordions
92
+
93
+ def add_segment_fn(current_count):
94
+ """点击加号:显示下一个片段,到达上限则禁用按钮"""
95
+ if current_count >= MAX_SEGMENTS:
96
+ return [current_count] + [gr.update() for _ in range(MAX_SEGMENTS)] + [gr.update(interactive=False, value=f"已达上限 ({MAX_SEGMENTS})")]
97
+
98
+ new_count = current_count + 1
99
+ vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
100
+ btn = gr.update(interactive=(new_count < MAX_SEGMENTS), value="➕新片段")
101
+ return [new_count] + vis + [btn]
102
+
103
+ def load_srt_fn(srt_file, current_count):
104
+ empty_fields = [gr.update() for _ in range(MAX_SEGMENTS * 4)]
105
+ empty_vis = [gr.update() for _ in range(MAX_SEGMENTS)]
106
+ if not srt_file:
107
+ return [current_count] + empty_fields + empty_vis + [gr.update()]
108
+ try:
109
+ with open(srt_file, 'r', encoding='utf-8-sig') as f:
110
+ content = f.read()
111
+ except Exception as e:
112
+ gr.Warning(f"读取 SRT 文件失败: {e}")
113
+ return [current_count] + empty_fields + empty_vis + [gr.update()]
114
+ parsed = parse_srt_content(content)
115
+ if not parsed:
116
+ print(" 未解析到有效字幕,请检查 SRT 格式")
117
+ return [current_count] + empty_fields + empty_vis + [gr.update()]
118
+ updates = []
119
+ for i in range(MAX_SEGMENTS):
120
+ if i < len(parsed):
121
+ seg = parsed[i]
122
+ updates.append(gr.update(value=seg['text']))
123
+ updates.append(gr.update(value=round(seg['start'], 2)))
124
+ updates.append(gr.update(value=round(seg['end'], 2)))
125
+ updates.append(gr.update(value=True))
126
+ else:
127
+ updates.append(gr.update(value=""))
128
+ updates.append(gr.update(value=0.0))
129
+ updates.append(gr.update(value=5.0 + i*5))
130
+ updates.append(gr.update(value=False))
131
+ new_count = min(len(parsed), MAX_SEGMENTS)
132
+ vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
133
+ btn = gr.update(interactive=(new_count < MAX_SEGMENTS))
134
+ if len(parsed) > MAX_SEGMENTS:
135
+ gr.Warning(f"SRT 包含 {len(parsed)} 个片段,已截取前 {MAX_SEGMENTS} 条")
136
+
137
+ return [new_count] + updates + vis + [btn]
138
+
139
+ def process_dubbing(video_file, *segment_inputs, progress=gr.Progress()):
140
+ """主推理流程"""
141
+ if not video_file:
142
+ return None, "❌ 请上传视频文件"
143
+
144
+ video_duration = get_video_duration(video_file)
145
+ if video_duration <= 0:
146
+ return None, "❌ 无法获取视频时长,请检查视频文件"
147
+
148
+ if os.path.exists(TEMP_DIR):
149
+ try:
150
+ shutil.rmtree(TEMP_DIR)
151
+ except Exception as e:
152
+ return None, f"❌ 清空临时目录失败:{e}"
153
+ os.makedirs(TEMP_DIR, exist_ok=True)
154
+
155
+ # 解析 segment_inputs
156
+ segments_data = []
157
+ for i in range(MAX_SEGMENTS):
158
+ base_idx = i * 8
159
+ enable = segment_inputs[base_idx + 7] # enable_check
160
+ if not enable: continue
161
+ text = segment_inputs[base_idx + 0]
162
+ if not text or not text.strip(): continue
163
+
164
+ clue = segment_inputs[base_idx + 1]
165
+ start = segment_inputs[base_idx + 2]
166
+ end = segment_inputs[base_idx + 3]
167
+ age = segment_inputs[base_idx + 4]
168
+ gender = segment_inputs[base_idx + 5]
169
+ ref_audio = segment_inputs[base_idx + 6]
170
+
171
+ errors = validate_timestamps(start, end, video_duration)
172
+ if errors:
173
+ return None, f"❌ 片段 {i+1} 时间戳错误:\n" + "\n".join(errors)
174
+
175
+ data = {
176
+ "text": str(text).strip(),
177
+ "clue": str(clue) if clue else "",
178
+ "start": float(start) if start else 0.0,
179
+ "end": float(end) if end else 0.0,
180
+ "age": str(age) if age else "不确定",
181
+ "gender": str(gender) if gender else "不确定",
182
+ "ref_audio": str(ref_audio) if ref_audio else ""
183
+ }
184
+
185
+ segments_data.append(data)
186
+
187
+ if not segments_data:
188
+ return None, "❌ 有效片段数据为空,请启用并填写至少一个片段"
189
+
190
+ try:
191
+ progress(0.1, desc="📋 预处理视频,生成 JSONL 数据...")
192
+ frontend = init_frontend_models()
193
+ jsonl_path, jsonl_items = generate_jsonl_data(frontend, video_file, segments_data, TEMP_DIR, video_duration)
194
+ report_lines = [f"✅ 任务完成!共生成 **{len(jsonl_items)}** 个片段数据。\n", "详细 JSONL 数据预览:**", "=" * 40]
195
+ for idx, item in enumerate(jsonl_items):
196
+ report_lines.extend([f"\n---片段 #{idx + 1} ---", json.dumps(item, ensure_ascii=False, indent=2), "-" * 40])
197
+ full_report = "\n".join(report_lines)
198
+
199
+ progress(0.3, desc="🔄 FunCineForge 模型加载中...")
200
+
201
+ eng = init_engine()
202
+ if eng and jsonl_items:
203
+ try:
204
+ progress(0.5, desc="🚀 FunCineForge 模型推理中...")
205
+ eng.inference(jsonl_path)
206
+
207
+ progress(0.8, desc="🎵 正在将配音语音粘贴回静音视频...")
208
+
209
+ output_wav_dir = os.path.join(TEMP_DIR, "wav")
210
+ final_video_path = os.path.join(TEMP_DIR, "dubbed_video.mp4")
211
+
212
+ if not os.path.exists(output_wav_dir):
213
+ return None, f"⚠️ 未找到音频输出目录:{output_wav_dir}"
214
+
215
+ wav_files = sorted([f for f in os.listdir(output_wav_dir) if f.endswith('.wav')])
216
+ if not wav_files:
217
+ return None, f"⚠️ 未生成任何音频文件:{output_wav_dir}"
218
+
219
+ time_mapping = {}
220
+ for item in jsonl_items:
221
+ for wf in wav_files:
222
+ if wf.startswith(item['utt']):
223
+ time_mapping[wf] = float(item['start'])
224
+ break
225
+
226
+ original_clip = VideoFileClip(video_file)
227
+ video_duration = original_clip.duration
228
+ is_silent = original_clip.audio is None
229
+ video_only = original_clip if is_silent else original_clip.without_audio()
230
+ audio_clips = []
231
+ for wav_file, start_time in time_mapping.items():
232
+ wav_path = os.path.join(output_wav_dir, wav_file)
233
+ audio_clip = AudioFileClip(wav_path).with_start(start_time)
234
+ audio_clips.append(audio_clip)
235
+
236
+ final_audio = CompositeAudioClip(audio_clips)
237
+ if final_audio.duration < video_duration:
238
+ final_audio = final_audio.with_duration(video_duration)
239
+ final_clip = video_only.with_audio(final_audio)
240
+ final_clip.write_videofile(
241
+ final_video_path,
242
+ codec='libx264',
243
+ audio_codec='aac',
244
+ preset='veryfast',
245
+ threads=8,
246
+ fps=original_clip.fps,
247
+ logger=None
248
+ )
249
+ original_clip.close(); video_only.close()
250
+ for ac in audio_clips: ac.close()
251
+ if 'final_audio' in locals(): final_audio.close()
252
+ final_clip.close()
253
+
254
+ progress(1.0, desc="✅ 配音完成")
255
+ return final_video_path, full_report
256
+ except Exception as e:
257
+ import traceback; traceback.print_exc()
258
+ if "index out of range" in str(e):
259
+ return None, f"⚠️ 模型推理失败。错误:{str(e)},建议补齐输入的线索描述和说话人属性"
260
+ else:
261
+ return None, f"⚠️ 模型推理失败。错误:{str(e)}"
262
+ else:
263
+ time.sleep(1)
264
+ progress(1.0, desc="模拟完成")
265
+ return video_file, full_report
266
+
267
+ except Exception as e:
268
+ import traceback; traceback.print_exc()
269
+ return None, f"❌ 发生错误:{str(e)}"
270
+
271
+
272
+ # ==================== 主程序 ====================
273
+
274
+ def main():
275
+ os.makedirs(TEMP_DIR, exist_ok=True)
276
+ with gr.Blocks(
277
+ title="Fun-CineForge 影视配音平台",
278
+ theme=gr.themes.Soft(),
279
+ css="""
280
+ .segment-accordion { margin: 10px 0; }
281
+ .gr-button-primary { background: #1976d2; }
282
+ .gr-button-stop { background: #d32f2f; }
283
+ """
284
+ ) as demo:
285
+
286
+ gr.Markdown("""
287
+ # 🎬 Fun-CineForge
288
+
289
+ **工作流程:** 上传短视频 → 配音片段信息(或上传 .srt 字幕文件) → 上传参考音色(可选) → 预处理、模型加载和推理 → 输出配音视频
290
+ """)
291
+
292
+ with gr.Row():
293
+ with gr.Column(scale=1):
294
+ video_input = gr.Video(label="上传视频", sources=["upload"])
295
+ load_video_btn = gr.Button("📂 加载示例视频", variant="secondary", size="sm")
296
+ srt_input = gr.UploadButton("上传 SRT 字幕", file_types=[".srt"], size="sm", variant="secondary")
297
+ # with gr.Row(elem_classes=["srt-compact"]):
298
+ # srt_input = gr.File(label="上传 SRT 字幕", file_types=[".srt"], height="auto")
299
+ gr.Markdown("### 🎛️ 配音片段配置")
300
+
301
+ segments, accordions = create_segments_ui()
302
+ seg_count_state = gr.State(1) #🔑记录当前可见片段数
303
+ add_segment_btn = gr.Button("➕添加新片段", size="sm", variant="secondary")
304
+ submit_btn = gr.Button("🚀 开始生成配音", variant="stop", size="lg")
305
+
306
+ with gr.Column(scale=1):
307
+ video_output = gr.Video(label="📺 配音后视频", autoplay=True)
308
+
309
+ status_text = gr.Textbox(label="结果状态", interactive=False, lines=2)
310
+
311
+ gr.Markdown("""
312
+ ### 📝 使用说明
313
+ | 字段 | 说明 |
314
+ |------|------|
315
+ | 配音文本 | 该片段台词内容(支持中/英) |
316
+ | 线索描述 | 请参考样例格式,阐述配音要求,重点描述说话人的性别年龄、语气和情感 |
317
+ | 时间戳 | 起止时间戳 (可精确到毫秒),模型对时间戳敏感,建议紧邻有声区间。时长 ≤30s/片段 |
318
+ | 年龄/性别 | 说话人属性选项 |
319
+ | 参考语音 | 音色克隆参考 (可选) |
320
+
321
+ **⚠️ 注意:** 确保每个片段的时间戳不重叠,且时间戳不超过视频总时长。模型会根据片段的时��长度进行强制时间对齐,弱监督对齐唇部运动。
322
+ """)
323
+
324
+ # ==================== 事件绑定 ====================
325
+
326
+ # 收集所有片段组件作为输入
327
+ segment_inputs = []
328
+ for seg in segments:
329
+ segment_inputs.extend([
330
+ seg["text"],
331
+ seg["clue"],
332
+ seg["start"],
333
+ seg["end"],
334
+ seg["age"],
335
+ seg["gender"],
336
+ seg["audio"],
337
+ seg["enable"]
338
+ ])
339
+
340
+ srt_update_fields = []
341
+ for seg in segments:
342
+ srt_update_fields.extend([seg["text"], seg["start"], seg["end"], seg["enable"]])
343
+
344
+ # 动态添加片段
345
+ add_segment_btn.click(
346
+ fn=add_segment_fn,
347
+ inputs=[seg_count_state],
348
+ outputs=[seg_count_state] + accordions + [add_segment_btn]
349
+ )
350
+
351
+ # SRT 加载
352
+ srt_input.upload(
353
+ fn=load_srt_fn,
354
+ inputs=[srt_input, seg_count_state],
355
+ outputs=[seg_count_state] + srt_update_fields + accordions + [add_segment_btn]
356
+ )
357
+
358
+ # 主推理
359
+ submit_btn.click(
360
+ fn=process_dubbing,
361
+ inputs=[video_input] + segment_inputs,
362
+ outputs=[video_output, status_text]
363
+ )
364
+
365
+ # 视频上传联动时间戳
366
+ def update_timestamps(video):
367
+ if not video: return [gr.update() for _ in range(MAX_SEGMENTS * 2)]
368
+ dur = get_video_duration(video)
369
+ updates = []
370
+ for i in range(MAX_SEGMENTS):
371
+ updates.append(gr.update(value=0.0))
372
+ updates.append(gr.update(value=dur))
373
+ return updates
374
+
375
+ def load_default_video_fn():
376
+ return DEFAULT_VIDEO_PATH, DEFAULT_TEXT, DEFAULT_CLUE
377
+
378
+ def load_default_audio_fn():
379
+ return DEFAULT_AUDIO_PATH
380
+
381
+ load_video_btn.click(
382
+ fn=load_default_video_fn,
383
+ inputs=[],
384
+ outputs=[video_input, segments[0]["text"], segments[0]["clue"]]
385
+ ).then(
386
+ fn=update_timestamps,
387
+ inputs=[video_input],
388
+ outputs=[segment_inputs[i] for i in range(len(segment_inputs)) if i % 8 in [2, 3]]
389
+ )
390
+
391
+ video_input.change(
392
+ fn=update_timestamps,
393
+ inputs=[video_input],
394
+ outputs=[comp for pair in zip(segment_inputs[2::8], segment_inputs[3::8]) for comp in pair]
395
+ )
396
+
397
+ if segments and segments[0]["load_audio_btn"]:
398
+ segments[0]["load_audio_btn"].click(
399
+ fn=load_default_audio_fn,
400
+ inputs=[],
401
+ outputs=[segments[0]["audio"]]
402
+ )
403
+
404
+ # ==================== 启动服务 ====================
405
+
406
+ demo.launch(
407
+ server_name="0.0.0.0",
408
+ server_port=SERVER_PORT,
409
+ share=False,
410
+ show_error=True,
411
+ inbrowser=True,
412
+ )
413
+
414
+ if __name__ == "__main__":
415
+ main()
data/ref.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8420568976edb1cf17a63d9fa968aedaf3c0f68cca4dbf75a409876b96ad700b
3
+ size 788876
data/sample.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b901981a2213fc7f98cd6424869710e8396eb558ff2ff3e8ab5d52fe427e0ab6
3
+ size 2567737
decode_conf/decode.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: FunCineForgeInferModel
2
+ index_ds: FunCineForgeDS
3
+ xvec_model: pretrained_models/funcineforge_zh_en/camplus.onnx
4
+ model_conf: {}
5
+
6
+ dataset_conf:
7
+ # face is from the video, vocal is the reference audio, extract speaker ID and start-end timestamp from dialogue
8
+ load_meta_data_key: "text,clue,face,dialogue,vocal,video"
9
+ sos: 6561
10
+ eos: 6562
11
+ turn_of_speech: 6563
12
+ fill_token: 6564
13
+ ignore_id: -100
14
+ startofclue_token: 151646
15
+ endofclue_token: 151647
16
+ frame_shift: 25 # ms
17
+ timebook_size: 1500 # 60 * 25 = 1500
18
+ pangbai: 1500
19
+ dubai: 1501
20
+ duihua: 1502
21
+ duoren: 1503
22
+ male: 1504
23
+ female: 1505
24
+ child: 1506
25
+ youth: 1507
26
+ adult: 1508
27
+ middle: 1509
28
+ elderly: 1510
29
+ speaker_id_start: 1511
30
+
31
+
32
+ sampling: ras
33
+ lm_use_prompt: true
34
+ fm_use_prompt: true
35
+ use_llm_cache: true
36
+ seed: 0
37
+ max_length: 1500 # 60s * 25 fps
38
+ min_length: 50 # 2s * 25 fps
39
+ llm_dtype: fp32
40
+ fm_dtype: fp32
41
+ voc_dtype: fp32
42
+ batch_size: 1
decode_conf/diar.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diarization config
2
+
3
+ fbank_dim: 80
4
+ embedding_size: 192
5
+
6
+ feature_extractor:
7
+ obj: speakerlab.process.processor.FBank
8
+ args:
9
+ n_mels: <fbank_dim>
10
+ sample_rate: <sample_rate>
11
+ mean_nor: True
12
+
13
+ embedding_model:
14
+ obj: speakerlab.models.campplus.DTDNN.CAMPPlus
15
+ args:
16
+ feat_dim: <fbank_dim>
17
+ embedding_size: <embedding_size>
18
+
19
+ # for visual embeddings extraction
20
+ min_track: 10
21
+ num_failed_det: 10
22
+ crop_scale: 0.4
23
+ min_face_size: 1
24
+ face_det_stride: 5 # 每5帧检测一次人脸
25
+ shot_stride: 50
26
+
27
+ # for clustering
28
+ audio_cluster:
29
+ obj: speakerlab.process.cluster.CommonClustering
30
+ args:
31
+ cluster_type: spectral
32
+ min_num_spks: 1
33
+ max_num_spks: 15
34
+ min_cluster_size: 1
35
+ oracle_num: null
36
+ pval: 0.032
37
+ mer_cos: 0.8
38
+
39
+ vision_cluster:
40
+ obj: speakerlab.process.cluster.CommonClustering
41
+ args:
42
+ cluster_type: AHC
43
+ cluster_line: 2
44
+ min_cluster_size: 1
45
+ fix_cos_thr: 0.25
46
+
47
+ cluster:
48
+ obj: speakerlab.process.cluster.JointClustering
49
+ args:
50
+ audio_cluster: <audio_cluster>
51
+ vision_cluster: <vision_cluster>
decode_conf/ds_stage0_fp32.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu": 1,
3
+ "gradient_accumulation_steps": 1,
4
+ "steps_per_print": 100,
5
+ "gradient_clipping": 5,
6
+ "fp16": {
7
+ "enabled": false,
8
+ "auto_cast": false,
9
+ "loss_scale": 0,
10
+ "initial_scale_power": 16,
11
+ "loss_scale_window": 1000,
12
+ "hysteresis": 2,
13
+ "consecutive_hysteresis": false,
14
+ "min_loss_scale": 1
15
+ },
16
+ "bf16": {
17
+ "enabled": false
18
+ },
19
+ "zero_force_ds_cpu_optimizer": false,
20
+ "zero_optimization": {
21
+ "stage": 0,
22
+ "offload_optimizer": {
23
+ "device": "none",
24
+ "pin_memory": true
25
+ },
26
+ "allgather_partitions": true,
27
+ "allgather_bucket_size": 5e8,
28
+ "overlap_comm": true,
29
+ "reduce_scatter": true,
30
+ "reduce_bucket_size": 5e8,
31
+ "contiguous_gradients" : true
32
+ }
33
+ }
funcineforge/.DS_Store ADDED
Binary file (8.2 kB). View file
 
funcineforge/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Initialize package."""
2
+
3
+ import os
4
+ from funcineforge.auto.auto_model import AutoModel
5
+ from funcineforge.auto.auto_frontend import AutoFrontend
6
+
7
+ os.environ["HYDRA_FULL_ERROR"] = "1"
funcineforge/auto/__init__.py ADDED
File without changes
funcineforge/auto/auto_frontend.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ from omegaconf import OmegaConf
5
+ from funcineforge.utils.hinter import get_logger
6
+ from funcineforge.models.utils import dtype_map
7
+ from funcineforge.datasets import FunCineForgeDS
8
+
9
+ class AutoFrontend:
10
+ def __init__(
11
+ self,
12
+ ckpt_path: str,
13
+ config_path: str,
14
+ output_dir: str,
15
+ device: str = "cuda:0"
16
+ ):
17
+ self.logger = get_logger(log_level=logging.INFO, local_rank=1, world_size=1)
18
+ self.device = device
19
+ self.output_dir = output_dir
20
+ self.lm_model = None
21
+ self.fm_model = None
22
+ self.voc_model = None
23
+ self.model = None
24
+ self.index_ds_class = None
25
+
26
+ self.dataset_conf = None
27
+ self.kwargs = OmegaConf.load(config_path)
28
+
29
+ if device.startswith("cuda"):
30
+ try:
31
+ device_id = int(device.split(":")[-1])
32
+ torch.cuda.set_device(device_id)
33
+ except (ValueError, IndexError):
34
+ self.logger.warning(f"Invalid cuda device string {device}, defaulting to 0")
35
+ torch.cuda.set_device(0)
36
+ else:
37
+ self.logger.info(f"Running on CPU")
38
+
39
+
40
+ lm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/llm/ds-model.pt.best/mp_rank_00_model_states.pt")
41
+ fm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/flow/ds-model.pt.best/mp_rank_00_model_states.pt")
42
+ voc_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/vocoder/ds-model.pt.best/avg_5_removewn.pt")
43
+
44
+ lm_exp_dir, lm_model_name, lm_ckpt_id, _ = lm_ckpt_path.rsplit("/", 3)
45
+ self.logger.info(f"init LM model form {lm_ckpt_path}")
46
+
47
+ from funcineforge import AutoModel
48
+ self.lm_model = (AutoModel(
49
+ model=os.path.join(lm_exp_dir, lm_model_name),
50
+ init_param=lm_ckpt_path,
51
+ output_dir=None,
52
+ device=device,
53
+ ))
54
+ self.lm_model.model.to(dtype_map[self.kwargs.get("llm_dtype", "fp32")])
55
+
56
+ fm_exp_dir, fm_model_name, fm_ckpt_id, _ = fm_ckpt_path.rsplit("/", 3)
57
+ self.logger.info(f"build FM model form {fm_ckpt_path}")
58
+ self.fm_model = AutoModel(
59
+ model=os.path.join(fm_exp_dir, fm_model_name),
60
+ init_param=fm_ckpt_path,
61
+ output_dir=None,
62
+ device=device,
63
+ )
64
+ self.fm_model.model.to(dtype_map[self.kwargs.get("fm_dtype", "fp32")])
65
+
66
+ voc_exp_dir, voc_model_name, voc_ckpt_id, _ = voc_ckpt_path.rsplit("/", 3)
67
+ self.logger.info(f"build VOC model form {voc_ckpt_path}")
68
+ self.voc_model = AutoModel(
69
+ model=os.path.join(voc_exp_dir, voc_model_name),
70
+ init_param=voc_ckpt_path,
71
+ output_dir=None,
72
+ device=device,
73
+ )
74
+ self.voc_model.model.to(dtype_map[self.kwargs.get("voc_dtype", "fp32")])
75
+
76
+ self.logger.info(f"build inference model {self.kwargs.get('model')}")
77
+ self.kwargs["output_dir"] = output_dir
78
+ self.kwargs["tokenizer"] = None
79
+ self.model = AutoModel(
80
+ **self.kwargs,
81
+ lm_model=self.lm_model,
82
+ fm_model=self.fm_model,
83
+ voc_model=self.voc_model,
84
+ )
85
+ self.dataset_conf = self.kwargs.get("dataset_conf")
86
+
87
+ def inference(self, jsonl_path: str):
88
+ if not self.model:
89
+ raise RuntimeError("Model class not initialized.")
90
+
91
+ dataset = FunCineForgeDS(jsonl_path, **self.dataset_conf)
92
+ self.logger.info(f"Starting inference on {len(dataset)} items...")
93
+
94
+ self.model.inference(input=dataset, input_len=len(dataset))
95
+ self.logger.info("Inference finished.")
funcineforge/auto/auto_model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import time
5
+ import torch
6
+ import logging
7
+ import os
8
+ from tqdm import tqdm
9
+ from funcineforge.utils.misc import deep_update
10
+ from funcineforge.utils.set_all_random_seed import set_all_random_seed
11
+ from funcineforge.utils.load_pretrained_model import load_pretrained_model
12
+ from funcineforge.download.download_model_from_hub import download_model
13
+ from funcineforge.tokenizer import FunCineForgeTokenizer
14
+ from funcineforge.face import FaceRecIR101
15
+ import importlib
16
+
17
+
18
+ def prepare_data_iterator(data_in, input_len):
19
+ """ """
20
+ data_list = []
21
+ key_list = []
22
+ for idx in range(input_len):
23
+ item = data_in[idx]
24
+ utt = item["utt"]
25
+ data_list.append(item)
26
+ key_list.append(utt)
27
+ return key_list, data_list
28
+
29
+
30
+ class AutoModel:
31
+
32
+ def __init__(self, **kwargs):
33
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
34
+ logging.basicConfig(level=log_level)
35
+ model, kwargs = self.build_model(**kwargs)
36
+ self.kwargs = kwargs
37
+ self.model = model
38
+ self.model_path = kwargs.get("model_path")
39
+
40
+ @staticmethod
41
+ def build_model(**kwargs):
42
+ assert "model" in kwargs
43
+ if "model_conf" not in kwargs:
44
+ logging.info("download models from {} or local dir".format(kwargs.get("hub", "ms")))
45
+ kwargs = download_model(**kwargs)
46
+
47
+ set_all_random_seed(kwargs.get("seed", 0))
48
+
49
+ device = kwargs.get("device", "cuda")
50
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
51
+ device = "cpu"
52
+ kwargs["batch_size"] = 1
53
+ kwargs["device"] = device
54
+
55
+ torch.set_num_threads(kwargs.get("ncpu", 4))
56
+
57
+ # build tokenizer
58
+ tokenizer = kwargs.get("tokenizer", None)
59
+ if tokenizer is not None:
60
+ tokenizer = FunCineForgeTokenizer(**kwargs.get("tokenizer_conf", {}))
61
+ kwargs["token_list"] = (
62
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
63
+ )
64
+ kwargs["token_list"] = (
65
+ tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
66
+ )
67
+ vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
68
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
69
+ vocab_size = tokenizer.get_vocab_size()
70
+ else:
71
+ vocab_size = -1
72
+ kwargs["tokenizer"] = tokenizer
73
+
74
+ # build face_encoder
75
+ face_encoder = kwargs.get("face_encoder", None)
76
+ if face_encoder is not None:
77
+ face_encoder = FaceRecIR101(**kwargs.get("face_encoder_conf", {}))
78
+ kwargs["face_encoder"] = face_encoder
79
+
80
+ model_conf = {}
81
+ model_class_name = kwargs["model"]
82
+ deep_update(model_conf, kwargs.get("model_conf", {}))
83
+ deep_update(model_conf, kwargs)
84
+ module = importlib.import_module("funcineforge.models")
85
+ model_class = getattr(module, model_class_name)
86
+ model = model_class(**model_conf, vocab_size=vocab_size)
87
+
88
+ # init_param
89
+ init_param = kwargs.get("init_param", None)
90
+ if init_param is not None and os.path.exists(init_param):
91
+ logging.info(f"Loading pretrained params from ckpt: {init_param}")
92
+ load_pretrained_model(
93
+ path=init_param,
94
+ model=model,
95
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
96
+ scope_map=kwargs.get("scope_map", []),
97
+ excludes=kwargs.get("excludes", None),
98
+ use_deepspeed=kwargs.get("train_conf", {}).get("use_deepspeed", False),
99
+ save_deepspeed_zero_fp32=kwargs.get("save_deepspeed_zero_fp32", True),
100
+ )
101
+
102
+ # fp16
103
+ if kwargs.get("fp16", False):
104
+ model.to(torch.float16)
105
+ elif kwargs.get("bf16", False):
106
+ model.to(torch.bfloat16)
107
+ model.to(device)
108
+
109
+ return model, kwargs
110
+
111
+ def __call__(self, *args, **cfg):
112
+ kwargs = self.kwargs
113
+ deep_update(kwargs, cfg)
114
+ res = self.model(*args, kwargs)
115
+ return res
116
+
117
+
118
+ def inference(self, input, input_len=None, model=None, kwargs=None, **cfg):
119
+ kwargs = self.kwargs if kwargs is None else kwargs
120
+ deep_update(kwargs, cfg)
121
+ model = self.model if model is None else model
122
+ model.eval()
123
+ batch_size = kwargs.get("batch_size", 1)
124
+ key_list, data_list = prepare_data_iterator(
125
+ input, input_len=input_len
126
+ )
127
+
128
+ speed_stats = {}
129
+ num_samples = len(data_list)
130
+ disable_pbar = self.kwargs.get("disable_pbar", False)
131
+ pbar = (
132
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
133
+ )
134
+ time_speech_total = 0.0
135
+ time_escape_total = 0.0
136
+ count = 0
137
+ log_interval = kwargs.get("log_interval", None)
138
+ for beg_idx in range(0, num_samples, batch_size):
139
+ end_idx = min(num_samples, beg_idx + batch_size)
140
+ data_batch = data_list[beg_idx:end_idx]
141
+ key_batch = key_list[beg_idx:end_idx]
142
+ batch = {"data_in": data_batch, "data_lengths": end_idx - beg_idx, "key": key_batch}
143
+
144
+ time1 = time.perf_counter()
145
+ with torch.no_grad():
146
+ res = model.inference(**batch, **kwargs)
147
+ if isinstance(res, (list, tuple)):
148
+ results = res[0] if len(res) > 0 else [{"text": ""}]
149
+ meta_data = res[1] if len(res) > 1 else {}
150
+ time2 = time.perf_counter()
151
+
152
+ batch_data_time = meta_data.get("batch_data_time", -1)
153
+ time_escape = time2 - time1
154
+ speed_stats["forward"] = f"{time_escape:0.3f}"
155
+ speed_stats["batch_size"] = f"{len(results)}"
156
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
157
+ description = f"{speed_stats}, "
158
+ if pbar:
159
+ pbar.update(batch_size)
160
+ pbar.set_description(description)
161
+ else:
162
+ if log_interval is not None and count % log_interval == 0:
163
+ logging.info(
164
+ f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}"
165
+ )
166
+ time_speech_total += batch_data_time
167
+ time_escape_total += time_escape
168
+ count += 1
169
+
170
+ if pbar:
171
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
172
+ torch.cuda.empty_cache()
173
+ return
funcineforge/datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .index_ds import FunCineForgeDS
2
+ from .datasets import FunCineForgeDataset
funcineforge/datasets/datasets.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import pickle
4
+ import numpy as np
5
+ from funcineforge.utils.hinter import hint_once
6
+ from funcineforge.datasets import FunCineForgeDS
7
+ from funcineforge.models import FunCineForgeSpecAug
8
+
9
+ class FunCineForgeDataset(torch.utils.data.Dataset):
10
+ """
11
+ Dataset for Mixed LM of FunCineForge
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ path,
17
+ index_ds: str = None,
18
+ frontend=None,
19
+ tokenizer=None,
20
+ face_encoder=None,
21
+ int_pad_value: int = -1,
22
+ float_pad_value: float = 0.0,
23
+ **kwargs,
24
+ ):
25
+ super().__init__()
26
+ self.index_ds = FunCineForgeDS(path, **kwargs)
27
+ self.tokenizer = tokenizer
28
+ self.face_encoder = face_encoder
29
+
30
+ self.int_pad_value = int_pad_value
31
+ self.float_pad_value = float_pad_value
32
+ self.batch_size = kwargs.get("batch_size")
33
+ self.batch_type = kwargs.get("batch_type")
34
+ self.retry = kwargs.get("retry", 100)
35
+
36
+ # self.kwargs = kwargs
37
+ self.max_token_length = kwargs.get("max_token_length", 1500)
38
+ self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5)
39
+ self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500)
40
+ self.multiturn_num_max = kwargs.get("multiturn_num_max", 1)
41
+ self.face_size = kwargs.get("face_size", 512)
42
+
43
+ self.codebook_size = kwargs.get("codebook_size", 6561)
44
+ self.sos = kwargs.get("sos", self.codebook_size)
45
+ self.eos = kwargs.get("eos", self.codebook_size + 1)
46
+ self.turn_of_speech = kwargs.get("turn_of_speech", self.codebook_size + 2)
47
+ self.ignore_id = kwargs.get("ignore_id", -100)
48
+
49
+ specaug = kwargs.get("specaug", None)
50
+ specaug_conf = kwargs.get("specaug_conf", {})
51
+ if specaug is not None:
52
+ specaug = FunCineForgeSpecAug(**specaug_conf)
53
+ self.specaug = specaug
54
+
55
+ self.set_invalid_xvec_zeros = kwargs.get("set_invalid_xvec_zeros", False)
56
+ self.use_emotion_clue = kwargs.get("use_emotion_clue", False)
57
+ logging.info(f"use_emotion_clue: {self.use_emotion_clue}")
58
+
59
+ def get_source_len(self, index):
60
+ item = self.index_ds[index]
61
+ source_len = self.index_ds.get_source_len(item)
62
+ return source_len
63
+
64
+ def get_target_len(self, index):
65
+ item = self.index_ds[index]
66
+ return self.index_ds.get_target_len(item)
67
+
68
+ def __len__(self):
69
+ return len(self.index_ds)
70
+
71
+ def mixup_text_codec(self, text: torch.Tensor, aug_codec: torch.Tensor, timespk_ids: torch.Tensor, type_id: int):
72
+ text_len = text.shape[0]
73
+ timespk_len = timespk_ids.shape[0]
74
+ sequence = [self.sos, *text.tolist(), type_id, *timespk_ids.tolist(), self.turn_of_speech, *aug_codec.tolist(), self.eos]
75
+ # sequence = [self.sos, *text.tolist(), type_id, self.turn_of_speech, *aug_codec.tolist(), self.eos]
76
+ input_ids = torch.tensor(sequence, dtype=torch.int64)
77
+ text_flag = torch.zeros(len(sequence), dtype=torch.float32)
78
+ text_flag[1:text_len+1] = 1
79
+ timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
80
+ timespk_flag[text_len+1:text_len+2+timespk_len] = 1
81
+ # timespk_flag[text_len+1:text_len+2] = 1
82
+ codec_flag = 1 - (text_flag + timespk_flag)
83
+ labels = torch.tensor(sequence, dtype=torch.int64)
84
+ labels[:text_len+timespk_len+3] = self.ignore_id
85
+ # labels[:text_len+3] = self.ignore_id
86
+
87
+ return input_ids, labels, text_flag, codec_flag, timespk_flag
88
+
89
+ def __getitem__(self, index):
90
+ output = None
91
+ for idx in range(self.retry):
92
+ if idx == 0:
93
+ index_cur = index
94
+ else:
95
+ index_cur = torch.randint(0, len(self.index_ds), ()).item()
96
+ item = self.index_ds[index_cur]
97
+
98
+ # clue + text
99
+ text = item["text"]
100
+ clue = "<|startofclue|>" + item["clue"] + "<|endofclue|>"
101
+ if self.use_emotion_clue:
102
+ text = clue + text
103
+ text_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.int32)
104
+ hint_once(f"raw text: {text}", "log_text")
105
+
106
+ # speech tokens
107
+ target_out = item["token"]
108
+ codec = torch.from_numpy(np.load(target_out))
109
+ codec_len = codec.shape[0] # 可用数据集中的 speech_length 代替
110
+ aug_codec = codec.clone()
111
+ if self.specaug is not None: # aug_codec是随机mask的codec增强鲁棒性
112
+ aug_codec, _ = self.specaug(aug_codec.float().unsqueeze(0).unsqueeze(-1))
113
+ aug_codec = aug_codec.squeeze(0).squeeze(-1).long()
114
+
115
+ # dialogue
116
+ timespk_ids = torch.from_numpy(item["timespk_ids"])
117
+
118
+ # mixup
119
+ type_id = item["type_id"]
120
+ input_ids, labels, text_flag, codec_flag, timespk_flag = self.mixup_text_codec(
121
+ text_ids, aug_codec, timespk_ids, type_id
122
+ )
123
+
124
+ # face
125
+ face_features = item["face"]
126
+ face_emb = torch.zeros((codec_len, self.face_size), dtype=torch.float32) # face_emb 长度与 codec_len 相同
127
+ with open(face_features, 'rb') as f:
128
+ stat_obj = pickle.load(f)
129
+ embeddings = stat_obj['embeddings']
130
+ faceI = stat_obj['faceI']
131
+ for emb, frameI in zip(embeddings, faceI):
132
+ fi = int(frameI)
133
+ if 0 <= fi < codec_len:
134
+ end = min(fi + 5, codec_len)
135
+ face_emb[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
136
+
137
+ # attention_mask 对应序列长度包括input_id=(sos, <|startofclue|>, clue, <|endofclue|>, text, type_id, timespk_ids, turn_of_speech, speech, eos)
138
+ attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
139
+ codec_len = torch.tensor([codec_len], dtype=torch.int32)
140
+ output = {
141
+ "input_ids": input_ids,
142
+ "face_emb": face_emb,
143
+ "attention_mask": attention_mask,
144
+ "labels_ids": labels,
145
+ "text_flag": text_flag,
146
+ "codec_flag": codec_flag,
147
+ "timespk_flag": timespk_flag,
148
+ "codec_len": codec_len,
149
+ }
150
+ break
151
+ return output
152
+
153
+ def collator(self, samples: list = None):
154
+
155
+ for idx in range(self.retry):
156
+ badcase_flag = False
157
+
158
+ outputs = {}
159
+ for sample in samples:
160
+ if sample is None:
161
+ continue
162
+ for key in sample.keys():
163
+ if key not in outputs:
164
+ outputs[key] = []
165
+ if isinstance(sample[key], (list, tuple)):
166
+ outputs[key].extend(sample[key])
167
+ else:
168
+ outputs[key].append(sample[key])
169
+
170
+ for key, data_list in outputs.items():
171
+ if isinstance(data_list[0], torch.Tensor):
172
+ if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
173
+
174
+ pad_value = self.int_pad_value
175
+ else:
176
+ pad_value = self.float_pad_value
177
+
178
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(
179
+ data_list, batch_first=True, padding_value=pad_value
180
+ )
181
+
182
+ if self.batch_type != "example":
183
+ b, t = outputs["input_ids"].shape
184
+ if b > 1 and b * t > self.batch_size_token_max:
185
+ logging.info(
186
+ f"Warning, {idx}th, b*t: {b}*{t}={b * t} > batch_size_token_max: {self.batch_size_token_max}, drop last data"
187
+ )
188
+ samples = samples[:-1]
189
+ continue
190
+
191
+ break
192
+
193
+ return outputs
funcineforge/datasets/index_ds.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import logging
4
+ import numpy as np
5
+
6
+
7
+ class FunCineForgeDS(torch.utils.data.Dataset):
8
+
9
+ def __init__(self, data_jsonl: str, **kwargs):
10
+ super().__init__()
11
+
12
+ self.max_source_length = kwargs.get("max_source_length", None)
13
+ self.max_text_length = kwargs.get("max_text_length", None)
14
+ self.max_token_length = kwargs.get("max_token_length", None)
15
+ self.ignore_id = kwargs.get("ignore_id", -100)
16
+ self.frame_shift = kwargs.get("frame_shift", 25)
17
+ self.timebook_size = kwargs.get("timebook_size", 1500)
18
+ self.type_map = {"旁白": kwargs.get("pangbai", self.timebook_size),
19
+ "独白": kwargs.get("dubai", self.timebook_size + 1),
20
+ "对话": kwargs.get("duihua", self.timebook_size + 2),
21
+ "多人": kwargs.get("duoren", self.timebook_size + 3),}
22
+ self.gender_map = {"男": kwargs.get("male", self.timebook_size + 4),
23
+ "male": kwargs.get("male", self.timebook_size + 4),
24
+ "女": kwargs.get("female", self.timebook_size + 5),
25
+ "female": kwargs.get("female", self.timebook_size + 5),}
26
+ self.age_map = {"儿童": kwargs.get("child", self.timebook_size + 6),
27
+ "child": kwargs.get("child", self.timebook_size + 6),
28
+ "青年": kwargs.get("youth", self.timebook_size + 7),
29
+ "teenager": kwargs.get("youth", self.timebook_size + 7),
30
+ "中年": kwargs.get("adult", self.timebook_size + 8),
31
+ "adult": kwargs.get("adult", self.timebook_size + 8),
32
+ "中老年": kwargs.get("middle", self.timebook_size + 9),
33
+ "middle-aged": kwargs.get("middle", self.timebook_size + 9),
34
+ "老年": kwargs.get("elderly", self.timebook_size + 10),
35
+ "elderly": kwargs.get("elderly", self.timebook_size + 10)}
36
+ self.speaker_id_start = kwargs.get("speaker_id_start", self.timebook_size + 11)
37
+
38
+ load_meta_data_key = kwargs.get("load_meta_data_key").split(",")
39
+
40
+ if not (data_jsonl.endswith(".jsonl") or data_jsonl.endswith(".json")):
41
+ # jsonl list file
42
+ with open(data_jsonl, encoding="utf-8") as fin:
43
+ file_list = fin.readlines()
44
+ logging.info(f"file_list: {file_list}")
45
+ else:
46
+ file_list = [data_jsonl]
47
+
48
+ contents = []
49
+ for file_json in file_list:
50
+ with open(file_json.strip(), encoding="utf-8") as fin:
51
+ for line in fin:
52
+ data_dict = json.loads(line.strip())
53
+ utt = data_dict["utt"]
54
+ data_type = data_dict.get("type")
55
+ type_id = self.type_map[data_type] if data_type in self.type_map else 1500
56
+ data = data_dict["messages"]
57
+ speech_length = data_dict.get("speech_length", -1)
58
+ # 2 for startofclue, endofclue
59
+ text_length = data_dict.get("text_length", -1) + data_dict.get("clue_length", -1) + 2
60
+ if self.max_token_length is not None and (speech_length > self.max_token_length or speech_length <= 0):
61
+ logging.info(
62
+ f"speech_length: {speech_length} > {self.max_token_length}, drop it: {data_dict}"
63
+ )
64
+ continue
65
+ if self.max_text_length is not None and (text_length > self.max_text_length or text_length <= 0):
66
+ logging.info(
67
+ f"text_length: {text_length} > {self.max_text_length}, drop it: {data_dict}"
68
+ )
69
+ continue
70
+
71
+ skip_flag = None
72
+ roles = {item.get("role") for item in data}
73
+ for key in load_meta_data_key:
74
+ if key not in roles:
75
+ skip_flag = key
76
+ break
77
+ if skip_flag is not None:
78
+ logging.info(
79
+ f"doesn't have {skip_flag}, drop it: {data_dict}")
80
+ continue
81
+
82
+ contents_i = {}
83
+ timespk_ids_len = 0
84
+ for i, item in enumerate(data):
85
+ role = item["role"]
86
+ content = item["content"]
87
+ for key in load_meta_data_key:
88
+ if role == key:
89
+ if key == "dialogue":
90
+ timespk_ids = self.timespk_to_codec(content)
91
+ timespk_ids_len = len(timespk_ids)
92
+ if timespk_ids_len == 0:
93
+ logging.info(f"[WARNING] len of timespk_ids is 0: {data_dict}")
94
+ contents_i["timespk_ids"] = timespk_ids
95
+ else:
96
+ contents_i[role] = content
97
+ contents_i["utt"] = utt
98
+ contents_i["type_id"] = type_id
99
+ # face embs len = speech tokens len, so need * 2;
100
+ # 4: sos, tos, eos; type_id
101
+ contents_i["source_len"] = speech_length * 2 + text_length + timespk_ids_len + 4
102
+ contents_i["speech_len"] = speech_length
103
+ contents_i["text_len"] = text_length # include clue_length
104
+ contents.append(contents_i)
105
+
106
+ self.contents = contents
107
+
108
+ logging.info("total_num of samplers: {}, {}".format(len(self.contents), data_jsonl))
109
+
110
+
111
+ def timespk_to_codec(self, dialogue):
112
+ # tuple tokens (start, spk, gender, age, end) * n_parts
113
+ n_parts = len(dialogue)
114
+ if n_parts == 0:
115
+ return np.array([], dtype=np.int64)
116
+ starts = np.array([part["start"] for part in dialogue])
117
+ durations = np.array([part["duration"] for part in dialogue])
118
+ speakers = np.array([int(part["spk"]) for part in dialogue])
119
+ genders = [part["gender"] for part in dialogue]
120
+ ages = [part["age"] for part in dialogue]
121
+
122
+ start_idxs = (starts * self.frame_shift + 1).astype(np.int64)
123
+ end_idxs = ((starts + durations) * self.frame_shift + 1).astype(np.int64)
124
+ spk_ids = (self.speaker_id_start + speakers - 1).astype(np.int64)
125
+ gender_ids = [self.gender_map.get(g, self.ignore_id) for g in genders]
126
+ age_ids = [self.age_map.get(a, self.ignore_id) for a in ages]
127
+
128
+ sequence = np.full(n_parts * 5, self.ignore_id, dtype=np.int64)
129
+ sequence[0::5] = start_idxs
130
+ sequence[1::5] = spk_ids
131
+ sequence[2::5] = gender_ids
132
+ sequence[3::5] = age_ids
133
+ sequence[4::5] = end_idxs
134
+ return sequence
135
+
136
+ def __len__(self):
137
+ return len(self.contents)
138
+
139
+ def __getitem__(self, index):
140
+
141
+ data = self.contents[index]
142
+
143
+ return data
144
+
145
+ def get_source_len(self, data_dict):
146
+ source_len = data_dict.get("source_len", 0)
147
+ return source_len
148
+
149
+ def get_target_len(self, data_dict):
150
+ target_len = data_dict.get("speech_len", 0)
151
+ return target_len
funcineforge/download/__init__.py ADDED
File without changes
funcineforge/download/download_model_from_hub.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from omegaconf import OmegaConf, DictConfig
4
+ from funcineforge.download.name_maps_from_hub import name_maps_ms, name_maps_hf, name_maps_openai
5
+
6
+ def download_model(**kwargs):
7
+ hub = kwargs.get("hub", "ms")
8
+ if hub == "ms":
9
+ kwargs = download_from_ms(**kwargs)
10
+ elif hub == "hf":
11
+ kwargs = download_from_hf(**kwargs)
12
+ elif hub == "openai":
13
+ model_or_path = kwargs.get("model")
14
+ if os.path.exists(model_or_path):
15
+ # local path
16
+ kwargs["model_path"] = model_or_path
17
+ kwargs["model"] = "WhisperWarp"
18
+ else:
19
+ # model name
20
+ if model_or_path in name_maps_openai:
21
+ model_or_path = name_maps_openai[model_or_path]
22
+ kwargs["model_path"] = model_or_path
23
+
24
+ return kwargs
25
+
26
+
27
+ def download_from_ms(**kwargs):
28
+ model_or_path = kwargs.get("model")
29
+ if model_or_path in name_maps_ms:
30
+ model_or_path = name_maps_ms[model_or_path]
31
+ model_revision = kwargs.get("model_revision", "master")
32
+ if not os.path.exists(model_or_path) and "model_path" not in kwargs:
33
+ try:
34
+ model_or_path = get_or_download_model_dir(
35
+ model_or_path,
36
+ model_revision,
37
+ is_training=kwargs.get("is_training"),
38
+ check_latest=kwargs.get("check_latest", True),
39
+ )
40
+ except Exception as e:
41
+ print(f"Download: {model_or_path} failed!: {e}")
42
+
43
+ kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
44
+
45
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
46
+ with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
47
+ conf_json = json.load(f)
48
+
49
+ cfg = {}
50
+ if "file_path_metas" in conf_json:
51
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
52
+ # cfg.update(kwargs)
53
+ cfg = OmegaConf.merge(cfg, kwargs)
54
+ if "config" in cfg:
55
+ config = OmegaConf.load(cfg["config"])
56
+ kwargs = OmegaConf.merge(config, cfg)
57
+ kwargs["model"] = config["model"]
58
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")):
59
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
60
+ kwargs = OmegaConf.merge(config, kwargs)
61
+
62
+ init_param = kwargs.get("init_param", "")
63
+ if (
64
+ isinstance(init_param, str)
65
+ and not os.path.exists(init_param)
66
+ or isinstance(init_param, (list, tuple))
67
+ ):
68
+ init_param_new = init_param
69
+ if isinstance(init_param, str):
70
+ init_param = init_param.split(",")
71
+ for init_param_i in init_param:
72
+ if not os.path.exists(init_param_i):
73
+ print(f"init_param: {init_param_i}, does not exist")
74
+ init_param_i = os.path.join(model_or_path, "model.pt")
75
+ init_param_new = f"{init_param_new},{init_param_i}"
76
+ kwargs["init_param"] = init_param_new
77
+ # assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
78
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
79
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
80
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
81
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
82
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
83
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
84
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
85
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
86
+ kwargs["model"] = config["model"]
87
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
88
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
89
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
90
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
91
+ if isinstance(kwargs, DictConfig):
92
+ kwargs = OmegaConf.to_container(kwargs, resolve=True)
93
+
94
+ return kwargs
95
+
96
+
97
+ def download_from_hf(**kwargs):
98
+ model_or_path = kwargs.get("model")
99
+ if model_or_path in name_maps_hf:
100
+ model_or_path = name_maps_hf[model_or_path]
101
+ model_revision = kwargs.get("model_revision", "master")
102
+ if not os.path.exists(model_or_path) and "model_path" not in kwargs:
103
+ try:
104
+ model_or_path = get_or_download_model_dir_hf(
105
+ model_or_path,
106
+ model_revision,
107
+ is_training=kwargs.get("is_training"),
108
+ check_latest=kwargs.get("check_latest", True),
109
+ )
110
+ except Exception as e:
111
+ print(f"Download: {model_or_path} failed!: {e}")
112
+
113
+ kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
114
+
115
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
116
+ with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
117
+ conf_json = json.load(f)
118
+
119
+ cfg = {}
120
+ if "file_path_metas" in conf_json:
121
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
122
+ cfg = OmegaConf.merge(cfg, kwargs)
123
+ # cfg.update(kwargs)
124
+ if "config" in cfg:
125
+ config = OmegaConf.load(cfg["config"])
126
+ kwargs = OmegaConf.merge(config, cfg)
127
+ kwargs["model"] = config["model"]
128
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
129
+ os.path.join(model_or_path, "model.pt")
130
+ ):
131
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
132
+ kwargs = OmegaConf.merge(config, kwargs)
133
+ init_param = os.path.join(model_or_path, "model.pt")
134
+ kwargs["init_param"] = init_param
135
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
136
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
137
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
138
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
139
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
140
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
141
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
142
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
143
+ kwargs["model"] = config["model"]
144
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
145
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
146
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
147
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
148
+ if isinstance(kwargs, DictConfig):
149
+ kwargs = OmegaConf.to_container(kwargs, resolve=True)
150
+
151
+ return kwargs
152
+
153
+
154
+ def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
155
+ print(file_path_metas)
156
+ if isinstance(file_path_metas, dict):
157
+ for k, v in file_path_metas.items():
158
+ if isinstance(v, str):
159
+ p = os.path.join(model_or_path, v)
160
+ if os.path.exists(p):
161
+ cfg[k] = p
162
+ elif isinstance(v, dict):
163
+ if k not in cfg:
164
+ cfg[k] = {}
165
+ add_file_root_path(model_or_path, v, cfg[k])
166
+ return cfg
167
+
168
+
169
+ def get_or_download_model_dir(
170
+ model,
171
+ model_revision=None,
172
+ is_training=False,
173
+ check_latest=True,
174
+ ):
175
+ """Get local model directory or download model if necessary.
176
+
177
+ Args:
178
+ model (str): model id or path to local model directory.
179
+ model_revision (str, optional): model version number.
180
+ :param is_training:
181
+ """
182
+ from modelscope.hub.check_model import check_local_model_is_latest
183
+ from modelscope.hub.snapshot_download import snapshot_download
184
+
185
+ from modelscope.utils.constant import Invoke, ThirdParty
186
+
187
+ key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
188
+
189
+ if os.path.exists(model) and check_latest:
190
+ model_cache_dir = model if os.path.isdir(model) else os.path.dirname(model)
191
+ try:
192
+ check_local_model_is_latest(
193
+ model_cache_dir, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funcineforge"}
194
+ )
195
+ except:
196
+ print("could not check the latest version")
197
+ else:
198
+ model_cache_dir = snapshot_download(
199
+ model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funcineforge"}
200
+ )
201
+ return model_cache_dir
202
+
203
+
204
+ def get_or_download_model_dir_hf(
205
+ model,
206
+ model_revision=None,
207
+ is_training=False,
208
+ check_latest=True,
209
+ ):
210
+ """Get local model directory or download model if necessary.
211
+
212
+ Args:
213
+ model (str): model id or path to local model directory.
214
+ model_revision (str, optional): model version number.
215
+ :param is_training:
216
+ """
217
+ from huggingface_hub import snapshot_download
218
+
219
+ model_cache_dir = snapshot_download(model)
220
+ return model_cache_dir
funcineforge/download/file.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import contextlib
4
+ import os
5
+ import tempfile
6
+ from abc import ABCMeta, abstractmethod
7
+ from pathlib import Path
8
+ from typing import Generator, Union
9
+
10
+ import requests
11
+ from urllib.parse import urlparse
12
+
13
+
14
+ def download_from_url(url):
15
+ result = urlparse(url)
16
+ file_path = None
17
+ if result.scheme is not None and len(result.scheme) > 0:
18
+ storage = HTTPStorage()
19
+ # bytes
20
+ data = storage.read(url)
21
+ work_dir = tempfile.TemporaryDirectory().name
22
+ if not os.path.exists(work_dir):
23
+ os.makedirs(work_dir)
24
+ file_path = os.path.join(work_dir, os.path.basename(url))
25
+ with open(file_path, "wb") as fb:
26
+ fb.write(data)
27
+ assert file_path is not None, f"failed to download: {url}"
28
+ return file_path
29
+
30
+
31
+ class Storage(metaclass=ABCMeta):
32
+ """Abstract class of storage.
33
+
34
+ All backends need to implement two apis: ``read()`` and ``read_text()``.
35
+ ``read()`` reads the file as a byte stream and ``read_text()`` reads
36
+ the file as texts.
37
+ """
38
+
39
+ @abstractmethod
40
+ def read(self, filepath: str):
41
+ pass
42
+
43
+ @abstractmethod
44
+ def read_text(self, filepath: str):
45
+ pass
46
+
47
+ @abstractmethod
48
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
49
+ pass
50
+
51
+ @abstractmethod
52
+ def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
53
+ pass
54
+
55
+
56
+ class LocalStorage(Storage):
57
+ """Local hard disk storage"""
58
+
59
+ def read(self, filepath: Union[str, Path]) -> bytes:
60
+ """Read data from a given ``filepath`` with 'rb' mode.
61
+
62
+ Args:
63
+ filepath (str or Path): Path to read data.
64
+
65
+ Returns:
66
+ bytes: Expected bytes object.
67
+ """
68
+ with open(filepath, "rb") as f:
69
+ content = f.read()
70
+ return content
71
+
72
+ def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
73
+ """Read data from a given ``filepath`` with 'r' mode.
74
+
75
+ Args:
76
+ filepath (str or Path): Path to read data.
77
+ encoding (str): The encoding format used to open the ``filepath``.
78
+ Default: 'utf-8'.
79
+
80
+ Returns:
81
+ str: Expected text reading from ``filepath``.
82
+ """
83
+ with open(filepath, "r", encoding=encoding) as f:
84
+ value_buf = f.read()
85
+ return value_buf
86
+
87
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
88
+ """Write data to a given ``filepath`` with 'wb' mode.
89
+
90
+ Note:
91
+ ``write`` will create a directory if the directory of ``filepath``
92
+ does not exist.
93
+
94
+ Args:
95
+ obj (bytes): Data to be written.
96
+ filepath (str or Path): Path to write data.
97
+ """
98
+ dirname = os.path.dirname(filepath)
99
+ if dirname and not os.path.exists(dirname):
100
+ os.makedirs(dirname, exist_ok=True)
101
+
102
+ with open(filepath, "wb") as f:
103
+ f.write(obj)
104
+
105
+ def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
106
+ """Write data to a given ``filepath`` with 'w' mode.
107
+
108
+ Note:
109
+ ``write_text`` will create a directory if the directory of
110
+ ``filepath`` does not exist.
111
+
112
+ Args:
113
+ obj (str): Data to be written.
114
+ filepath (str or Path): Path to write data.
115
+ encoding (str): The encoding format used to open the ``filepath``.
116
+ Default: 'utf-8'.
117
+ """
118
+ dirname = os.path.dirname(filepath)
119
+ if dirname and not os.path.exists(dirname):
120
+ os.makedirs(dirname, exist_ok=True)
121
+
122
+ with open(filepath, "w", encoding=encoding) as f:
123
+ f.write(obj)
124
+
125
+ @contextlib.contextmanager
126
+ def as_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
127
+ """Only for unified API and do nothing."""
128
+ yield filepath
129
+
130
+
131
+ class HTTPStorage(Storage):
132
+ """HTTP and HTTPS storage."""
133
+
134
+ def read(self, url):
135
+ # TODO @wenmeng.zwm add progress bar if file is too large
136
+ r = requests.get(url)
137
+ r.raise_for_status()
138
+ return r.content
139
+
140
+ def read_text(self, url):
141
+ r = requests.get(url)
142
+ r.raise_for_status()
143
+ return r.text
144
+
145
+ @contextlib.contextmanager
146
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
147
+ """Download a file from ``filepath``.
148
+
149
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
150
+ can be called with ``with`` statement, and when exists from the
151
+ ``with`` statement, the temporary path will be released.
152
+
153
+ Args:
154
+ filepath (str): Download a file from ``filepath``.
155
+
156
+ Examples:
157
+ >>> storage = HTTPStorage()
158
+ >>> # After existing from the ``with`` clause,
159
+ >>> # the path will be removed
160
+ >>> with storage.get_local_path('http://path/to/file') as path:
161
+ ... # do something here
162
+ """
163
+ try:
164
+ f = tempfile.NamedTemporaryFile(delete=False)
165
+ f.write(self.read(filepath))
166
+ f.close()
167
+ yield f.name
168
+ finally:
169
+ os.remove(f.name)
170
+
171
+ def write(self, obj: bytes, url: Union[str, Path]) -> None:
172
+ raise NotImplementedError("write is not supported by HTTP Storage")
173
+
174
+ def write_text(self, obj: str, url: Union[str, Path], encoding: str = "utf-8") -> None:
175
+ raise NotImplementedError("write_text is not supported by HTTP Storage")
176
+
177
+
178
+ class OSSStorage(Storage):
179
+ """OSS storage."""
180
+
181
+ def __init__(self, oss_config_file=None):
182
+ # read from config file or env var
183
+ raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
184
+
185
+ def read(self, filepath):
186
+ raise NotImplementedError("OSSStorage.read to be implemented in the future")
187
+
188
+ def read_text(self, filepath, encoding="utf-8"):
189
+ raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
190
+
191
+ @contextlib.contextmanager
192
+ def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
193
+ """Download a file from ``filepath``.
194
+
195
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
196
+ can be called with ``with`` statement, and when exists from the
197
+ ``with`` statement, the temporary path will be released.
198
+
199
+ Args:
200
+ filepath (str): Download a file from ``filepath``.
201
+
202
+ Examples:
203
+ >>> storage = OSSStorage()
204
+ >>> # After existing from the ``with`` clause,
205
+ >>> # the path will be removed
206
+ >>> with storage.get_local_path('http://path/to/file') as path:
207
+ ... # do something here
208
+ """
209
+ try:
210
+ f = tempfile.NamedTemporaryFile(delete=False)
211
+ f.write(self.read(filepath))
212
+ f.close()
213
+ yield f.name
214
+ finally:
215
+ os.remove(f.name)
216
+
217
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
218
+ raise NotImplementedError("OSSStorage.write to be implemented in the future")
219
+
220
+ def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
221
+ raise NotImplementedError("OSSStorage.write_text to be implemented in the future")
222
+
223
+
224
+ G_STORAGES = {}
225
+
226
+
227
+ class File(object):
228
+ _prefix_to_storage: dict = {
229
+ "oss": OSSStorage,
230
+ "http": HTTPStorage,
231
+ "https": HTTPStorage,
232
+ "local": LocalStorage,
233
+ }
234
+
235
+ @staticmethod
236
+ def _get_storage(uri):
237
+ assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
238
+
239
+ if "://" not in uri:
240
+ # local path
241
+ storage_type = "local"
242
+ else:
243
+ prefix, _ = uri.split("://")
244
+ storage_type = prefix
245
+
246
+ assert storage_type in File._prefix_to_storage, (
247
+ f"Unsupported uri {uri}, valid prefixs: " f"{list(File._prefix_to_storage.keys())}"
248
+ )
249
+
250
+ if storage_type not in G_STORAGES:
251
+ G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
252
+
253
+ return G_STORAGES[storage_type]
254
+
255
+ @staticmethod
256
+ def read(uri: str) -> bytes:
257
+ """Read data from a given ``filepath`` with 'rb' mode.
258
+
259
+ Args:
260
+ filepath (str or Path): Path to read data.
261
+
262
+ Returns:
263
+ bytes: Expected bytes object.
264
+ """
265
+ storage = File._get_storage(uri)
266
+ return storage.read(uri)
267
+
268
+ @staticmethod
269
+ def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
270
+ """Read data from a given ``filepath`` with 'r' mode.
271
+
272
+ Args:
273
+ filepath (str or Path): Path to read data.
274
+ encoding (str): The encoding format used to open the ``filepath``.
275
+ Default: 'utf-8'.
276
+
277
+ Returns:
278
+ str: Expected text reading from ``filepath``.
279
+ """
280
+ storage = File._get_storage(uri)
281
+ return storage.read_text(uri)
282
+
283
+ @staticmethod
284
+ def write(obj: bytes, uri: Union[str, Path]) -> None:
285
+ """Write data to a given ``filepath`` with 'wb' mode.
286
+
287
+ Note:
288
+ ``write`` will create a directory if the directory of ``filepath``
289
+ does not exist.
290
+
291
+ Args:
292
+ obj (bytes): Data to be written.
293
+ filepath (str or Path): Path to write data.
294
+ """
295
+ storage = File._get_storage(uri)
296
+ return storage.write(obj, uri)
297
+
298
+ @staticmethod
299
+ def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
300
+ """Write data to a given ``filepath`` with 'w' mode.
301
+
302
+ Note:
303
+ ``write_text`` will create a directory if the directory of
304
+ ``filepath`` does not exist.
305
+
306
+ Args:
307
+ obj (str): Data to be written.
308
+ filepath (str or Path): Path to write data.
309
+ encoding (str): The encoding format used to open the ``filepath``.
310
+ Default: 'utf-8'.
311
+ """
312
+ storage = File._get_storage(uri)
313
+ return storage.write_text(obj, uri)
314
+
315
+ @contextlib.contextmanager
316
+ def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
317
+ """Only for unified API and do nothing."""
318
+ storage = File._get_storage(uri)
319
+ with storage.as_local_path(uri) as local_path:
320
+ yield local_path
funcineforge/download/name_maps_from_hub.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name_maps_ms = {
2
+ "paraformer": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
3
+ "paraformer-zh": "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
4
+ "paraformer-en": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
5
+ "paraformer-en-spk": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
6
+ "paraformer-zh-streaming": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
7
+ "fsmn-vad": "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
8
+ "ct-punc": "iic/punc_ct-transformer_cn-en-common-vocab471067-large",
9
+ "ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
10
+ "fa-zh": "iic/speech_timestamp_prediction-v1-16k-offline",
11
+ "cam++": "iic/speech_campplus_sv_zh-cn_16k-common",
12
+ "Whisper-large-v3": "iic/Whisper-large-v3",
13
+ "Qwen-Audio": "Qwen/Qwen-Audio",
14
+ "emotion2vec_plus_large": "iic/emotion2vec_plus_large",
15
+ "emotion2vec_plus_base": "iic/emotion2vec_plus_base",
16
+ "emotion2vec_plus_seed": "iic/emotion2vec_plus_seed",
17
+ }
18
+
19
+ name_maps_hf = {
20
+ "paraformer": "funasr/paraformer-zh",
21
+ "paraformer-zh": "funasr/paraformer-zh",
22
+ "paraformer-en": "funasr/paraformer-zh",
23
+ "paraformer-zh-streaming": "funasr/paraformer-zh-streaming",
24
+ "fsmn-vad": "funasr/fsmn-vad",
25
+ "ct-punc": "funasr/ct-punc",
26
+ "ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
27
+ "fa-zh": "funasr/fa-zh",
28
+ "cam++": "funasr/campplus",
29
+ "iic/emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
30
+ "iic/emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
31
+ "iic/emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
32
+ }
33
+
34
+ name_maps_openai = {
35
+ "Whisper-base.en": "base.en",
36
+ "Whisper-base": "base",
37
+ "Whisper-large": "large",
38
+ "Whisper-large-v1": "large-v1",
39
+ "Whisper-large-v2": "large-v2",
40
+ "Whisper-large-v3": "large-v3",
41
+ "Whisper-large-v3-turbo": "turbo",
42
+ }
funcineforge/face/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .face_recognition import FaceRecIR101
funcineforge/face/face_recognition.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def FaceRecIR101(init_param_path, **kwargs):
2
+ """
3
+ Face embeddings extraction with CurricularFace pretrained model.
4
+ Reference:
5
+ - https://modelscope.cn/models/iic/cv_ir101_facerecognition_cfglint
6
+ """
7
+ import onnxruntime
8
+ options = onnxruntime.SessionOptions()
9
+ options.intra_op_num_threads = 8
10
+ options.inter_op_num_threads = 8
11
+ ort_session = onnxruntime.InferenceSession(
12
+ init_param_path,
13
+ sess_options=options,
14
+ providers=['CPUExecutionProvider']
15
+ )
16
+ return ort_session
funcineforge/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .specaug.specaug import SpecAug as FunCineForgeSpecAug
2
+ from .language_model import FunCineForgeLM
3
+ from .causal_hifigan import CausalHifiGan
4
+ from .flow_matching_model import CosyVoiceFlowMatching
5
+ from .inference_model import FunCineForgeInferModel
funcineforge/models/causal_hifigan.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 KaiHu
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """HIFI-GAN"""
5
+
6
+ from typing import Dict
7
+ from typing import Tuple, List
8
+
9
+ import numpy as np
10
+ from scipy.signal import get_window
11
+ import torch
12
+ import torchaudio
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ from torch.nn.utils import remove_weight_norm
16
+ from torch.nn.utils.parametrize import remove_parametrizations
17
+ from torch.nn.utils.parametrizations import weight_norm
18
+ import logging
19
+ from funcineforge.utils.device_funcs import to_device
20
+ import os
21
+ from torch.nn.utils.rnn import pad_sequence
22
+ from funcineforge.models.utils import dtype_map
23
+ from funcineforge.models.modules.hifigan import init_weights
24
+ from funcineforge.models.modules.hifigan.activations import Snake
25
+
26
+
27
+ class LookRightConv1d(torch.nn.Conv1d):
28
+ def __init__(
29
+ self,
30
+ in_channels: int,
31
+ out_channels: int,
32
+ kernel_size: int,
33
+ stride: int = 1,
34
+ dilation: int = 1,
35
+ groups: int = 1,
36
+ bias: bool = True,
37
+ padding_mode: str = 'zeros',
38
+ device=None,
39
+ dtype=None
40
+ ) -> None:
41
+ super(LookRightConv1d, self).__init__(in_channels, out_channels,
42
+ kernel_size, stride,
43
+ padding=0, dilation=dilation,
44
+ groups=groups, bias=bias,
45
+ padding_mode=padding_mode,
46
+ device=device, dtype=dtype)
47
+ assert stride == 1
48
+ self.causal_padding = kernel_size - 1
49
+
50
+ def forward(self, x: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
51
+ if context.size(2) == 0:
52
+ x = F.pad(x, (0, self.causal_padding), value=0.0)
53
+ else:
54
+ assert context.size(2) == self.causal_padding
55
+ x = torch.concat([x, context], dim=2)
56
+ x = super(LookRightConv1d, self).forward(x)
57
+ return x
58
+
59
+ class LookLeftConv1d(torch.nn.Conv1d):
60
+ def __init__(
61
+ self,
62
+ in_channels: int,
63
+ out_channels: int,
64
+ kernel_size: int,
65
+ stride: int = 1,
66
+ dilation: int = 1,
67
+ groups: int = 1,
68
+ bias: bool = True,
69
+ padding_mode: str = 'zeros',
70
+ device=None,
71
+ dtype=None
72
+ ) -> None:
73
+ super(LookLeftConv1d, self).__init__(in_channels, out_channels,
74
+ kernel_size, stride,
75
+ padding=0, dilation=dilation,
76
+ groups=groups, bias=bias,
77
+ padding_mode=padding_mode,
78
+ device=device, dtype=dtype)
79
+ assert stride == 1 and dilation == 1
80
+ self.causal_padding = kernel_size - 1
81
+
82
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ if cache.size(2) == 0:
84
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
85
+ else:
86
+ assert cache.size(2) == self.causal_padding
87
+ x = torch.concat([cache, x], dim=2)
88
+ # NOTE 兼容kernel_size=1的情况
89
+ if self.causal_padding == 0:
90
+ cache_new = x[:, :, :0]
91
+ else:
92
+ cache_new = x[:, :, -self.causal_padding:]
93
+ x = super(LookLeftConv1d, self).forward(x)
94
+ return x, cache_new
95
+
96
+
97
+ class CausalConvRNNF0Predictor(nn.Module):
98
+ def __init__(self,
99
+ num_class: int = 1,
100
+ in_channels: int = 80,
101
+ cond_channels: int = 512
102
+ ):
103
+ super().__init__()
104
+
105
+ self.num_class = num_class
106
+ self.condnet = nn.Sequential(
107
+ weight_norm(
108
+ LookRightConv1d(in_channels, cond_channels, kernel_size=4)
109
+ ),
110
+ nn.ELU(),
111
+ weight_norm(
112
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
113
+ ),
114
+ nn.ELU(),
115
+ weight_norm(
116
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
117
+ ),
118
+ nn.ELU(),
119
+ weight_norm(
120
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
121
+ ),
122
+ nn.ELU(),
123
+ weight_norm(
124
+ LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
125
+ ),
126
+ nn.ELU(),
127
+ )
128
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
129
+
130
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0, 0), finalize: bool = True) -> torch.Tensor:
131
+ if finalize is False:
132
+ x, context = x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:]
133
+ else:
134
+ x, context = x, x[:, :, :0]
135
+ x = self.condnet[0](x, context)
136
+ x = self.condnet[1](x)
137
+ if cache.size(0) != 0:
138
+ x, cache[0] = self.condnet[2](x, cache[0])
139
+ else:
140
+ x, _ = self.condnet[2](x)
141
+ x = self.condnet[3](x)
142
+ if cache.size(0) != 0:
143
+ x, cache[1] = self.condnet[4](x, cache[1])
144
+ else:
145
+ x, _ = self.condnet[4](x)
146
+ x = self.condnet[5](x)
147
+ if cache.size(0) != 0:
148
+ x, cache[2] = self.condnet[6](x, cache[2])
149
+ else:
150
+ x, _ = self.condnet[6](x)
151
+ x = self.condnet[7](x)
152
+ if cache.size(0) != 0:
153
+ x, cache[3] = self.condnet[8](x, cache[3])
154
+ else:
155
+ x, _ = self.condnet[8](x)
156
+ x = self.condnet[9](x)
157
+ x = x.transpose(1, 2)
158
+ x = torch.abs(self.classifier(x).squeeze(-1))
159
+ return x, cache
160
+
161
+ def init_cache(self, device):
162
+ return torch.zeros(4, 1, 512, 2).to(device)
163
+
164
+ def remove_weight_norm(self):
165
+ print('Removing weight norm...')
166
+ try:
167
+ remove_weight_norm(self.condnet[0])
168
+ remove_weight_norm(self.condnet[2])
169
+ remove_weight_norm(self.condnet[4])
170
+ remove_weight_norm(self.condnet[6])
171
+ remove_weight_norm(self.condnet[8])
172
+ except:
173
+ remove_parametrizations(self.condnet[0], 'weight')
174
+ remove_parametrizations(self.condnet[2], 'weight')
175
+ remove_parametrizations(self.condnet[4], 'weight')
176
+ remove_parametrizations(self.condnet[6], 'weight')
177
+ remove_parametrizations(self.condnet[8], 'weight')
178
+
179
+
180
+ class LookLeftConvTranspose1d(torch.nn.Conv1d):
181
+ def __init__(
182
+ self,
183
+ in_channels: int,
184
+ out_channels: int,
185
+ kernel_size: int,
186
+ stride: int = 1,
187
+ dilation: int = 1,
188
+ groups: int = 1,
189
+ bias: bool = True,
190
+ padding_mode: str = 'zeros',
191
+ device=None,
192
+ dtype=None
193
+ ) -> None:
194
+ super(LookLeftConvTranspose1d, self).__init__(in_channels, out_channels,
195
+ kernel_size, 1,
196
+ padding=0, dilation=dilation,
197
+ groups=groups, bias=bias,
198
+ padding_mode=padding_mode,
199
+ device=device, dtype=dtype)
200
+ assert dilation == 1 and stride != 1
201
+ self.causal_padding = kernel_size - 1
202
+ self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
203
+
204
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
205
+ x = self.upsample(x)
206
+ if cache.size(2) == 0:
207
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
208
+ else:
209
+ assert cache.size(2) == self.causal_padding
210
+ x = torch.concat([cache, x], dim=2)
211
+ cache_new = x[:, :, -self.causal_padding:]
212
+ x = super(LookLeftConvTranspose1d, self).forward(x)
213
+ return x, cache_new
214
+
215
+
216
+ class LookLeftConv1dWithStride(torch.nn.Conv1d):
217
+ def __init__(
218
+ self,
219
+ in_channels: int,
220
+ out_channels: int,
221
+ kernel_size: int,
222
+ stride: int = 1,
223
+ dilation: int = 1,
224
+ groups: int = 1,
225
+ bias: bool = True,
226
+ padding_mode: str = 'zeros',
227
+ device=None,
228
+ dtype=None
229
+ ) -> None:
230
+ super(LookLeftConv1dWithStride, self).__init__(in_channels, out_channels,
231
+ kernel_size, stride,
232
+ padding=0, dilation=dilation,
233
+ groups=groups, bias=bias,
234
+ padding_mode=padding_mode,
235
+ device=device, dtype=dtype)
236
+ assert stride != 1 and dilation == 1
237
+ assert kernel_size % stride == 0
238
+ self.causal_padding = stride - 1
239
+
240
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
241
+ if cache.size(2) == 0:
242
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
243
+ else:
244
+ assert cache.size(2) == self.causal_padding
245
+ x = torch.concat([cache, x], dim=2)
246
+ cache_new = x[:, :, -self.causal_padding:]
247
+ x = super(LookLeftConv1dWithStride, self).forward(x)
248
+ return x, cache_new
249
+
250
+
251
+ class LookLeftConv1dWithDilation(torch.nn.Conv1d):
252
+ def __init__(
253
+ self,
254
+ in_channels: int,
255
+ out_channels: int,
256
+ kernel_size: int,
257
+ stride: int = 1,
258
+ dilation: int = 1,
259
+ groups: int = 1,
260
+ bias: bool = True,
261
+ padding_mode: str = 'zeros',
262
+ device=None,
263
+ dtype=None
264
+ ) -> None:
265
+ super(LookLeftConv1dWithDilation, self).__init__(in_channels, out_channels,
266
+ kernel_size, stride,
267
+ padding=0, dilation=dilation,
268
+ groups=groups, bias=bias,
269
+ padding_mode=padding_mode,
270
+ device=device, dtype=dtype)
271
+ # NOTE(lyuxiang.lx) 这个causal_padding仅在kernel_size为奇数时才成立
272
+ assert kernel_size // 2 * dilation * 2 == int((kernel_size * dilation - dilation) / 2) * 2
273
+ self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2
274
+
275
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
276
+ if cache.size(2) == 0:
277
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
278
+ else:
279
+ assert cache.size(2) == self.causal_padding
280
+ x = torch.concat([cache, x], dim=2)
281
+ cache_new = x[:, :, -self.causal_padding:]
282
+ x = super(LookLeftConv1dWithDilation, self).forward(x)
283
+ return x, cache_new
284
+
285
+
286
+ class ResBlock(torch.nn.Module):
287
+ """Residual block module in HiFiGAN/BigVGAN."""
288
+ def __init__(
289
+ self,
290
+ channels: int = 512,
291
+ kernel_size: int = 3,
292
+ dilations: List[int] = [1, 3, 5],
293
+ ):
294
+ super(ResBlock, self).__init__()
295
+ self.convs1 = nn.ModuleList()
296
+ self.convs2 = nn.ModuleList()
297
+
298
+ for dilation in dilations:
299
+ self.convs1.append(
300
+ weight_norm(
301
+ LookLeftConv1dWithDilation(
302
+ channels,
303
+ channels,
304
+ kernel_size,
305
+ 1,
306
+ dilation=dilation
307
+ ) if dilation != 1 else
308
+ LookLeftConv1d(
309
+ channels,
310
+ channels,
311
+ kernel_size,
312
+ 1,
313
+ dilation=dilation
314
+ )
315
+ )
316
+ )
317
+ self.convs2.append(
318
+ weight_norm(
319
+ LookLeftConv1d(
320
+ channels,
321
+ channels,
322
+ kernel_size,
323
+ 1,
324
+ dilation=1
325
+ )
326
+ )
327
+ )
328
+ self.convs1.apply(init_weights)
329
+ self.convs2.apply(init_weights)
330
+ self.activations1 = nn.ModuleList([
331
+ Snake(channels, alpha_logscale=False)
332
+ for _ in range(len(self.convs1))
333
+ ])
334
+ self.activations2 = nn.ModuleList([
335
+ Snake(channels, alpha_logscale=False)
336
+ for _ in range(len(self.convs2))
337
+ ])
338
+
339
+ def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)) -> torch.Tensor:
340
+ for idx in range(len(self.convs1)):
341
+ xt = self.activations1[idx](x)
342
+ xt, _ = self.convs1[idx](xt)
343
+ xt = self.activations2[idx](xt)
344
+ xt, _ = self.convs2[idx](xt)
345
+ x = xt + x
346
+ return x, cache
347
+
348
+ def remove_weight_norm(self):
349
+ for idx in range(len(self.convs1)):
350
+ try:
351
+ remove_weight_norm(self.convs1[idx])
352
+ remove_weight_norm(self.convs2[idx])
353
+ except:
354
+ remove_parametrizations(self.convs1[idx], 'weight')
355
+ remove_parametrizations(self.convs2[idx], 'weight')
356
+
357
+
358
+ class SineGen(torch.nn.Module):
359
+ """ Definition of sine generator
360
+ SineGen(samp_rate, harmonic_num = 0,
361
+ sine_amp = 0.1, noise_std = 0.003,
362
+ voiced_threshold = 0,
363
+ flag_for_pulse=False)
364
+ samp_rate: sampling rate in Hz
365
+ harmonic_num: number of harmonic overtones (default 0)
366
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
367
+ noise_std: std of Gaussian noise (default 0.003)
368
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
369
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
370
+ Note: when flag_for_pulse is True, the first time step of a voiced
371
+ segment is always sin(np.pi) or cos(0)
372
+ """
373
+
374
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
375
+ sine_amp=0.1, noise_std=0.003,
376
+ voiced_threshold=0,
377
+ flag_for_pulse=False):
378
+ super(SineGen, self).__init__()
379
+ self.sine_amp = sine_amp
380
+ self.noise_std = noise_std
381
+ self.harmonic_num = harmonic_num
382
+ self.dim = self.harmonic_num + 1
383
+ self.sampling_rate = samp_rate
384
+ self.voiced_threshold = voiced_threshold
385
+ self.flag_for_pulse = flag_for_pulse
386
+ self.upsample_scale = upsample_scale
387
+ self.rand_ini = torch.rand(1, 9)
388
+ self.rand_ini[:, 0] = 0
389
+ self.sine_waves = torch.rand(1, 300 * 24000, 9)
390
+
391
+ def _f02uv(self, f0):
392
+ # generate uv signal
393
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
394
+ return uv
395
+
396
+ def _f02sine(self, f0_values):
397
+ """ f0_values: (batchsize, length, dim)
398
+ where dim indicates fundamental tone and overtones
399
+ """
400
+ # convert to F0 in rad. The interger part n can be ignored
401
+ # because 2 * np.pi * n doesn't affect phase
402
+ rad_values = (f0_values / self.sampling_rate) % 1
403
+
404
+ # initial phase noise (no noise for fundamental component)
405
+ rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
406
+
407
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
408
+ if not self.flag_for_pulse:
409
+ # # for normal case
410
+
411
+ # # To prevent torch.cumsum numerical overflow,
412
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
413
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
414
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
415
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
416
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
417
+ # cumsum_shift = torch.zeros_like(rad_values)
418
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
419
+
420
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
421
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
422
+ scale_factor=1/self.upsample_scale,
423
+ mode="linear").transpose(1, 2)
424
+
425
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
426
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
427
+ # cumsum_shift = torch.zeros_like(rad_values)
428
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
429
+
430
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
431
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
432
+ scale_factor=self.upsample_scale, mode="nearest").transpose(1, 2)
433
+ sines = torch.sin(phase)
434
+
435
+ else:
436
+ # If necessary, make sure that the first time step of every
437
+ # voiced segments is sin(pi) or cos(0)
438
+ # This is used for pulse-train generation
439
+
440
+ # identify the last time step in unvoiced segments
441
+ uv = self._f02uv(f0_values)
442
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
443
+ uv_1[:, -1, :] = 1
444
+ u_loc = (uv < 1) * (uv_1 > 0)
445
+
446
+ # get the instantanouse phase
447
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
448
+ # different batch needs to be processed differently
449
+ for idx in range(f0_values.shape[0]):
450
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
451
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
452
+ # stores the accumulation of i.phase within
453
+ # each voiced segments
454
+ tmp_cumsum[idx, :, :] = 0
455
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
456
+
457
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
458
+ # within the previous voiced segment.
459
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
460
+
461
+ # get the sines
462
+ sines = torch.cos(i_phase * 2 * np.pi)
463
+ return sines
464
+
465
+ def forward(self, f0):
466
+ """ sine_tensor, uv = forward(f0)
467
+ input F0: tensor(batchsize=1, length, dim=1)
468
+ f0 for unvoiced steps should be 0
469
+ output sine_tensor: tensor(batchsize=1, length, dim)
470
+ output uv: tensor(batchsize=1, length, 1)
471
+ """
472
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
473
+ device=f0.device)
474
+ # fundamental component
475
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
476
+
477
+ # generate sine waveforms
478
+ sine_waves = self._f02sine(fn) * self.sine_amp
479
+
480
+ # generate uv signal
481
+ # uv = torch.ones(f0.shape)
482
+ # uv = uv * (f0 > self.voiced_threshold)
483
+ uv = self._f02uv(f0)
484
+
485
+ # noise: for unvoiced should be similar to sine_amp
486
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
487
+ # . for voiced regions is self.noise_std
488
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
489
+ noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
490
+
491
+ # first: set the unvoiced part to 0 by uv
492
+ # then: additive noise
493
+ sine_waves = sine_waves * uv + noise
494
+ return sine_waves, uv, noise
495
+
496
+
497
+ class SourceModuleHnNSF(torch.nn.Module):
498
+ """ SourceModule for hn-nsf
499
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
500
+ add_noise_std=0.003, voiced_threshod=0)
501
+ sampling_rate: sampling_rate in Hz
502
+ harmonic_num: number of harmonic above F0 (default: 0)
503
+ sine_amp: amplitude of sine source signal (default: 0.1)
504
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
505
+ note that amplitude of noise in unvoiced is decided
506
+ by sine_amp
507
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
508
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
509
+ F0_sampled (batchsize, length, 1)
510
+ Sine_source (batchsize, length, 1)
511
+ noise_source (batchsize, length 1)
512
+ uv (batchsize, length, 1)
513
+ """
514
+
515
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
516
+ add_noise_std=0.003, voiced_threshod=0):
517
+ super(SourceModuleHnNSF, self).__init__()
518
+
519
+ self.sine_amp = sine_amp
520
+ self.noise_std = add_noise_std
521
+
522
+ # to produce sine waveforms
523
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
524
+ sine_amp, add_noise_std, voiced_threshod)
525
+
526
+ # to merge source harmonics into a single excitation
527
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
528
+ self.l_tanh = torch.nn.Tanh()
529
+ self.uv = torch.rand(1, 300 * 24000, 1)
530
+
531
+ def forward(self, x):
532
+ """
533
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
534
+ F0_sampled (batchsize, length, 1)
535
+ Sine_source (batchsize, length, 1)
536
+ noise_source (batchsize, length 1)
537
+ """
538
+ # source for harmonic branch
539
+ with torch.no_grad():
540
+ sine_wavs, uv, _ = self.l_sin_gen(x)
541
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
542
+
543
+ # source for noise branch, in the same shape as uv
544
+ noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
545
+ return sine_merge, noise, uv
546
+
547
+
548
+ class CausalHiFTGenerator(nn.Module):
549
+ """
550
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
551
+ https://arxiv.org/abs/2309.09493
552
+ """
553
+ def __init__(
554
+ self,
555
+ in_channels: int = 80,
556
+ base_channels: int = 512,
557
+ nb_harmonics: int = 8,
558
+ sampling_rate: int = 22050,
559
+ nsf_alpha: float = 0.1,
560
+ nsf_sigma: float = 0.003,
561
+ nsf_voiced_threshold: float = 10,
562
+ upsample_rates: List[int] = [8, 8],
563
+ upsample_kernel_sizes: List[int] = [16, 16],
564
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
565
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
566
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
567
+ source_resblock_kernel_sizes: List[int] = [7, 11],
568
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
569
+ lrelu_slope: float = 0.1,
570
+ audio_limit: float = 0.99,
571
+ f0_predictor: torch.nn.Module = None,
572
+ ):
573
+ super(CausalHiFTGenerator, self).__init__()
574
+
575
+ self.out_channels = 1
576
+ self.nb_harmonics = nb_harmonics
577
+ self.sampling_rate = sampling_rate
578
+ self.istft_params = istft_params
579
+ self.lrelu_slope = lrelu_slope
580
+ self.audio_limit = audio_limit
581
+
582
+ self.num_kernels = len(resblock_kernel_sizes)
583
+ self.num_upsamples = len(upsample_rates)
584
+ self.m_source = SourceModuleHnNSF(
585
+ sampling_rate=sampling_rate,
586
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
587
+ harmonic_num=nb_harmonics,
588
+ sine_amp=nsf_alpha,
589
+ add_noise_std=nsf_sigma,
590
+ voiced_threshod=nsf_voiced_threshold)
591
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"], mode='nearest')
592
+
593
+ self.conv_pre = weight_norm(
594
+ LookRightConv1d(in_channels, base_channels, 5, 1)
595
+ )
596
+
597
+ # Up
598
+ self.ups = nn.ModuleList()
599
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
600
+ self.ups.append(
601
+ weight_norm(
602
+ LookLeftConvTranspose1d(
603
+ base_channels // (2**i),
604
+ base_channels // (2**(i + 1)),
605
+ k,
606
+ u
607
+ )
608
+ )
609
+ )
610
+
611
+ # Down
612
+ self.source_downs = nn.ModuleList()
613
+ self.source_resblocks = nn.ModuleList()
614
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
615
+ downsample_cum_rates = np.cumprod(downsample_rates)
616
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
617
+ if u == 1:
618
+ self.source_downs.append(
619
+ LookLeftConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
620
+ )
621
+ else:
622
+ self.source_downs.append(
623
+ LookLeftConv1dWithStride(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
624
+ )
625
+
626
+ self.source_resblocks.append(
627
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
628
+ )
629
+
630
+ self.resblocks = nn.ModuleList()
631
+ for i in range(len(self.ups)):
632
+ ch = base_channels // (2**(i + 1))
633
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
634
+ self.resblocks.append(ResBlock(ch, k, d))
635
+
636
+ self.conv_post = weight_norm(LookLeftConv1d(ch, istft_params["n_fft"] + 2, 7, 1))
637
+ self.ups.apply(init_weights)
638
+ self.conv_post.apply(init_weights)
639
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
640
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
641
+ self.f0_predictor = f0_predictor
642
+ # f0回退3帧,hift回退5帧
643
+ self.context_size = 8
644
+
645
+ def remove_weight_norm(self):
646
+ print('Removing weight norm...')
647
+ for l in self.ups:
648
+ try:
649
+ remove_weight_norm(l)
650
+ except:
651
+ remove_parametrizations(l, 'weight')
652
+ for l in self.resblocks:
653
+ l.remove_weight_norm()
654
+ try:
655
+ remove_weight_norm(self.conv_pre)
656
+ remove_weight_norm(self.conv_post)
657
+ except:
658
+ remove_parametrizations(self.conv_pre, 'weight')
659
+ remove_parametrizations(self.conv_post, 'weight')
660
+ self.f0_predictor.remove_weight_norm()
661
+ for l in self.source_resblocks:
662
+ l.remove_weight_norm()
663
+
664
+ def _stft(self, x):
665
+ spec = torch.stft(
666
+ x,
667
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
668
+ return_complex=True)
669
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
670
+ return spec[..., 0], spec[..., 1]
671
+
672
+ def _istft(self, magnitude, phase):
673
+ magnitude = torch.clip(magnitude, max=1e2)
674
+ real = magnitude * torch.cos(phase)
675
+ img = magnitude * torch.sin(phase)
676
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
677
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
678
+ return inverse_transform
679
+
680
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(0, 0, 0), finalize: bool = True) -> torch.Tensor:
681
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
682
+ # NOTE(lyuxiang.lx) 回退4帧
683
+ if finalize is False:
684
+ s_stft_real, s_stft_imag = s_stft_real[:, :, :-int(480 * 4 / self.istft_params["hop_len"])], s_stft_imag[:, :, :-int(480 * 4 / self.istft_params["hop_len"])]
685
+ x = self.conv_pre(x[:, :, :-4], x[:, :, -4:])
686
+ else:
687
+ x = self.conv_pre(x)
688
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
689
+ for i in range(self.num_upsamples):
690
+ x = F.leaky_relu(x, self.lrelu_slope)
691
+ x, _ = self.ups[i](x)
692
+
693
+ if i == self.num_upsamples - 1:
694
+ x = self.reflection_pad(x)
695
+
696
+ # fusion
697
+ si, _ = self.source_downs[i](s_stft)
698
+ si, _ = self.source_resblocks[i](si)
699
+ x = x + si
700
+
701
+ xs = None
702
+ for j in range(self.num_kernels):
703
+ this_xs, _ = self.resblocks[i * self.num_kernels + j](x)
704
+ if xs is None:
705
+ xs = this_xs
706
+ else:
707
+ xs += this_xs
708
+ x = xs / self.num_kernels
709
+
710
+ x = F.leaky_relu(x)
711
+ x, _ = self.conv_post(x)
712
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
713
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
714
+
715
+ x = self._istft(magnitude, phase)
716
+ # NOTE(lyuxiang.lx) 回退1帧
717
+ if finalize is False:
718
+ x = x[:, :-480]
719
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
720
+ return x
721
+
722
+ @torch.inference_mode()
723
+ def inference(self, speech_feat: torch.Tensor, f0_cpu: bool = False, finalize: bool = True) -> torch.Tensor:
724
+ # mel->f0->source
725
+ if f0_cpu is True:
726
+ self.f0_predictor.to('cpu')
727
+ f0, _ = self.f0_predictor(speech_feat.cpu(), finalize=finalize)
728
+ f0 = f0.to(speech_feat.device)
729
+ else:
730
+ self.f0_predictor.to(speech_feat.device)
731
+ f0, _ = self.f0_predictor(speech_feat, finalize=finalize)
732
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
733
+ s, _, _ = self.m_source(s)
734
+ s = s.transpose(1, 2)
735
+ if finalize is False:
736
+ generated_speech = self.decode(speech_feat[:, :, :-3], s, finalize=finalize)
737
+ else:
738
+ generated_speech = self.decode(speech_feat, s, finalize=finalize)
739
+ return generated_speech, []
740
+
741
+
742
+ class CausalHifiGan(nn.Module):
743
+ """HIFIGAN-style vocoders (generator [stack of time-level-upsampling blocks] + discriminator).
744
+ NSF-HIFIGAN, HiFTNet Optional.
745
+ """
746
+
747
+ def __init__(
748
+ self,
749
+ CausalHiFTGenerator_conf: dict = {},
750
+ CausalConvRNNF0Predictor_conf: dict = {},
751
+ sample_rate: float = 24000,
752
+ **kwargs
753
+ ):
754
+ super().__init__()
755
+ self.generator = CausalHiFTGenerator(**CausalHiFTGenerator_conf)
756
+ self.generator.f0_predictor = CausalConvRNNF0Predictor(**CausalConvRNNF0Predictor_conf)
757
+ self.generator.remove_weight_norm()
758
+ self.sample_rate = sample_rate
759
+
760
+ def inference_prepare(
761
+ self,
762
+ data_in,
763
+ data_lengths=None,
764
+ key: list = None,
765
+ **kwargs,
766
+ ):
767
+ if kwargs.get("batch_size", 1) > 1:
768
+ raise NotImplementedError("batch decoding is not implemented")
769
+
770
+ feat_list = []
771
+ feat_len_list = []
772
+ for i, feat in enumerate(data_in):
773
+ if isinstance(feat, str) and os.path.exists(feat):
774
+ feat = np.load(feat)
775
+ if isinstance(feat, np.ndarray):
776
+ feat = torch.from_numpy(feat)
777
+
778
+ feat_list.append(feat)
779
+ feat_len_list.append(feat.shape[0])
780
+
781
+ batch = {
782
+ "x": pad_sequence(feat_list, batch_first=True),
783
+ "x_lengths": torch.tensor(feat_len_list, dtype=torch.int64),
784
+ }
785
+ batch = to_device(batch, kwargs["device"])
786
+
787
+ return batch
788
+
789
+ def inference(
790
+ self,
791
+ data_in,
792
+ data_lengths=None,
793
+ key: list = None,
794
+ f0_cpu: bool = True,
795
+ finalize: bool = True,
796
+ **kwargs,
797
+ ) -> torch.Tensor:
798
+ """Run inference.
799
+
800
+ Args:
801
+ x (torch.Tensor): input representation, B x T x C
802
+
803
+ Returns:
804
+ Dict[str, Tensor]:
805
+ * recon_speech (Tensor): Reconstructed waveform tensor (B, T_wav).
806
+
807
+ """
808
+ uttid = key[0]
809
+ batch = self.inference_prepare(data_in, data_lengths, key, **kwargs)
810
+ voc_dtype = dtype_map[kwargs.get("voc_dtype", "fp32")]
811
+ x = batch["x"].to(voc_dtype)
812
+ recon_speech = self.generator.inference(x.transpose(1, 2), f0_cpu=f0_cpu, finalize=finalize)[0].squeeze(1)
813
+ recon_speech = recon_speech.float()
814
+ logging.info(f"{uttid}: wav lengths {recon_speech.shape[1]}")
815
+
816
+ output_dir = kwargs.get("output_dir", None)
817
+ output_sr = kwargs.get("output_sr", None)
818
+ if output_dir is not None:
819
+ wav_out_dir = os.path.join(output_dir, "wav")
820
+ os.makedirs(wav_out_dir, exist_ok=True)
821
+ wav_sr = self.sample_rate
822
+ if output_sr is not None and output_sr != self.sample_rate:
823
+ recon_speech = torchaudio.functional.resample(
824
+ recon_speech,
825
+ orig_freq=self.sample_rate,
826
+ new_freq=output_sr
827
+ )
828
+ wav_sr = output_sr
829
+ torchaudio.save(
830
+ os.path.join(wav_out_dir, f"{key[0]}.wav"), recon_speech.cpu(),
831
+ sample_rate=wav_sr, encoding='PCM_S', bits_per_sample=16
832
+ )
833
+
834
+ return recon_speech
funcineforge/models/flow_matching_model.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict
6
+ import logging
7
+ from librosa.filters import mel as librosa_mel_fn
8
+ import torch.nn.functional as F
9
+ from funcineforge.models.utils.nets_utils import make_pad_mask
10
+ from funcineforge.utils.device_funcs import to_device
11
+ import numpy as np
12
+ from funcineforge.utils.load_utils import extract_campp_xvec
13
+ import time
14
+ from funcineforge.models.utils import dtype_map
15
+ from funcineforge.utils.hinter import hint_once
16
+ from funcineforge.models.utils.masks import add_optional_chunk_mask
17
+ from .modules.dit_flow_matching.dit_model import DiT
18
+
19
+
20
+ class Audio2Mel(nn.Module):
21
+ def __init__(
22
+ self,
23
+ n_fft=1024,
24
+ hop_length=256,
25
+ win_length=1024,
26
+ sampling_rate=22050,
27
+ n_mel_channels=80,
28
+ mel_fmin=0.0,
29
+ mel_fmax=None,
30
+ center=False,
31
+ device='cuda',
32
+ feat_type="power_log",
33
+ ):
34
+ super().__init__()
35
+ ##############################################
36
+ # FFT Parameters
37
+ ##############################################
38
+ window = torch.hann_window(win_length, device=device).float()
39
+ mel_basis = librosa_mel_fn(
40
+ sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
41
+ )
42
+ mel_basis = torch.from_numpy(mel_basis).float().to(device)
43
+ self.register_buffer("mel_basis", mel_basis)
44
+ self.register_buffer("window", window)
45
+ self.n_fft = n_fft
46
+ self.hop_length = hop_length
47
+ self.win_length = win_length
48
+ self.sampling_rate = sampling_rate
49
+ self.n_mel_channels = n_mel_channels
50
+ self.mel_fmax = mel_fmax
51
+ self.center = center
52
+ self.feat_type = feat_type
53
+
54
+ def forward(self, audioin):
55
+ p = (self.n_fft - self.hop_length) // 2
56
+ audio = F.pad(audioin, (p, p), "reflect").squeeze(1)
57
+ fft = torch.stft(
58
+ audio,
59
+ n_fft=self.n_fft,
60
+ hop_length=self.hop_length,
61
+ win_length=self.win_length,
62
+ window=self.window,
63
+ center=self.center,
64
+ pad_mode="reflect",
65
+ normalized=False,
66
+ onesided=True,
67
+ return_complex=True,
68
+ )
69
+ if self.feat_type == "mag_log10":
70
+ power_spec = torch.sqrt(torch.pow(fft.imag, 2) + torch.pow(fft.real, 2))
71
+ mel_output = torch.matmul(self.mel_basis, power_spec)
72
+ return torch.log10(torch.clamp(mel_output, min=1e-5))
73
+ power_spec = torch.pow(fft.imag, 2) + torch.pow(fft.real, 2)
74
+ mel_spec = torch.matmul(self.mel_basis, torch.sqrt(power_spec + 1e-9))
75
+ return self.spectral_normalize(mel_spec)
76
+
77
+ @classmethod
78
+ def spectral_normalize(cls, spec, C=1, clip_val=1e-5):
79
+ output = cls.dynamic_range_compression(spec, C, clip_val)
80
+ return output
81
+
82
+ @classmethod
83
+ def spectral_de_normalize_torch(cls, spec, C=1, clip_val=1e-5):
84
+ output = cls.dynamic_range_decompression(spec, C, clip_val)
85
+ return output
86
+
87
+ @staticmethod
88
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
89
+ return torch.log(torch.clamp(x, min=clip_val) * C)
90
+
91
+ @staticmethod
92
+ def dynamic_range_decompression(x, C=1):
93
+ return torch.exp(x) / C
94
+
95
+
96
+ class LookaheadBlock(nn.Module):
97
+ def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.pre_lookahead_len = pre_lookahead_len
101
+ self.conv1 = nn.Conv1d(
102
+ in_channels, channels,
103
+ kernel_size=pre_lookahead_len+1,
104
+ stride=1, padding=0,
105
+ )
106
+ self.conv2 = nn.Conv1d(
107
+ channels, in_channels,
108
+ kernel_size=3, stride=1, padding=0,
109
+ )
110
+
111
+ def forward(self, inputs, ilens, context: torch.Tensor = torch.zeros(0, 0, 0)):
112
+ """
113
+ inputs: (batch_size, seq_len, channels)
114
+ """
115
+ outputs = inputs.transpose(1, 2).contiguous()
116
+ context = context.transpose(1, 2).contiguous()
117
+ # look ahead
118
+ if context.size(2) == 0:
119
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0)
120
+ else:
121
+ assert context.size(2) == self.pre_lookahead_len
122
+ outputs = torch.concat([outputs, context], dim=2)
123
+ outputs = F.leaky_relu(self.conv1(outputs))
124
+ # outputs
125
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0)
126
+ outputs = self.conv2(outputs)
127
+ outputs = outputs.transpose(1, 2).contiguous()
128
+
129
+ mask = (~make_pad_mask(ilens).unsqueeze(-1).to(inputs.device))
130
+ # residual connection
131
+ outputs = (outputs + inputs) * mask
132
+
133
+ return outputs, ilens
134
+
135
+
136
+ class CosyVoiceFlowMatching(nn.Module):
137
+ def __init__(
138
+ self,
139
+ codebook_size: int,
140
+ model_size: int,
141
+ xvec_size: int = 198,
142
+ dit_conf: Dict = {},
143
+ mel_feat_conf: Dict = {},
144
+ prompt_conf: Dict = None,
145
+ **kwargs):
146
+ super().__init__()
147
+
148
+ # feat related
149
+ self.feat_token_ratio = kwargs.get("feat_token_ratio", None)
150
+ try:
151
+ self.mel_extractor = Audio2Mel(**mel_feat_conf)
152
+ self.sample_rate = self.mel_extractor.sampling_rate
153
+ except:
154
+ self.mel_extractor = None
155
+ self.sample_rate = 24000
156
+ self.mel_norm_type = kwargs.get("mel_norm_type", None)
157
+ self.num_mels = num_mels = mel_feat_conf["n_mel_channels"]
158
+ self.token_rate = kwargs.get("token_rate", 25)
159
+ self.model_dtype = kwargs.get("model_dtype", "fp32")
160
+ self.codebook_size = codebook_size
161
+
162
+ # condition related
163
+ self.prompt_conf = prompt_conf
164
+ if self.prompt_conf is not None:
165
+ self.prompt_masker = self.build_prompt_masker()
166
+
167
+ # codec related
168
+ self.codec_embedder = nn.Embedding(codebook_size, num_mels)
169
+ lookahead_length = kwargs.get("lookahead_length", 4)
170
+ self.lookahead_conv1d = LookaheadBlock(num_mels, model_size, lookahead_length)
171
+
172
+ # spk embed related
173
+ if xvec_size is not None:
174
+ self.xvec_proj = torch.nn.Linear(xvec_size, num_mels)
175
+
176
+ # dit model related
177
+ self.dit_conf = dit_conf
178
+ self.dit_model = DiT(**dit_conf)
179
+
180
+ self.training_cfg_rate = kwargs.get("training_cfg_rate", 0)
181
+ self.only_mask_loss = kwargs.get("only_mask_loss", True)
182
+
183
+ # NOTE fm需要右看的下文
184
+ self.context_size = self.lookahead_conv1d.pre_lookahead_len
185
+
186
+ def build_prompt_masker(self):
187
+ prompt_type = self.prompt_conf.get("prompt_type", "free")
188
+ if prompt_type == "prefix":
189
+ from funcineforge.models.utils.mask_along_axis import MaskTailVariableMaxWidth
190
+ masker = MaskTailVariableMaxWidth(
191
+ mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"],
192
+ )
193
+ else:
194
+ raise NotImplementedError
195
+
196
+ return masker
197
+
198
+ @staticmethod
199
+ def norm_spk_emb(xvec):
200
+ xvec_mask = (~xvec.norm(dim=-1).isnan()) * (~xvec.norm(dim=-1).isinf())
201
+ xvec = xvec * xvec_mask.unsqueeze(-1)
202
+ xvec = xvec.mean(dim=1)
203
+ xvec = F.normalize(xvec, dim=1)
204
+
205
+ return xvec
206
+
207
+ def select_target_prompt(self, y: torch.Tensor, y_lengths: torch.Tensor):
208
+ # cond_mask: 1, 1, 1, ..., 0, 0, 0
209
+ cond_mask = self.prompt_masker(y, y_lengths, return_mask=True)
210
+
211
+ return cond_mask
212
+
213
+ @torch.no_grad()
214
+ def normalize_mel_feat(self, feat, feat_lengths):
215
+ # feat in B,T,D
216
+ if self.mel_norm_type == "mean_std":
217
+ max_length = feat.shape[1]
218
+ mask = (~make_pad_mask(feat_lengths, maxlen=max_length))
219
+ mask = mask.unsqueeze(-1).to(feat)
220
+ mean = ((feat * mask).sum(dim=(1, 2), keepdim=True) /
221
+ (mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1]))
222
+ var = (((feat - mean)**2 * mask).sum(dim=(1, 2), keepdim=True) /
223
+ (mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1] - 1)) # -1 for unbiased estimation
224
+ std = torch.sqrt(var)
225
+ feat = (feat - mean) / std
226
+ feat = feat * mask
227
+ return feat
228
+ if self.mel_norm_type == "min_max":
229
+ bb, tt, dd = feat.shape
230
+ mask = (~make_pad_mask(feat_lengths, maxlen=tt))
231
+ mask = mask.unsqueeze(-1).to(feat)
232
+ feat_min = (feat * mask).reshape([bb, tt * dd]).min(dim=1, keepdim=True).values.unsqueeze(-1)
233
+ feat_max = (feat * mask).reshape([bb, tt * dd]).max(dim=1, keepdim=True).values.unsqueeze(-1)
234
+ feat = (feat - feat_min) / (feat_max - feat_min)
235
+ # noise ~ N(0, I), P(x >= 3sigma) = 0.001, 3 is enough.
236
+ feat = (feat * 3) * mask # feat in [-3, 3]
237
+ return feat
238
+ else:
239
+ raise NotImplementedError
240
+
241
+ @torch.no_grad()
242
+ def extract_feat(self, y: torch.Tensor, y_lengths: torch.Tensor):
243
+ mel_extractor = self.mel_extractor.float()
244
+ feat = mel_extractor(y)
245
+ feat = feat.transpose(1, 2)
246
+ feat_lengths = (y_lengths / self.mel_extractor.hop_length).to(y_lengths)
247
+ if self.mel_norm_type is not None:
248
+ feat = self.normalize_mel_feat(feat, feat_lengths)
249
+ return feat, feat_lengths
250
+
251
+ def load_data(self, contents: dict, **kwargs):
252
+ fm_use_prompt = kwargs.get("fm_use_prompt", True)
253
+
254
+ # codec
255
+ codec = contents["codec"]
256
+ if isinstance(codec, np.ndarray):
257
+ codec = torch.from_numpy(codec)
258
+ # codec = torch.from_numpy(codec)[None, :]
259
+ codec_lengths = torch.tensor([codec.shape[1]], dtype=torch.int64)
260
+
261
+ # prompt codec (optional)
262
+ prompt_codec = kwargs.get("prompt_codec", None)
263
+ prompt_codec_lengths = None
264
+ if prompt_codec is not None and fm_use_prompt:
265
+ if isinstance(prompt_codec, str) and os.path.exists(prompt_codec):
266
+ prompt_codec = np.load(prompt_codec)
267
+ if isinstance(prompt_codec, np.ndarray):
268
+ prompt_codec = torch.from_numpy(prompt_codec)[None, :]
269
+ prompt_codec_lengths = torch.tensor([prompt_codec.shape[1]], dtype=torch.int64)
270
+ else:
271
+ prompt_codec = None
272
+ spk_emb = kwargs.get("spk_emb", None)
273
+ spk_emb_lengths = None
274
+ if spk_emb is not None:
275
+ if isinstance(spk_emb, str) and os.path.exists(spk_emb):
276
+ spk_emb = np.load(spk_emb)
277
+ if isinstance(spk_emb, np.ndarray):
278
+ spk_emb = torch.from_numpy(spk_emb)[None, :]
279
+ spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
280
+
281
+ # prompt wav as condition
282
+ prompt_wav = contents["vocal"]
283
+ prompt_wav_lengths = None
284
+ if prompt_wav is not None and fm_use_prompt and os.path.exists(prompt_wav):
285
+ if prompt_wav.endswith(".npy"):
286
+ spk_emb = np.load(prompt_wav)
287
+ spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
288
+ else:
289
+ spk_emb = extract_campp_xvec(prompt_wav, **kwargs)
290
+ spk_emb = torch.from_numpy(spk_emb)
291
+ spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
292
+ # prompt_wav = load_audio_text_image_video(prompt_wav, fs=self.sample_rate)
293
+ # prompt_wav = prompt_wav[None, :]
294
+ # prompt_wav_lengths = torch.tensor([prompt_wav.shape[1]], dtype=torch.int64)
295
+ else:
296
+ logging.info("[error] prompt_wav is None or not path or path not exists! Please provide the correct speaker embedding.")
297
+
298
+ output = {
299
+ "codec": codec,
300
+ "codec_lengths": codec_lengths,
301
+ "prompt_codec": prompt_codec,
302
+ "prompt_codec_lengths": prompt_codec_lengths,
303
+ "prompt_wav": None,
304
+ "prompt_wav_lengths": None,
305
+ "xvec": spk_emb,
306
+ "xvec_lengths": spk_emb_lengths,
307
+ }
308
+
309
+ return output
310
+
311
+ @torch.no_grad()
312
+ def inference(
313
+ self,
314
+ data_in,
315
+ data_lengths=None,
316
+ key: list = None,
317
+ chunk_size: int = -1,
318
+ finalize: bool = True,
319
+ **kwargs,
320
+ ):
321
+ uttid = key[0]
322
+ if kwargs.get("batch_size", 1) > 1:
323
+ raise NotImplementedError("batch decoding is not implemented")
324
+ batch = self.load_data(data_in[0], **kwargs)
325
+ batch = to_device(batch, kwargs["device"])
326
+ batch.update({'finalize': finalize, 'chunk_size': chunk_size})
327
+ feat = self._inference(**batch, **kwargs)
328
+ feat = feat.float()
329
+ logging.info(f"{uttid}: feat lengths {feat.shape[1]}")
330
+
331
+ return feat
332
+
333
+ @torch.no_grad()
334
+ def _inference(
335
+ self,
336
+ codec, codec_lengths,
337
+ prompt_codec=None, prompt_codec_lengths=None,
338
+ prompt_wav=None, prompt_wav_lengths=None,
339
+ xvec=None, xvec_lengths=None, chunk_size=-1, finalize=False,
340
+ **kwargs
341
+ ):
342
+ fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")]
343
+ rand_xvec = None
344
+ if xvec is not None:
345
+ if xvec.dim() == 2:
346
+ xvec = xvec.unsqueeze(1)
347
+ xvec_lens = torch.ones_like(xvec_lengths)
348
+ rand_xvec = self.norm_spk_emb(xvec)
349
+ self.xvec_proj.to(fm_dtype)
350
+ rand_xvec = self.xvec_proj(rand_xvec.to(fm_dtype))
351
+ rand_xvec = rand_xvec.unsqueeze(1)
352
+
353
+ if (codec >= self.codebook_size).any():
354
+ new_codec = codec[codec < self.codebook_size].unsqueeze(0)
355
+ logging.info(f"remove out-of-range token for FM: from {codec.shape[1]} to {new_codec.shape[1]}.")
356
+ codec_lengths = codec_lengths - (codec.shape[1] - new_codec.shape[1])
357
+ codec = new_codec
358
+ if prompt_codec is not None:
359
+ codec, codec_lengths = self.concat_prompt(prompt_codec, prompt_codec_lengths, codec, codec_lengths)
360
+ mask = (codec != -1).float().unsqueeze(-1)
361
+ codec_emb = self.codec_embedder(torch.clamp(codec, min=0)) * mask
362
+
363
+ self.lookahead_conv1d.to(fm_dtype)
364
+ if finalize is True:
365
+ context = torch.zeros(1, 0, self.codec_embedder.embedding_dim).to(fm_dtype)
366
+ else:
367
+ codec_emb, context = codec_emb[:, :-self.context_size].to(fm_dtype), codec_emb[:, -self.context_size:].to(fm_dtype)
368
+ codec_lengths = codec_lengths - self.context_size
369
+ mu, _ = self.lookahead_conv1d(codec_emb, codec_lengths, context)
370
+ mu = mu.repeat_interleave(self.feat_token_ratio, dim=1)
371
+ # print(mu.size())
372
+ conditions = torch.zeros([mu.size(0), mu.shape[1], self.num_mels]).to(mu)
373
+ # get conditions
374
+ if prompt_wav is not None:
375
+ if prompt_wav.ndim == 2:
376
+ prompt_wav, prompt_wav_lengths = self.extract_feat(prompt_wav, prompt_wav_lengths)
377
+ # NOTE 在fmax12k fm中,尝试mel interploate成token 2倍shape,而不是强制截断
378
+ prompt_wav = prompt_wav.to(fm_dtype)
379
+ for i, _len in enumerate(prompt_wav_lengths):
380
+ conditions[i, :_len] = prompt_wav[i]
381
+
382
+ feat_lengths = codec_lengths * self.feat_token_ratio
383
+ # NOTE add_optional_chunk_mask支持生成-1/1/15/30不同chunk_size的mask
384
+ mask = add_optional_chunk_mask(mu, torch.ones([1, 1, mu.shape[1]]).to(mu).bool(), False, False, 0, chunk_size, -1)
385
+ feat = self.solve_ode(mu, rand_xvec, conditions.to(fm_dtype), mask, **kwargs)
386
+
387
+ if prompt_codec is not None and prompt_wav is not None:
388
+ feat, feat_lens = self.remove_prompt(None, prompt_wav_lengths, feat, feat_lengths)
389
+
390
+ return feat
391
+
392
+ @staticmethod
393
+ def concat_prompt(prompt, prompt_lengths, text, text_lengths):
394
+ xs_list, x_len_list = [], []
395
+ for idx, (_prompt_len, _text_len) in enumerate(zip(prompt_lengths, text_lengths)):
396
+ xs_list.append(torch.concat([prompt[idx, :_prompt_len], text[idx, :_text_len]], dim=0))
397
+ x_len_list.append(_prompt_len + _text_len)
398
+
399
+ xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0)
400
+ x_lens = torch.tensor(x_len_list, dtype=torch.int64).to(xs.device)
401
+
402
+ return xs, x_lens
403
+
404
+ @staticmethod
405
+ def remove_prompt(prompt, prompt_lengths, padded, padded_lengths):
406
+ xs_list = []
407
+ for idx, (_prompt_len, _x_len) in enumerate(zip(prompt_lengths, padded_lengths)):
408
+ xs_list.append(padded[idx, _prompt_len: _x_len])
409
+
410
+ xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0)
411
+
412
+ return xs, padded_lengths - prompt_lengths
413
+
414
+ def get_rand_noise(self, mu: torch.Tensor, **kwargs):
415
+ use_fixed_noise_infer = kwargs.get("use_fixed_noise_infer", True)
416
+ max_len = kwargs.get("max_len", 50*300)
417
+ if use_fixed_noise_infer:
418
+ if not hasattr(self, "rand_noise") or self.rand_noise is None or self.rand_noise.shape[2] < mu.shape[2]:
419
+ self.rand_noise = torch.randn([1, max_len, mu.shape[2]]).to(mu)
420
+ logging.info("init random noise for Flow")
421
+ # return self.rand_noise[:, :mu.shape[1], :]
422
+ return torch.concat([self.rand_noise[:, :mu.shape[1], :] for _ in range(mu.size(0))], dim = 0)
423
+ else:
424
+ return torch.randn_like(mu)
425
+
426
+ def solve_ode(self, mu, rand_xvec, conditions, mask, **kwargs):
427
+ fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")]
428
+ temperature = kwargs.get("temperature", 1.0)
429
+ n_timesteps = kwargs.get("n_timesteps", 10)
430
+ infer_t_scheduler = kwargs.get("infer_t_scheduler", "cosine")
431
+ z = self.get_rand_noise(mu) * temperature
432
+ # print("z", z.size(), "mu", mu.size())
433
+ t_span = torch.linspace(0, 1, n_timesteps + 1).to(mu)
434
+ # print("t_span", t_span)
435
+ if infer_t_scheduler == 'cosine':
436
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
437
+ fm_time = time.time()
438
+ self.dit_model.to(fm_dtype)
439
+ feat = self.solve_euler(
440
+ z.to(fm_dtype), t_span=t_span.to(fm_dtype), mu=mu.to(fm_dtype), mask=mask,
441
+ spks=rand_xvec.to(fm_dtype), cond=conditions.to(fm_dtype), **kwargs
442
+ )
443
+ escape_time = (time.time() - fm_time) * 1000.0
444
+ logging.info(f"fm dec {n_timesteps} step time: {escape_time:.2f}, avg {escape_time/n_timesteps:.2f} ms")
445
+ return feat
446
+
447
+ def solve_euler(self, x, t_span, mu, mask, spks=None, cond=None, **kwargs):
448
+ """
449
+ Fixed euler solver for ODEs.
450
+ Args:
451
+ x (torch.Tensor): random noise
452
+ t_span (torch.Tensor): n_timesteps interpolated
453
+ shape: (n_timesteps + 1,)
454
+ mu (torch.Tensor): output of encoder
455
+ shape: (batch_size, n_feats, mel_timesteps)
456
+ mask (torch.Tensor): output_mask
457
+ shape: (batch_size, 1, mel_timesteps)
458
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
459
+ shape: (batch_size, spk_emb_dim)
460
+ cond: Not used but kept for future purposes
461
+ """
462
+ inference_cfg_rate = kwargs.get("inference_cfg_rate", 0.7)
463
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
464
+ # print("solve_euler cond", cond.size())
465
+ steps = 1
466
+ z, bz = x, x.shape[0]
467
+ while steps <= len(t_span) - 1:
468
+ if inference_cfg_rate > 0:
469
+ x_in = torch.concat([x, x], dim=0)
470
+ spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
471
+ mask_in = torch.concat([mask, mask], dim=0)
472
+ mu_in = torch.concat([mu, torch.zeros_like(mu)], dim=0)
473
+ t_in = torch.concat([t.unsqueeze(0) for _ in range(mu_in.size(0))], dim=0)
474
+ if isinstance(cond, torch.Tensor):
475
+ cond_in = torch.concat([cond, torch.zeros_like(cond)], dim=0)
476
+ else:
477
+ cond_in = dict(
478
+ prompt=[
479
+ torch.concat([cond["prompt"][0], torch.zeros_like(cond["prompt"][0])], dim=0),
480
+ torch.concat([cond["prompt"][1], cond["prompt"][1]], dim=0),
481
+ ]
482
+ )
483
+ else:
484
+ x_in, mask_in, mu_in, spks_in, t_in, cond_in = x, mask, mu, spks, t, cond
485
+
486
+ # if spks is not None:
487
+ # cond_in = cond_in + spks
488
+
489
+ infer_causal_mask_type = kwargs.get("infer_causal_mask_type", 0)
490
+ chunk_mask_value = self.dit_model.causal_mask_type[infer_causal_mask_type]["prob_min"]
491
+ hint_once(
492
+ f"flow mask type: {infer_causal_mask_type}, mask_rank value: {chunk_mask_value}.",
493
+ "chunk_mask_value"
494
+ )
495
+ # print("dit_model cond", x_in.size(), cond_in.size(), mu_in.size(), spks_in.size(), t_in.size())
496
+ # print(t_in)
497
+ dphi_dt = self.dit_model(
498
+ x_in, cond_in, mu_in, spks_in, t_in,
499
+ mask=mask_in,
500
+ mask_rand=torch.ones_like(t_in).reshape(-1, 1, 1) * chunk_mask_value
501
+ )
502
+ if inference_cfg_rate > 0:
503
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [bz, bz], dim=0)
504
+ dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
505
+ inference_cfg_rate * cfg_dphi_dt)
506
+
507
+ x = x + dt * dphi_dt
508
+ t = t + dt
509
+ # sol.append(x)
510
+ if steps < len(t_span) - 1:
511
+ dt = t_span[steps + 1] - t
512
+ steps += 1
513
+
514
+ return x
funcineforge/models/inference_model.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import logging
4
+ import numpy as np
5
+ import os
6
+ import torchaudio
7
+ import time
8
+ import shutil
9
+ from funcineforge.utils.set_all_random_seed import set_all_random_seed
10
+ from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
11
+
12
+
13
+ class FunCineForgeInferModel(nn.Module):
14
+ def __init__(
15
+ self,
16
+ lm_model,
17
+ fm_model,
18
+ voc_model,
19
+ **kwargs
20
+ ):
21
+ from funcineforge.auto.auto_model import AutoModel
22
+ super().__init__()
23
+ self.tokenizer = lm_model.kwargs["tokenizer"]
24
+ self.frontend = fm_model.kwargs["frontend"]
25
+ self.lm_model = lm_model.model
26
+ self.fm_model = fm_model.model
27
+ self.voc_model = voc_model.model
28
+ mel_extractor = self.fm_model.mel_extractor
29
+ if mel_extractor:
30
+ self.mel_frame_rate = mel_extractor.sampling_rate // mel_extractor.hop_length
31
+ self.sample_rate = mel_extractor.sampling_rate
32
+ else:
33
+ self.mel_frame_rate = self.fm_model.sample_rate // 480
34
+ self.sample_rate = self.fm_model.sample_rate
35
+
36
+ @torch.no_grad()
37
+ def inference(
38
+ self,
39
+ data_in,
40
+ data_lengths=None,
41
+ key: list = None,
42
+ **kwargs,
43
+ ):
44
+ uttid = key[0]
45
+ logging.info(f"generating {uttid}")
46
+ # text -> codec in [1, T]
47
+ kwargs["tokenizer"] = self.tokenizer
48
+ set_all_random_seed(kwargs.get("random_seed", 0))
49
+ lm_time = time.time()
50
+ codec, hit_eos, states = self.lm_model.inference(data_in, data_lengths, key, **kwargs)
51
+ logging.info(f"[llm time]: {((time.time()-lm_time)*1000):.2f} ms, [hit_eos]: {hit_eos}, [gen len]: {codec.shape[1]}, [speech tokens]: {codec[0].cpu().tolist()}")
52
+ wav, batch_data_time = None, 1.0
53
+ if codec.shape[1] > 0:
54
+ fm_time = time.time()
55
+ data_in[0]["codec"] = codec
56
+ set_all_random_seed(kwargs.get("random_seed", 0))
57
+ feat = self.fm_model.inference(data_in, data_lengths, key, **kwargs)
58
+ # feat -> wav
59
+ set_all_random_seed(kwargs.get("random_seed", 0))
60
+ wav = self.voc_model.inference([feat[0]], data_lengths, key, **kwargs)
61
+ # output save
62
+ output_dir = kwargs.get("output_dir", None)
63
+ if output_dir is not None:
64
+ feat_out_dir = os.path.join(output_dir, "feat")
65
+ os.makedirs(feat_out_dir, exist_ok=True)
66
+ np.save(os.path.join(feat_out_dir, f"{key[0]}.npy"), feat[0].cpu().numpy())
67
+
68
+ wav_out_dir = os.path.join(output_dir, "wav")
69
+ os.makedirs(wav_out_dir, exist_ok=True)
70
+ output_wav_path = os.path.join(wav_out_dir, f"{key[0]}.wav")
71
+ torchaudio.save(
72
+ output_wav_path, wav.cpu(),
73
+ sample_rate=self.sample_rate, encoding='PCM_S', bits_per_sample=16
74
+ )
75
+
76
+ silent_video_path = data_in[0]["video"]
77
+ if os.path.exists(silent_video_path):
78
+ video_out_dir = os.path.join(output_dir, "mp4")
79
+ video_gt_dir = os.path.join(output_dir, "gt")
80
+ os.makedirs(video_out_dir, exist_ok=True)
81
+ os.makedirs(video_gt_dir, exist_ok=True)
82
+ output_video_path = os.path.join(video_out_dir, f"{key[0]}.mp4")
83
+ copy_video_path = os.path.join(video_gt_dir, f"{key[0]}.mp4")
84
+ shutil.copy2(silent_video_path, copy_video_path)
85
+ self.merge_video_audio(
86
+ silent_video_path=silent_video_path,
87
+ wav_path=output_wav_path,
88
+ output_path=output_video_path,
89
+ )
90
+
91
+ logging.info(f"fm_voc time: {((time.time()-fm_time)*1000):.2f} ms")
92
+
93
+ batch_data_time = wav.shape[1] / self.voc_model.sample_rate
94
+
95
+ return [[wav]], {"batch_data_time": batch_data_time}
96
+
97
+ def merge_video_audio(self, silent_video_path, wav_path, output_path):
98
+
99
+ video_clip = VideoFileClip(silent_video_path)
100
+ video_duration = video_clip.duration
101
+ audio_clip = AudioFileClip(wav_path)
102
+ audio_duration = audio_clip.duration
103
+
104
+ if audio_duration >= video_duration:
105
+ audio_clip = audio_clip.subclipped(0, video_duration)
106
+
107
+ video_clip = video_clip.with_audio(audio_clip)
108
+ video_clip.write_videofile(
109
+ output_path,
110
+ codec='libx264',
111
+ audio_codec='aac',
112
+ fps=video_clip.fps,
113
+ logger=None
114
+ )
115
+ video_clip.close()
116
+ audio_clip.close()
funcineforge/models/language_model.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ from funcineforge.models.utils.llm_decoding import LLMDecoder
6
+ from funcineforge.utils.device_funcs import to_device
7
+ import numpy as np
8
+ from funcineforge.models.utils import dtype_map
9
+ from funcineforge.models import FunCineForgeSpecAug
10
+ from transformers import AutoModelForCausalLM
11
+ import pickle
12
+
13
+
14
+
15
+ class FunCineForgeLM(nn.Module):
16
+ def __init__(
17
+ self,
18
+ llm: str = None,
19
+ llm_conf: dict = None,
20
+ input_size: int = 80,
21
+ length_normalized_loss: bool = False,
22
+ **kwargs,
23
+ ):
24
+ super().__init__()
25
+
26
+ # llm
27
+ self.llm_conf = llm_conf
28
+ self.llm = None
29
+
30
+ init_param_path = llm_conf.get("init_param_path", "")
31
+ llm_load_kwargs = llm_conf.get("load_kwargs", {})
32
+ self.sample_rate = kwargs.get("sample_rate", 24000)
33
+ self.token_rate = kwargs.get("token_rate", 25)
34
+
35
+ if kwargs.get("infer_lora_merged", False):
36
+ llm_conf["use_qlora"] = False
37
+ llm_conf["use_lora"] = False
38
+ kwargs["infer_use_lora"] = False
39
+
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ init_param_path,
43
+ load_in_8bit=None,
44
+ device_map=None,
45
+ use_cache=None,
46
+ **llm_load_kwargs,
47
+ )
48
+
49
+ freeze = llm_conf.get("freeze", True)
50
+ if freeze:
51
+ for name, param in model.named_parameters():
52
+ param.requires_grad = False
53
+ model.eval()
54
+
55
+ logging.info(f"use_lora: {llm_conf.get('use_lora', False)}, use_qlora: {llm_conf.get('use_qlora', False)}, infer_use_lora: {kwargs.get('infer_use_lora',False)}, infer_lora_merged: {kwargs.get('infer_lora_merged',False)}")
56
+
57
+ if llm_conf.get("activation_checkpoint", False):
58
+ model.gradient_checkpointing_enable()
59
+
60
+ self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
61
+ self.llm = model.to(dtype_map[self.llm_dtype])
62
+ llm_dim = model.get_input_embeddings().weight.shape[-1]
63
+
64
+ if (not llm_conf.get("use_lora", False)) and (not kwargs.get("infer_use_lora",False)):
65
+ del self.llm.lm_head
66
+ self.codec_unit = kwargs.get("codec_unit", 6761)
67
+ self.timespk_unit = kwargs.get("timespk_unit", 1550)
68
+ self.codec_embed = nn.Embedding(self.codec_unit, llm_dim, 0)
69
+ self.timespk_embed = nn.Embedding(self.timespk_unit, llm_dim, 0)
70
+ self.codec_head = nn.Linear(llm_dim, self.codec_unit, bias=False)
71
+ self.face_size = kwargs.get("face_size", 512)
72
+ self.face_linear = nn.Linear(self.face_size, llm_dim)
73
+
74
+ self.length_normalized_loss = length_normalized_loss
75
+ self.ignore_id = kwargs.get("ignore_id", -100)
76
+
77
+ specaug = kwargs.get("specaug", None)
78
+ specaug_conf = kwargs.get("specaug_conf", {})
79
+ if specaug is not None:
80
+ specaug = FunCineForgeSpecAug(**specaug_conf)
81
+ self.specaug = specaug
82
+ rank = int(os.environ.get("RANK", 0))
83
+ logging.info(f"rank: {rank}, model is builded.")
84
+
85
+
86
+ def insert_face_embeddings(
87
+ self, inputs_embeds, face_emb, attention_mask, labels_ids,
88
+ codec_len, insert_pos, device
89
+ ):
90
+ """
91
+ 将face_emb插入到inputs_embeds中的指定位置, 同步更新attention_mask和labels_ids
92
+ Args:
93
+ inputs_embeds: (batch_size, token_num, dims) 输入embedding
94
+ face_emb: (batch_size, max_face_len, dims) 面部embedding
95
+ attention_mask: (batch_size, token_num) 注意力mask
96
+ labels_ids: (batch_size, token_num) 标签ID
97
+ codec_len: (batch_size,) 每个样本的实际face_emb长度
98
+ insert_pos: int 插入位置, SOS token之后
99
+ device
100
+ Returns:
101
+ padded_inputs_embeds: 插入face_emb并padding后的inputs_embeds
102
+ padded_attention_mask: 更新后的attention_mask
103
+ padded_labels: 更新后的labels_ids
104
+ """
105
+ batch_size, token_num, dims = inputs_embeds.shape
106
+ max_face_len = face_emb.size(1)
107
+
108
+ # 预计算新序列的最大长度
109
+ new_max_length = token_num + max_face_len
110
+
111
+ # 预分配输出张量
112
+ padded_inputs_embeds = torch.zeros(batch_size, new_max_length, dims, device=device)
113
+ padded_attention_mask = torch.zeros(batch_size, new_max_length, device=device, dtype=attention_mask.dtype)
114
+ padded_labels = torch.full((batch_size, new_max_length), self.ignore_id, device=device, dtype=labels_ids.dtype)
115
+
116
+ for i in range(batch_size):
117
+ current_face_len = codec_len[i].item()
118
+
119
+ # 直接填充,避免中间拼接
120
+ padded_inputs_embeds[i, :insert_pos] = inputs_embeds[i, :insert_pos]
121
+ padded_inputs_embeds[i, insert_pos:insert_pos+current_face_len] = face_emb[i, :current_face_len]
122
+ padded_inputs_embeds[i, insert_pos+current_face_len:token_num+current_face_len] = inputs_embeds[i, insert_pos:]
123
+
124
+ # 同样处理mask和labels
125
+ padded_attention_mask[i, :insert_pos] = attention_mask[i, :insert_pos]
126
+ padded_attention_mask[i, insert_pos:insert_pos+current_face_len] = 1
127
+ padded_attention_mask[i, insert_pos+current_face_len:token_num+current_face_len] = attention_mask[i, insert_pos:]
128
+
129
+ padded_labels[i, :insert_pos] = labels_ids[i, :insert_pos]
130
+ padded_labels[i, insert_pos:insert_pos+current_face_len] = self.ignore_id
131
+ padded_labels[i, insert_pos+current_face_len:token_num+current_face_len] = labels_ids[i, insert_pos:]
132
+
133
+ return padded_inputs_embeds, padded_attention_mask, padded_labels
134
+
135
+
136
+ def load_data(self, contents: dict, **kwargs):
137
+ lm_use_prompt = kwargs.get("lm_use_prompt", True)
138
+ tokenizer = kwargs.get("tokenizer")
139
+ # text + clue
140
+ text = contents["text"]
141
+ clue = "<|startofclue|>" + contents["clue"] + "<|endofclue|>"
142
+ if lm_use_prompt:
143
+ text = clue + text
144
+ text_ids = tokenizer.encode(text)
145
+ text_len = len(text_ids)
146
+ # timespk_ids
147
+ timespk_ids = contents["timespk_ids"].tolist()
148
+ type_id = contents["type_id"]
149
+ # sequence
150
+ sequence = [
151
+ kwargs['dataset_conf']["sos"],
152
+ *text_ids,
153
+ type_id,
154
+ *timespk_ids,
155
+ kwargs['dataset_conf']["turn_of_speech"]
156
+ ]
157
+ input_ids = torch.tensor(sequence, dtype=torch.int64)
158
+
159
+ # flag tensors
160
+ text_flag = torch.zeros(len(sequence), dtype=torch.float32)
161
+ timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
162
+ codec_flag = torch.zeros(len(sequence), dtype=torch.float32)
163
+ text_flag[1: text_len+1] = 1
164
+ timespk_flag[text_len+1: -1] = 1
165
+ codec_flag = 1 - text_flag - timespk_flag
166
+
167
+ # face embs
168
+ speech_len = contents["speech_len"]
169
+ face_embs = torch.zeros((speech_len, self.face_size), dtype=torch.float32)
170
+ face_path = contents.get("face")
171
+ with open(face_path, 'rb') as f:
172
+ stat_obj = pickle.load(f)
173
+ embeddings = stat_obj['embeddings']
174
+ faceI = stat_obj['faceI']
175
+ for emb, frameI in zip(embeddings, faceI):
176
+ fi = int(frameI)
177
+ if 0 <= fi < speech_len:
178
+ end = min(fi + 5, speech_len)
179
+ face_embs[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
180
+
181
+ # batch dimension
182
+ input_ids = input_ids[None, :]
183
+ text_flag = text_flag[None, :]
184
+ timespk_flag = timespk_flag[None, :]
185
+ codec_flag = codec_flag[None, :]
186
+ face_embs = face_embs[None, :, :]
187
+ output = {
188
+ "input_ids": input_ids,
189
+ "face_embs": face_embs,
190
+ "text_flag": text_flag > 0,
191
+ "timespk_flag": timespk_flag > 0,
192
+ "codec_flag": codec_flag > 0,
193
+ "prompt_codec": None, # you can add prompt codec here if needed
194
+ }
195
+ return output
196
+
197
+ def inference_prepare(self, data_in, **kwargs):
198
+ if kwargs.get("batch_size", 1) > 1:
199
+ raise NotImplementedError("batch decoding is not implemented")
200
+ output = self.load_data(data_in[0], **kwargs)
201
+ batch = to_device(output, kwargs["device"])
202
+ input_ids = batch["input_ids"]
203
+ input_ids = input_ids * (input_ids > 0)
204
+ text_flag = batch["text_flag"]
205
+ timespk_flag = batch["timespk_flag"]
206
+ codec_flag = batch["codec_flag"]
207
+ face_embs = batch["face_embs"]
208
+
209
+ if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
210
+ text_embeds = self.llm.base_model.model.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
211
+ else:
212
+ text_embeds = self.llm.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
213
+ timespk_embeds = self.timespk_embed(input_ids * timespk_flag) * timespk_flag.unsqueeze(-1)
214
+ codec_embs = self.codec_embed(input_ids * codec_flag) * codec_flag.unsqueeze(-1)
215
+ face_embs = self.face_linear(face_embs)
216
+
217
+ inputs_embeds = text_embeds + timespk_embeds + codec_embs
218
+
219
+ inputs_embeds = torch.cat([
220
+ inputs_embeds[:, 0:1, :], # sos token
221
+ face_embs, # face embeddings
222
+ inputs_embeds[:, 1:, :] # inputs_embeds after sos
223
+ ], dim=1)
224
+
225
+ prompt_codec = batch.get("prompt_codec", None)
226
+ if prompt_codec is not None:
227
+ codec_emb = self.codec_embed(prompt_codec)
228
+ inputs_embeds = torch.cat((inputs_embeds, codec_emb), dim=1)
229
+
230
+ return inputs_embeds
231
+
232
+ @torch.no_grad()
233
+ def inference(
234
+ self,
235
+ data_in,
236
+ data_lengths=None,
237
+ key: list = None,
238
+ **kwargs,
239
+ ):
240
+ uttid = key[0]
241
+ inputs_emb = self.inference_prepare(data_in, **kwargs)
242
+
243
+ logging.info(f"{uttid}: min length: {kwargs['min_length']}, max length: {kwargs['max_length']}")
244
+
245
+ dtype = dtype_map[kwargs.get("llm_dtype", "fp32")]
246
+ if not hasattr(self, "llm_generator"):
247
+ llm_generator_conf = kwargs.get("dataset_conf", {})
248
+ self.llm_generator = LLMDecoder(
249
+ token_embeder=self.codec_embed,
250
+ **llm_generator_conf
251
+ ).to(dtype)
252
+
253
+ if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
254
+ self.llm.base_model.model.lm_head = self.codec_head.to(dtype)
255
+ else:
256
+ self.llm.lm_head = self.codec_head.to(dtype)
257
+
258
+ gen_codec, hit_eos, states = self.llm_generator(
259
+ inputs_emb.to(dtype),
260
+ self.llm,
261
+ states=kwargs.get("states", {}),
262
+ **kwargs
263
+ )
264
+
265
+ output_dir = kwargs.get("output_dir", None)
266
+ if output_dir is not None:
267
+ output_dir = os.path.join(output_dir, "codec")
268
+ os.makedirs(output_dir, exist_ok=True)
269
+ np.save(
270
+ os.path.join(output_dir, f"{key[0]}.npy"),
271
+ gen_codec[0].cpu().numpy()
272
+ )
273
+
274
+ return gen_codec, hit_eos, states
funcineforge/models/modules/__init__.py ADDED
File without changes
funcineforge/models/modules/dit_flow_matching/__init__.py ADDED
File without changes
funcineforge/models/modules/dit_flow_matching/dit_model.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+ from einops import repeat
16
+ from x_transformers.x_transformers import RotaryEmbedding
17
+ from funcineforge.models.utils.masks import causal_block_mask
18
+
19
+ from .dit_modules import (
20
+ TimestepEmbedding,
21
+ ConvNeXtV2Block,
22
+ CausalConvPositionEmbedding,
23
+ DiTBlock,
24
+ AdaLayerNormZero_Final,
25
+ precompute_freqs_cis,
26
+ get_pos_embed_indices,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+
33
+ class TextEmbedding(nn.Module):
34
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
35
+ super().__init__()
36
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
37
+
38
+ if conv_layers > 0:
39
+ self.extra_modeling = True
40
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
41
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
42
+ self.text_blocks = nn.Sequential(
43
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
44
+ )
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
49
+ batch, text_len = text.shape[0], text.shape[1]
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value=0)
53
+
54
+ if drop_text: # cfg for text
55
+ text = torch.zeros_like(text)
56
+
57
+ text = self.text_embed(text) # b n -> b n d
58
+
59
+ # possible extra modeling
60
+ if self.extra_modeling:
61
+ # sinus pos emb
62
+ batch_start = torch.zeros((batch,), dtype=torch.long)
63
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
+ text_pos_embed = self.freqs_cis[pos_idx]
65
+ text = text + text_pos_embed
66
+
67
+ # convnextv2 blocks
68
+ text = self.text_blocks(text)
69
+
70
+ return text
71
+
72
+
73
+ # noised input audio and context mixing embedding
74
+
75
+
76
+ class InputEmbedding(nn.Module):
77
+ def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
78
+ super().__init__()
79
+ spk_dim = 0 if spk_dim is None else spk_dim
80
+ self.spk_dim = spk_dim
81
+ self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
82
+ self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
83
+
84
+ def forward(
85
+ self,
86
+ x: float["b n d"],
87
+ cond: float["b n d"],
88
+ text_embed: float["b n d"],
89
+ spks: float["b d"],
90
+ ):
91
+ to_cat = [x, cond, text_embed]
92
+ if self.spk_dim > 0:
93
+ spks = repeat(spks, "b c -> b t c", t=x.shape[1])
94
+ to_cat.append(spks)
95
+
96
+ x = self.proj(torch.cat(to_cat, dim=-1))
97
+ x = self.conv_pos_embed(x) + x
98
+ return x
99
+
100
+
101
+ # Transformer backbone using DiT blocks
102
+
103
+
104
+ class DiT(nn.Module):
105
+ def __init__(
106
+ self,
107
+ *,
108
+ dim,
109
+ depth=8,
110
+ heads=8,
111
+ dim_head=64,
112
+ dropout=0.1,
113
+ ff_mult=4,
114
+ mel_dim=80,
115
+ mu_dim=None,
116
+ long_skip_connection=False,
117
+ spk_dim=None,
118
+ **kwargs
119
+ ):
120
+ super().__init__()
121
+
122
+ self.time_embed = TimestepEmbedding(dim)
123
+ if mu_dim is None:
124
+ mu_dim = mel_dim
125
+ self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
126
+
127
+ self.rotary_embed = RotaryEmbedding(dim_head)
128
+
129
+ self.dim = dim
130
+ self.depth = depth
131
+
132
+ self.transformer_blocks = nn.ModuleList(
133
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
134
+ )
135
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
136
+
137
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
138
+ self.proj_out = nn.Linear(dim, mel_dim)
139
+ self.causal_mask_type = kwargs.get("causal_mask_type", None)
140
+
141
+ def build_mix_causal_mask(self, attn_mask, rand=None, ratio=None):
142
+ b, _, _, t = attn_mask.shape
143
+ if rand is None:
144
+ rand = torch.rand((b, 1, 1, 1), device=attn_mask.device, dtype=torch.float32)
145
+ mixed_mask = attn_mask.clone()
146
+ for item in self.causal_mask_type:
147
+ prob_min, prob_max = item["prob_min"], item["prob_max"]
148
+ _ratio = 1
149
+ if "ratio" in item:
150
+ _ratio = item["ratio"]
151
+ if ratio is not None:
152
+ _ratio = ratio
153
+ block_size = item["block_size"] * _ratio
154
+ if block_size <= 0:
155
+ causal_mask = attn_mask
156
+ else:
157
+ causal_mask = causal_block_mask(
158
+ t, block_size, attn_mask.device, torch.float32
159
+ ).unsqueeze(0).unsqueeze(1) # 1,1,T,T
160
+ flag = (prob_min <= rand) & (rand < prob_max)
161
+ mixed_mask = mixed_mask * (~flag) + (causal_mask * attn_mask) * flag
162
+
163
+ return mixed_mask
164
+
165
+ def forward(
166
+ self,
167
+ x: float["b n d"], # nosied input audio
168
+ cond: float["b n d"], # masked cond audio
169
+ mu: int["b nt d"], # mu
170
+ spks: float["b 1 d"], # spk xvec
171
+ time: float["b"] | float[""], # time step
172
+ return_hidden: bool = False,
173
+ mask: bool["b 1 n"] | None = None,
174
+ mask_rand: float["b 1 1"] = None, # for mask flag type
175
+ **kwargs,
176
+ ):
177
+ batch, seq_len = x.shape[0], x.shape[1]
178
+ if time.ndim == 0:
179
+ time = time.repeat(batch)
180
+
181
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
182
+ t = self.time_embed(time)
183
+ x = self.input_embed(x, cond, mu, spks.squeeze(1))
184
+
185
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
186
+
187
+ if self.long_skip_connection is not None:
188
+ residual = x
189
+
190
+ mask = mask.unsqueeze(1) # B,1,1,T
191
+ if self.causal_mask_type is not None:
192
+ mask = self.build_mix_causal_mask(mask, rand=mask_rand.unsqueeze(-1))
193
+
194
+ for block in self.transformer_blocks:
195
+ # mask-out padded values for amp training
196
+ x = x * mask[:, 0, -1, :].unsqueeze(-1)
197
+ x = block(x, t, mask=mask.bool(), rope=rope)
198
+
199
+ if self.long_skip_connection is not None:
200
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
201
+
202
+ x = self.norm_out(x, t)
203
+ output = self.proj_out(x)
204
+
205
+ if return_hidden:
206
+ return output, None
207
+
208
+ return output
funcineforge/models/modules/dit_flow_matching/dit_modules.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+
19
+ from x_transformers.x_transformers import apply_rotary_pos_emb
20
+
21
+
22
+ # raw wav to mel spec
23
+ class MelSpec(nn.Module):
24
+ def __init__(
25
+ self,
26
+ filter_length=1024,
27
+ hop_length=256,
28
+ win_length=1024,
29
+ n_mel_channels=100,
30
+ target_sample_rate=24_000,
31
+ normalize=False,
32
+ power=1,
33
+ norm=None,
34
+ center=True,
35
+ ):
36
+ super().__init__()
37
+ self.n_mel_channels = n_mel_channels
38
+
39
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
40
+ sample_rate=target_sample_rate,
41
+ n_fft=filter_length,
42
+ win_length=win_length,
43
+ hop_length=hop_length,
44
+ n_mels=n_mel_channels,
45
+ power=power,
46
+ center=center,
47
+ normalized=normalize,
48
+ norm=norm,
49
+ )
50
+
51
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
52
+
53
+ def forward(self, inp):
54
+ if len(inp.shape) == 3:
55
+ inp = inp.squeeze(1) # 'b 1 nw -> b nw'
56
+
57
+ assert len(inp.shape) == 2
58
+
59
+ if self.dummy.device != inp.device:
60
+ self.to(inp.device)
61
+
62
+ mel = self.mel_stft(inp)
63
+ mel = mel.clamp(min=1e-5).log()
64
+ return mel
65
+
66
+
67
+ # sinusoidal position embedding
68
+
69
+
70
+ class SinusPositionEmbedding(nn.Module):
71
+ def __init__(self, dim):
72
+ super().__init__()
73
+ self.dim = dim
74
+
75
+ def forward(self, x, scale=1000):
76
+ device = x.device
77
+ half_dim = self.dim // 2
78
+ emb = math.log(10000) / (half_dim - 1)
79
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
80
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
81
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
82
+ return emb
83
+
84
+
85
+ # convolutional position embedding
86
+
87
+
88
+ class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size=31, groups=16):
90
+ super().__init__()
91
+ assert kernel_size % 2 != 0
92
+ self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
94
+ nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
96
+ nn.Mish(),
97
+ )
98
+
99
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
100
+ if mask is not None:
101
+ mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.0)
103
+
104
+ x = x.permute(0, 2, 1)
105
+ x = self.conv1d(x)
106
+ out = x.permute(0, 2, 1)
107
+
108
+ if mask is not None:
109
+ out = out.masked_fill(~mask, 0.0)
110
+
111
+ return out
112
+
113
+
114
+ class CausalConvPositionEmbedding(nn.Module):
115
+ def __init__(self, dim, kernel_size=31, groups=16):
116
+ super().__init__()
117
+ assert kernel_size % 2 != 0
118
+ self.kernel_size = kernel_size
119
+ self.conv1 = nn.Sequential(
120
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
121
+ nn.Mish(),
122
+ )
123
+ self.conv2 = nn.Sequential(
124
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
125
+ nn.Mish(),
126
+ )
127
+
128
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
129
+ if mask is not None:
130
+ mask = mask[..., None]
131
+ x = x.masked_fill(~mask, 0.0)
132
+
133
+ x = x.permute(0, 2, 1)
134
+ x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
135
+ x = self.conv1(x)
136
+ x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
137
+ x = self.conv2(x)
138
+ out = x.permute(0, 2, 1)
139
+
140
+ if mask is not None:
141
+ out = out.masked_fill(~mask, 0.0)
142
+
143
+ return out
144
+
145
+
146
+ # rotary positional embedding related
147
+
148
+
149
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
150
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
151
+ # has some connection to NTK literature
152
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
153
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
154
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
155
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
156
+ t = torch.arange(end, device=freqs.device) # type: ignore
157
+ freqs = torch.outer(t, freqs).float() # type: ignore
158
+ freqs_cos = torch.cos(freqs) # real part
159
+ freqs_sin = torch.sin(freqs) # imaginary part
160
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
161
+
162
+
163
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
164
+ # length = length if isinstance(length, int) else length.max()
165
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
166
+ pos = (
167
+ start.unsqueeze(1)
168
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
169
+ )
170
+ # avoid extra long error.
171
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
172
+ return pos
173
+
174
+
175
+ # Global Response Normalization layer (Instance Normalization ?)
176
+
177
+
178
+ class GRN(nn.Module):
179
+ def __init__(self, dim):
180
+ super().__init__()
181
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
182
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
183
+
184
+ def forward(self, x):
185
+ with torch.cuda.amp.autocast(enabled=False):
186
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
187
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
188
+ return self.gamma * (x * Nx) + self.beta + x
189
+
190
+
191
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
192
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
193
+
194
+
195
+ class ConvNeXtV2Block(nn.Module):
196
+ def __init__(
197
+ self,
198
+ dim: int,
199
+ intermediate_dim: int,
200
+ dilation: int = 1,
201
+ ):
202
+ super().__init__()
203
+ padding = (dilation * (7 - 1)) // 2
204
+ self.dwconv = nn.Conv1d(
205
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
206
+ ) # depthwise conv
207
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
208
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
209
+ self.act = nn.GELU()
210
+ self.grn = GRN(intermediate_dim)
211
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
212
+
213
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
214
+ residual = x
215
+ x = x.transpose(1, 2) # b n d -> b d n
216
+ x = self.dwconv(x)
217
+ x = x.transpose(1, 2) # b d n -> b n d
218
+ with torch.cuda.amp.autocast(enabled=False):
219
+ x = self.norm(x)
220
+ x = self.pwconv1(x)
221
+ x = self.act(x)
222
+ x = self.grn(x)
223
+ x = self.pwconv2(x)
224
+ return residual + x
225
+
226
+
227
+ # AdaLayerNormZero
228
+ # return with modulated x for attn input, and params for later mlp modulation
229
+
230
+
231
+ class AdaLayerNormZero(nn.Module):
232
+ def __init__(self, dim):
233
+ super().__init__()
234
+
235
+ self.silu = nn.SiLU()
236
+ self.linear = nn.Linear(dim, dim * 6)
237
+
238
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
239
+
240
+ def forward(self, x, emb=None):
241
+ emb = self.linear(self.silu(emb))
242
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
243
+
244
+ with torch.cuda.amp.autocast(enabled=False):
245
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
246
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
247
+
248
+
249
+ # AdaLayerNormZero for final layer
250
+ # return only with modulated x for attn input, cuz no more mlp modulation
251
+
252
+
253
+ class AdaLayerNormZero_Final(nn.Module):
254
+ def __init__(self, dim):
255
+ super().__init__()
256
+
257
+ self.silu = nn.SiLU()
258
+ self.linear = nn.Linear(dim, dim * 2)
259
+
260
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
261
+
262
+ def forward(self, x, emb):
263
+ emb = self.linear(self.silu(emb))
264
+ scale, shift = torch.chunk(emb, 2, dim=1)
265
+
266
+ with torch.cuda.amp.autocast(enabled=False):
267
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
268
+ return x
269
+
270
+
271
+ # FeedForward
272
+
273
+
274
+ class FeedForward(nn.Module):
275
+ def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
276
+ super().__init__()
277
+ inner_dim = int(dim * mult)
278
+ dim_out = dim_out if dim_out is not None else dim
279
+
280
+ activation = nn.GELU(approximate=approximate)
281
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
282
+ self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
283
+
284
+ def forward(self, x):
285
+ return self.ff(x)
286
+
287
+
288
+ # Attention with possible joint part
289
+ # modified from diffusers/src/diffusers/models/attention_processor.py
290
+
291
+
292
+ class Attention(nn.Module):
293
+ def __init__(
294
+ self,
295
+ processor: JointAttnProcessor | AttnProcessor,
296
+ dim: int,
297
+ heads: int = 8,
298
+ dim_head: int = 64,
299
+ dropout: float = 0.0,
300
+ context_dim: Optional[int] = None, # if not None -> joint attention
301
+ context_pre_only=None,
302
+ ):
303
+ super().__init__()
304
+
305
+ if not hasattr(F, "scaled_dot_product_attention"):
306
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
307
+
308
+ self.processor = processor
309
+
310
+ self.dim = dim
311
+ self.heads = heads
312
+ self.inner_dim = dim_head * heads
313
+ self.dropout = dropout
314
+
315
+ self.context_dim = context_dim
316
+ self.context_pre_only = context_pre_only
317
+
318
+ self.to_q = nn.Linear(dim, self.inner_dim)
319
+ self.to_k = nn.Linear(dim, self.inner_dim)
320
+ self.to_v = nn.Linear(dim, self.inner_dim)
321
+
322
+ if self.context_dim is not None:
323
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
324
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
325
+ if self.context_pre_only is not None:
326
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
327
+
328
+ self.to_out = nn.ModuleList([])
329
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
330
+ self.to_out.append(nn.Dropout(dropout))
331
+
332
+ if self.context_pre_only is not None and not self.context_pre_only:
333
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
334
+
335
+ def forward(
336
+ self,
337
+ x: float["b n d"], # noised input x # noqa: F722
338
+ c: float["b n d"] = None, # context c # noqa: F722
339
+ mask: bool["b n"] | None = None, # noqa: F722
340
+ rope=None, # rotary position embedding for x
341
+ c_rope=None, # rotary position embedding for c
342
+ ) -> torch.Tensor:
343
+ if c is not None:
344
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
345
+ else:
346
+ return self.processor(self, x, mask=mask, rope=rope)
347
+
348
+
349
+ # Attention processor
350
+
351
+
352
+ class AttnProcessor:
353
+ def __init__(self):
354
+ pass
355
+
356
+ def __call__(
357
+ self,
358
+ attn: Attention,
359
+ x: float["b n d"], # noised input x # noqa: F722
360
+ mask: bool["b n"] | None = None, # noqa: F722
361
+ rope=None, # rotary position embedding
362
+ ) -> torch.FloatTensor:
363
+ batch_size = x.shape[0]
364
+
365
+ # `sample` projections.
366
+ query = attn.to_q(x)
367
+ key = attn.to_k(x)
368
+ value = attn.to_v(x)
369
+
370
+ # apply rotary position embedding
371
+ if rope is not None:
372
+ freqs, xpos_scale = rope
373
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
374
+
375
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
376
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
377
+
378
+ # attention
379
+ inner_dim = key.shape[-1]
380
+ head_dim = inner_dim // attn.heads
381
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
382
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
383
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
384
+
385
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
386
+ if mask is not None:
387
+ attn_mask = mask
388
+ if attn_mask.dim() == 2:
389
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
390
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
391
+ else:
392
+ attn_mask = None
393
+
394
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
395
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
396
+ x = x.to(query.dtype)
397
+
398
+ # linear proj
399
+ x = attn.to_out[0](x)
400
+ # dropout
401
+ x = attn.to_out[1](x)
402
+
403
+ if mask is not None:
404
+ if mask.dim() == 2:
405
+ mask = mask.unsqueeze(-1)
406
+ else:
407
+ mask = mask[:, 0, -1].unsqueeze(-1)
408
+ x = x.masked_fill(~mask, 0.0)
409
+
410
+ return x
411
+
412
+
413
+ # Joint Attention processor for MM-DiT
414
+ # modified from diffusers/src/diffusers/models/attention_processor.py
415
+
416
+
417
+ class JointAttnProcessor:
418
+ def __init__(self):
419
+ pass
420
+
421
+ def __call__(
422
+ self,
423
+ attn: Attention,
424
+ x: float["b n d"], # noised input x # noqa: F722
425
+ c: float["b nt d"] = None, # context c, here text # noqa: F722
426
+ mask: bool["b n"] | None = None, # noqa: F722
427
+ rope=None, # rotary position embedding for x
428
+ c_rope=None, # rotary position embedding for c
429
+ ) -> torch.FloatTensor:
430
+ residual = x
431
+
432
+ batch_size = c.shape[0]
433
+
434
+ # `sample` projections.
435
+ query = attn.to_q(x)
436
+ key = attn.to_k(x)
437
+ value = attn.to_v(x)
438
+
439
+ # `context` projections.
440
+ c_query = attn.to_q_c(c)
441
+ c_key = attn.to_k_c(c)
442
+ c_value = attn.to_v_c(c)
443
+
444
+ # apply rope for context and noised input independently
445
+ if rope is not None:
446
+ freqs, xpos_scale = rope
447
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
448
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
449
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
450
+ if c_rope is not None:
451
+ freqs, xpos_scale = c_rope
452
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
453
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
454
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
455
+
456
+ # attention
457
+ query = torch.cat([query, c_query], dim=1)
458
+ key = torch.cat([key, c_key], dim=1)
459
+ value = torch.cat([value, c_value], dim=1)
460
+
461
+ inner_dim = key.shape[-1]
462
+ head_dim = inner_dim // attn.heads
463
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
464
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
465
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
466
+
467
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
468
+ if mask is not None:
469
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
470
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
471
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
472
+ else:
473
+ attn_mask = None
474
+
475
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
476
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
477
+ x = x.to(query.dtype)
478
+
479
+ # Split the attention outputs.
480
+ x, c = (
481
+ x[:, : residual.shape[1]],
482
+ x[:, residual.shape[1] :],
483
+ )
484
+
485
+ # linear proj
486
+ x = attn.to_out[0](x)
487
+ # dropout
488
+ x = attn.to_out[1](x)
489
+ if not attn.context_pre_only:
490
+ c = attn.to_out_c(c)
491
+
492
+ if mask is not None:
493
+ mask = mask.unsqueeze(-1)
494
+ x = x.masked_fill(~mask, 0.0)
495
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
496
+
497
+ return x, c
498
+
499
+
500
+ # DiT Block
501
+
502
+
503
+ class DiTBlock(nn.Module):
504
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
505
+ super().__init__()
506
+
507
+ self.attn_norm = AdaLayerNormZero(dim)
508
+ self.attn = Attention(
509
+ processor=AttnProcessor(),
510
+ dim=dim,
511
+ heads=heads,
512
+ dim_head=dim_head,
513
+ dropout=dropout,
514
+ )
515
+
516
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
517
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
518
+
519
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
520
+ # pre-norm & modulation for attention input
521
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
522
+
523
+ # attention
524
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
525
+
526
+ # process attention output for input x
527
+ x = x + gate_msa.unsqueeze(1) * attn_output
528
+
529
+ with torch.cuda.amp.autocast(enabled=False):
530
+ ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
531
+ ff_output = self.ff(ff_norm)
532
+ x = x + gate_mlp.unsqueeze(1) * ff_output
533
+
534
+ return x
535
+
536
+
537
+ # MMDiT Block https://arxiv.org/abs/2403.03206
538
+
539
+
540
+ class MMDiTBlock(nn.Module):
541
+ r"""
542
+ modified from diffusers/src/diffusers/models/attention.py
543
+
544
+ notes.
545
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
546
+ _x: noised input related. (right part)
547
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
548
+ """
549
+
550
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
551
+ super().__init__()
552
+
553
+ self.context_pre_only = context_pre_only
554
+
555
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
556
+ self.attn_norm_x = AdaLayerNormZero(dim)
557
+ self.attn = Attention(
558
+ processor=JointAttnProcessor(),
559
+ dim=dim,
560
+ heads=heads,
561
+ dim_head=dim_head,
562
+ dropout=dropout,
563
+ context_dim=dim,
564
+ context_pre_only=context_pre_only,
565
+ )
566
+
567
+ if not context_pre_only:
568
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
569
+ self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
570
+ else:
571
+ self.ff_norm_c = None
572
+ self.ff_c = None
573
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
574
+ self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
575
+
576
+ def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
577
+ # pre-norm & modulation for attention input
578
+ if self.context_pre_only:
579
+ norm_c = self.attn_norm_c(c, t)
580
+ else:
581
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
582
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
583
+
584
+ # attention
585
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
586
+
587
+ # process attention output for context c
588
+ if self.context_pre_only:
589
+ c = None
590
+ else: # if not last layer
591
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
592
+
593
+ with torch.cuda.amp.autocast(enabled=False):
594
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
595
+ c_ff_output = self.ff_c(norm_c)
596
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
597
+
598
+ # process attention output for input x
599
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
600
+
601
+ with torch.cuda.amp.autocast(enabled=False):
602
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
603
+ x_ff_output = self.ff_x(norm_x)
604
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
605
+
606
+ return c, x
607
+
608
+
609
+ # time step conditioning embedding
610
+
611
+
612
+ class TimestepEmbedding(nn.Module):
613
+ def __init__(self, dim, freq_embed_dim=256):
614
+ super().__init__()
615
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
616
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
617
+
618
+ def forward(self, timestep: float["b"]): # noqa: F821
619
+ time_hidden = self.time_embed(timestep)
620
+ time_hidden = time_hidden.to(timestep.dtype)
621
+ time = self.time_mlp(time_hidden) # b d
622
+ return time
funcineforge/models/modules/hifigan/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def get_padding(kernel_size, dilation=1):
3
+ return int((kernel_size * dilation - dilation) / 2)
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ from funcineforge.models.modules.hifigan.generator import HifiGenerator, NsfHifiGenerator, HiFTGenerator
13
+ from funcineforge.models.modules.hifigan.discriminator import MultipleDiscriminator
14
+ from funcineforge.models.modules.hifigan.nsf_utils import ConvRNNF0Predictor
funcineforge/models/modules/hifigan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
funcineforge/models/modules/hifigan/discriminator.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """hifigan based dicriminator implementation.
2
+
3
+ This code is modified from https://github.com/jik876/hifi-gan and https://github.com/kan-bayashi/ParallelWaveGAN.
4
+
5
+ """
6
+
7
+ import typing as tp
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.nn as nn
12
+ from torch.nn import Conv2d, AvgPool1d, Conv1d
13
+ from torch.nn.utils import weight_norm, spectral_norm
14
+
15
+ from funcineforge.models.modules.hifigan import get_padding
16
+
17
+
18
+ class DiscriminatorP(torch.nn.Module):
19
+ def __init__(self, period, kernel_size=5, stride=3,
20
+ use_spectral_norm=False, lrelu_slope=0.1):
21
+ super(DiscriminatorP, self).__init__()
22
+ self.period = period
23
+ self.lrelu_slope = lrelu_slope
24
+
25
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
26
+ self.convs = nn.ModuleList([
27
+ norm_f(
28
+ Conv2d(
29
+ 1,
30
+ 32, (kernel_size, 1), (stride, 1),
31
+ padding=(get_padding(5, 1), 0))),
32
+ norm_f(
33
+ Conv2d(
34
+ 32,
35
+ 128, (kernel_size, 1), (stride, 1),
36
+ padding=(get_padding(5, 1), 0))),
37
+ norm_f(
38
+ Conv2d(
39
+ 128,
40
+ 512, (kernel_size, 1), (stride, 1),
41
+ padding=(get_padding(5, 1), 0))),
42
+ norm_f(
43
+ Conv2d(
44
+ 512,
45
+ 1024, (kernel_size, 1), (stride, 1),
46
+ padding=(get_padding(5, 1), 0))),
47
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
48
+ ])
49
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
50
+
51
+ def forward(self, x):
52
+ fmap = []
53
+
54
+ # 1d to 2d
55
+ b, c, t = x.shape
56
+ if t % self.period != 0: # pad first
57
+ n_pad = self.period - (t % self.period)
58
+ x = F.pad(x, (0, n_pad), "reflect")
59
+ t = t + n_pad
60
+ x = x.view(b, c, t // self.period, self.period)
61
+
62
+ for l in self.convs:
63
+ x = l(x)
64
+ x = F.leaky_relu(x, self.lrelu_slope)
65
+ fmap.append(x)
66
+ x = self.conv_post(x)
67
+ fmap.append(x)
68
+ x = torch.flatten(x, 1, -1)
69
+
70
+ return x, fmap
71
+
72
+
73
+ class MultiPeriodDiscriminator(torch.nn.Module):
74
+ def __init__(self,
75
+ in_channels: int = 1,
76
+ periods: tp.List[int] = [2, 3, 5, 7, 11]):
77
+ super(MultiPeriodDiscriminator, self).__init__()
78
+ self.discriminators = nn.ModuleList([
79
+ DiscriminatorP(p) for p in periods
80
+ ])
81
+
82
+ def forward(self, x: torch.Tensor, return_intermediates: bool = True):
83
+ """Calculate forward propagation.
84
+
85
+ Args:
86
+ x (Tensor): Input noise signal (B, 1, T).
87
+
88
+ Returns:
89
+ List: List of list of each discriminator outputs, which consists of each
90
+ layer output tensors.
91
+
92
+ """
93
+ outs = []
94
+ for f in self.discriminators:
95
+ # outs += [f(x)]
96
+ if return_intermediates:
97
+ outs.append(f(x))
98
+ else:
99
+ outs.append(f(x)[0])
100
+
101
+ return outs
102
+
103
+
104
+ class DiscriminatorS(torch.nn.Module):
105
+ def __init__(self, use_spectral_norm=False, lrelu_slope=0.1):
106
+ super(DiscriminatorS, self).__init__()
107
+ self.lrelu_slope = lrelu_slope
108
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
109
+ self.convs = nn.ModuleList([
110
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
111
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
112
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
113
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
114
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
115
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
116
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
117
+ ])
118
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
119
+
120
+ def forward(self, x):
121
+ fmap = []
122
+ for l in self.convs:
123
+ x = l(x)
124
+ x = F.leaky_relu(x, self.lrelu_slope)
125
+ fmap.append(x)
126
+ x = self.conv_post(x)
127
+ fmap.append(x)
128
+ x = torch.flatten(x, 1, -1)
129
+
130
+ return x, fmap
131
+
132
+
133
+ class MultiScaleDiscriminator(torch.nn.Module):
134
+ def __init__(self, in_channels: int = 1, nb_scales: int = 3):
135
+ super(MultiScaleDiscriminator, self).__init__()
136
+ self.discriminators = nn.ModuleList([
137
+ DiscriminatorS(use_spectral_norm=True),
138
+ DiscriminatorS(),
139
+ DiscriminatorS(),
140
+ ])
141
+ self.meanpools = nn.ModuleList(
142
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
143
+
144
+ def forward(self, x: torch.Tensor, return_intermediates: bool = True):
145
+ """Calculate forward propagation.
146
+
147
+ Args:
148
+ x (Tensor): Input noise signal (B, 1, T).
149
+
150
+ Returns:
151
+ List: List of list of each discriminator outputs, which consists of each
152
+ layer output tensors.
153
+
154
+ """
155
+ outs = []
156
+ for i, f in enumerate(self.discriminators):
157
+ if i != 0:
158
+ x = self.meanpools[i - 1](x)
159
+ if return_intermediates:
160
+ outs.append(f(x))
161
+ else:
162
+ outs.append(f(x)[0])
163
+
164
+ return outs
165
+
166
+
167
+ class DiscriminatorR(nn.Module):
168
+ def __init__(
169
+ self,
170
+ stft_params: tp.List[int],
171
+ lrelu_slope: float = 0.1,
172
+ use_spectral_norm: bool = False,
173
+ ):
174
+ super().__init__()
175
+
176
+ self.stft_params = stft_params
177
+ self.lrelu_slope = lrelu_slope
178
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
179
+
180
+ self.convs = nn.ModuleList([
181
+ norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
182
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
183
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
184
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
185
+ norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
186
+ ])
187
+ self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
188
+
189
+ def spectrogram(self, x):
190
+ n_fft, hop_length, win_length = self.stft_params
191
+ x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
192
+ x = x.squeeze(1)
193
+ spec = torch.stft(x, n_fft, hop_length=hop_length, win_length=win_length,
194
+ center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
195
+
196
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
197
+ mag = torch.norm(spec, p=2, dim =-1) #[B, F, TT]
198
+
199
+ return mag
200
+
201
+ def forward(self, x):
202
+ fmap = []
203
+
204
+ x = self.spectrogram(x).unsqueeze(1)
205
+ for l in self.convs:
206
+ x = l(x)
207
+ x = F.leaky_relu(x, self.lrelu_slope)
208
+ fmap.append(x)
209
+ x = self.conv_post(x)
210
+ fmap.append(x)
211
+
212
+ x = torch.flatten(x, 1, -1)
213
+
214
+ return x, fmap
215
+
216
+
217
+ class MultiResolutionDiscriminator(nn.Module):
218
+ def __init__(
219
+ self,
220
+ in_channels: int,
221
+ fft_sizes: tp.List[int] = [1024, 2048, 512],
222
+ hop_sizes: tp.List[int] = [120, 240, 50],
223
+ win_lengths: tp.List[int] = [600, 1200, 240],
224
+ lrelu_slope: float = 0.1,
225
+ ):
226
+ super().__init__()
227
+
228
+ self.discriminators = nn.ModuleList()
229
+
230
+ for fft, hop, win in zip(fft_sizes, hop_sizes, win_lengths):
231
+ self.discriminators.append(DiscriminatorR([fft, hop, win], lrelu_slope))
232
+
233
+ def forward(self, x: torch.Tensor, return_intermediates: bool = True):
234
+ """Calculate forward propagation.
235
+
236
+ Args:
237
+ x (Tensor): Input noise signal (B, 1, T).
238
+
239
+ Returns:
240
+ List: List of list of each discriminator outputs, which consists of each
241
+ layer output tensors.
242
+
243
+ """
244
+ outs = []
245
+ for f in self.discriminators:
246
+ if return_intermediates:
247
+ outs.append(f(x))
248
+ else:
249
+ outs.append(f(x)[0])
250
+
251
+ return outs
252
+
253
+
254
+ class MultipleDiscriminator(nn.Module):
255
+ def __init__(
256
+ self,
257
+ input_size: int = 1,
258
+ disc_conf_list: tp.List[tp.Dict[str, tp.Any]] = None,
259
+ ):
260
+ super().__init__()
261
+
262
+ self.support_disc_choices = dict(
263
+ mpd=MultiPeriodDiscriminator,
264
+ msd=MultiScaleDiscriminator,
265
+ mrd=MultiResolutionDiscriminator,
266
+ )
267
+
268
+ self.discriminators = nn.ModuleList()
269
+ self.discriminator_type_lst = []
270
+ for args in disc_conf_list:
271
+ assert "name" in args, "disc_conf must have `name` attr to specific disc type."
272
+ disc_type = args.pop("name")
273
+ assert disc_type in self.support_disc_choices, \
274
+ "Unsupported discriminator type, only support {}".format(
275
+ ",".join(self.support_disc_choices.keys())
276
+ )
277
+
278
+ disc_class = self.support_disc_choices[disc_type]
279
+ one_disc = disc_class(in_channels=input_size, **args)
280
+ self.discriminators.append(one_disc)
281
+ # add back to the args for dump config.yaml
282
+ args["name"] = disc_type
283
+ self.discriminator_type_lst.append(disc_type)
284
+
285
+ def get_discriminator_type_lst(self) -> tp.List[str]:
286
+ return self.discriminator_type_lst
287
+
288
+ def forward(self, x, return_intermediates=True):
289
+ retval = []
290
+ for disc in self.discriminators:
291
+ out = disc(x, return_intermediates=return_intermediates)
292
+ if isinstance(out, tuple):
293
+ retval.append(out)
294
+ elif isinstance(out, list):
295
+ retval.extend(out)
296
+ else:
297
+ raise TypeError("The return value of discriminator must be tuple or list[tuple]")
298
+
299
+ return retval
funcineforge/models/modules/hifigan/generator.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """hifigan based generator implementation.
2
+
3
+ This code is modified from https://github.com/jik876/hifi-gan
4
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
5
+ https://github.com/NVIDIA/BigVGAN
6
+
7
+ """
8
+
9
+ import typing as tp
10
+
11
+ import numpy as np
12
+ from scipy.signal import get_window
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from torch.nn import Conv1d, ConvTranspose1d
18
+ from torch.nn.utils import weight_norm
19
+ from torch.nn.utils import remove_weight_norm
20
+
21
+ from funcineforge.models.modules.hifigan import get_padding, init_weights
22
+ from funcineforge.models.modules.hifigan.activations import Snake, SnakeBeta
23
+ from funcineforge.models.modules.hifigan.nsf_utils import SourceModule, SourceModuleHnNSF
24
+
25
+
26
+ class ResBlock(torch.nn.Module):
27
+ """Residual block module in HiFiGAN/BigVGAN."""
28
+ def __init__(
29
+ self,
30
+ channels: int = 512,
31
+ kernel_size: int = 3,
32
+ dilations: tp.List[int] = [1, 3, 5],
33
+ use_additional_convs: bool = True,
34
+ nonlinear_activation: str = "LeakyReLU",
35
+ nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
36
+ ):
37
+ super(ResBlock, self).__init__()
38
+ self.use_additional_convs = use_additional_convs
39
+
40
+ self.convs1 = nn.ModuleList()
41
+ if use_additional_convs:
42
+ self.convs2 = nn.ModuleList()
43
+
44
+ for dilation in dilations:
45
+ self.convs1.append(
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation,
53
+ padding=get_padding(kernel_size, dilation)
54
+ )
55
+ )
56
+ )
57
+
58
+ if use_additional_convs:
59
+ self.convs2.append(
60
+ weight_norm(
61
+ Conv1d(
62
+ channels,
63
+ channels,
64
+ kernel_size,
65
+ 1,
66
+ dilation=1,
67
+ padding=get_padding(kernel_size, 1)
68
+ )
69
+ )
70
+ )
71
+
72
+ self.convs1.apply(init_weights)
73
+ if use_additional_convs:
74
+ self.convs2.apply(init_weights)
75
+
76
+ if nonlinear_activation == "LeakyReLU":
77
+ self.activations1 = nn.ModuleList([
78
+ nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
79
+ for _ in range(len(self.convs1))
80
+ ])
81
+ if use_additional_convs:
82
+ self.activations2 = nn.ModuleList([
83
+ nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
84
+ for _ in range(len(self.convs2))
85
+ ])
86
+
87
+ elif nonlinear_activation == "Snake":
88
+ self.activations1 = nn.ModuleList([
89
+ Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
90
+ for _ in range(len(self.convs1))
91
+ ])
92
+ if use_additional_convs:
93
+ self.activations2 = nn.ModuleList([
94
+ Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
95
+ for _ in range(len(self.convs2))
96
+ ])
97
+
98
+ elif nonlinear_activation == "SnakeBeta":
99
+ self.activations1 = nn.ModuleList([
100
+ SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
101
+ for _ in range(len(self.convs1))
102
+ ])
103
+ if use_additional_convs:
104
+ self.activations2 = nn.ModuleList([
105
+ SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
106
+ for _ in range(len(self.convs2))
107
+ ])
108
+
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
113
+ for idx in range(len(self.convs1)):
114
+ xt = self.activations1[idx](x)
115
+ xt = self.convs1[idx](xt)
116
+ if self.use_additional_convs:
117
+ xt = self.activations2[idx](xt)
118
+ xt = self.convs2[idx](xt)
119
+ x = xt + x
120
+ return x
121
+
122
+ def remove_weight_norm(self):
123
+ for idx in range(len(self.convs1)):
124
+ remove_weight_norm(self.convs1[idx])
125
+ if self.use_additional_convs:
126
+ remove_weight_norm(self.convs2[idx])
127
+
128
+
129
+ class HifiGenerator(nn.Module):
130
+ def __init__(
131
+ self,
132
+ in_channels: int = 80,
133
+ base_channels: int = 512,
134
+ global_channels: int = -1,
135
+ upsample_rates: tp.List[int] = [8, 8, 2, 2],
136
+ upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
137
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
138
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
139
+ resblock_nonlinear_activation: str = "LeakyReLU",
140
+ resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
141
+ use_additional_convs: bool = True,
142
+ cond_in_each_up_layer: bool = False,
143
+ lrelu_slope: float = 0.1,
144
+ act_pre_each_up_layer: bool = True
145
+ ):
146
+ super(HifiGenerator, self).__init__()
147
+
148
+ self.out_channels = 1
149
+ self.global_channels = global_channels
150
+ self.use_additional_convs = use_additional_convs
151
+ self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
152
+ self.lrelu_slope = lrelu_slope
153
+ self.act_pre_each_up_layer = act_pre_each_up_layer
154
+
155
+ self.num_kernels = len(resblock_kernel_sizes)
156
+ self.num_upsamples = len(upsample_rates)
157
+
158
+ self.conv_pre = weight_norm(
159
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
160
+ )
161
+
162
+ self.ups = nn.ModuleList()
163
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
164
+ self.ups.append(
165
+ weight_norm(
166
+ ConvTranspose1d(
167
+ base_channels // (2**i),
168
+ base_channels // (2**(i + 1)),
169
+ k,
170
+ u,
171
+ padding=(k - u) // 2,
172
+ )
173
+ )
174
+ )
175
+
176
+ self.resblocks = nn.ModuleList()
177
+ for i in range(len(self.ups)):
178
+ ch = base_channels // (2**(i + 1))
179
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
180
+ self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
181
+ resblock_nonlinear_activation,
182
+ resblock_nonlinear_activation_params))
183
+
184
+ if self.global_channels > 0:
185
+ self.conv_global_cond = weight_norm(
186
+ Conv1d(global_channels, base_channels, 1)
187
+ )
188
+ self.conv_global_cond.apply(init_weights)
189
+
190
+ if self.cond_in_each_up_layer:
191
+ self.conv_conds = nn.ModuleList()
192
+ for i in range(len(self.ups)):
193
+ self.conv_conds.append(weight_norm(
194
+ nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
195
+ )
196
+ self.conv_conds.apply(init_weights)
197
+
198
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
199
+ self.ups.apply(init_weights)
200
+ self.conv_post.apply(init_weights)
201
+
202
+ def output_size(self):
203
+ return self.out_channels
204
+
205
+ def forward(self, x: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
206
+ # x in (B, in_channels, T), g in (B, global_channels, 1)
207
+ x = self.conv_pre(x)
208
+ if self.global_channels > 0 and g is not None:
209
+ x = x + self.conv_global_cond(g)
210
+
211
+ for i in range(self.num_upsamples):
212
+ if self.act_pre_each_up_layer:
213
+ x = F.leaky_relu(x, self.lrelu_slope)
214
+ x = self.ups[i](x)
215
+
216
+ if self.cond_in_each_up_layer and g is not None:
217
+ x = x + self.conv_conds[i](g)
218
+
219
+ xs = None
220
+ for j in range(self.num_kernels):
221
+ if xs is None:
222
+ xs = self.resblocks[i * self.num_kernels + j](x)
223
+ else:
224
+ xs += self.resblocks[i * self.num_kernels + j](x)
225
+ x = xs / self.num_kernels
226
+
227
+ x = F.leaky_relu(x)
228
+ x = self.conv_post(x)
229
+ x = torch.tanh(x)
230
+
231
+ return x
232
+
233
+ def remove_weight_norm(self):
234
+ print('Removing weight norm...')
235
+ for l in self.ups:
236
+ remove_weight_norm(l)
237
+ for l in self.resblocks:
238
+ l.remove_weight_norm()
239
+ remove_weight_norm(self.conv_pre)
240
+ remove_weight_norm(self.conv_post)
241
+ if self.global_channels > 0:
242
+ remove_weight_norm(self.conv_global_cond)
243
+ if self.cond_in_each_up_layer:
244
+ for l in self.conv_conds:
245
+ remove_weight_norm(l)
246
+
247
+
248
+ class NsfHifiGenerator(nn.Module):
249
+ """
250
+ Neural Source Filter + HifiGan
251
+ """
252
+ def __init__(
253
+ self,
254
+ in_channels: int = 80,
255
+ base_channels: int = 512,
256
+ global_channels: int = -1,
257
+ nb_harmonics: int = 7,
258
+ sampling_rate: int = 22050,
259
+ nsf_alpha: float = 0.1,
260
+ nsf_sigma: float = 0.003,
261
+ nsf_voiced_threshold: float = 10,
262
+ upsample_rates: tp.List[int] = [8, 8, 2, 2],
263
+ upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
264
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
265
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
266
+ resblock_nonlinear_activation: str = "LeakyReLU",
267
+ resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
268
+ use_additional_convs: bool = True,
269
+ cond_in_each_up_layer: bool = False,
270
+ lrelu_slope: float = 0.1,
271
+ act_pre_each_up_layer: bool = True
272
+ ):
273
+ super(NsfHifiGenerator, self).__init__()
274
+
275
+ self.out_channels = 1
276
+ self.global_channels = global_channels
277
+ self.nb_harmonics = nb_harmonics
278
+ self.sampling_rate = sampling_rate
279
+ self.use_additional_convs = use_additional_convs
280
+ self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
281
+ self.lrelu_slope = lrelu_slope
282
+ self.act_pre_each_up_layer = act_pre_each_up_layer
283
+
284
+ self.num_kernels = len(resblock_kernel_sizes)
285
+ self.num_upsamples = len(upsample_rates)
286
+
287
+ self.source_module = SourceModule(nb_harmonics, np.cumprod(upsample_rates)[-1],
288
+ sampling_rate, nsf_alpha, nsf_sigma, nsf_voiced_threshold)
289
+ self.conv_pre = weight_norm(
290
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
291
+ )
292
+
293
+ # Up
294
+ self.ups = nn.ModuleList()
295
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
296
+ self.ups.append(
297
+ weight_norm(
298
+ ConvTranspose1d(
299
+ base_channels // (2**i),
300
+ base_channels // (2**(i + 1)),
301
+ k,
302
+ u,
303
+ padding=(k - u) // 2,
304
+ )
305
+ )
306
+ )
307
+ # Down
308
+ self.source_downs = nn.ModuleList()
309
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
310
+ downsample_cum_rates = np.cumprod(downsample_rates)
311
+ for i, u in enumerate(downsample_cum_rates[::-1]):
312
+ if (u == 1):
313
+ self.source_downs.append(
314
+ weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), 1, 1))
315
+ )
316
+ else:
317
+ self.source_downs.append(
318
+ weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2)))
319
+ )
320
+
321
+ self.resblocks = nn.ModuleList()
322
+ for i in range(len(self.ups)):
323
+ ch = base_channels // (2**(i + 1))
324
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
325
+ self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
326
+ resblock_nonlinear_activation,
327
+ resblock_nonlinear_activation_params))
328
+
329
+ if self.global_channels > 0:
330
+ self.conv_global_cond = weight_norm(
331
+ Conv1d(global_channels, base_channels, 1)
332
+ )
333
+ self.conv_global_cond.apply(init_weights)
334
+
335
+ if self.cond_in_each_up_layer:
336
+ self.conv_conds = nn.ModuleList()
337
+ for i in range(len(self.ups)):
338
+ self.conv_conds.append(weight_norm(
339
+ nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
340
+ )
341
+ self.conv_conds.apply(init_weights)
342
+
343
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
344
+ self.ups.apply(init_weights)
345
+ self.conv_post.apply(init_weights)
346
+
347
+ def output_size(self):
348
+ return self.out_channels
349
+
350
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
351
+ return self.source_module(f0.unsqueeze(1))
352
+
353
+ def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
354
+ # x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
355
+
356
+ s = self._f02source(f0)
357
+
358
+ x = self.conv_pre(x)
359
+ if self.global_channels > 0 and g is not None:
360
+ x = x + self.conv_global_cond(g)
361
+
362
+ for i in range(self.num_upsamples):
363
+ if self.act_pre_each_up_layer:
364
+ x = F.leaky_relu(x, self.lrelu_slope)
365
+ x = self.ups[i](x)
366
+
367
+ if self.cond_in_each_up_layer and g is not None:
368
+ x = x + self.conv_conds[i](g)
369
+
370
+ # fusion
371
+ x = x + self.source_downs[i](s)
372
+
373
+ xs = None
374
+ for j in range(self.num_kernels):
375
+ if xs is None:
376
+ xs = self.resblocks[i * self.num_kernels + j](x)
377
+ else:
378
+ xs += self.resblocks[i * self.num_kernels + j](x)
379
+ x = xs / self.num_kernels
380
+
381
+ x = F.leaky_relu(x)
382
+ x = self.conv_post(x)
383
+ x = torch.tanh(x)
384
+
385
+ return x
386
+
387
+ def remove_weight_norm(self):
388
+ print('Removing weight norm...')
389
+ for l in self.ups:
390
+ remove_weight_norm(l)
391
+ for l in self.resblocks:
392
+ l.remove_weight_norm()
393
+ remove_weight_norm(self.conv_pre)
394
+ remove_weight_norm(self.conv_post)
395
+ if self.global_channels > 0:
396
+ remove_weight_norm(self.conv_global_cond)
397
+ if self.cond_in_each_up_layer:
398
+ for l in self.conv_conds:
399
+ remove_weight_norm(l)
400
+ self.source_module.remove_weight_norm()
401
+ for l in self.source_downs:
402
+ remove_weight_norm(l)
403
+
404
+
405
+ class HiFTGenerator(nn.Module):
406
+ """
407
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
408
+ https://arxiv.org/abs/2309.09493
409
+ """
410
+ def __init__(
411
+ self,
412
+ in_channels: int = 80,
413
+ base_channels: int = 512,
414
+ global_channels: int = -1,
415
+ nb_harmonics: int = 8,
416
+ sampling_rate: int = 22050,
417
+ nsf_alpha: float = 0.1,
418
+ nsf_sigma: float = 0.003,
419
+ nsf_voiced_threshold: float = 10,
420
+ upsample_rates: tp.List[int] = [8, 8],
421
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
422
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
423
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
424
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
425
+ resblock_nonlinear_activation: str = "Snake",
426
+ resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
427
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
428
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
429
+ source_resblock_nonlinear_activation: str = "Snake",
430
+ source_resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
431
+ use_additional_convs: bool = True,
432
+ cond_in_each_up_layer: bool = False,
433
+ lrelu_slope: float = 0.1,
434
+ act_pre_each_up_layer: bool = True,
435
+ audio_limit: float = 0.99,
436
+ ):
437
+ super(HiFTGenerator, self).__init__()
438
+
439
+ self.out_channels = 1
440
+ self.global_channels = global_channels
441
+ self.nb_harmonics = nb_harmonics
442
+ self.sampling_rate = sampling_rate
443
+ self.istft_params = istft_params
444
+ self.use_additional_convs = use_additional_convs
445
+ self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
446
+ self.lrelu_slope = lrelu_slope
447
+ self.act_pre_each_up_layer = act_pre_each_up_layer
448
+ self.audio_limit = audio_limit
449
+
450
+ self.num_kernels = len(resblock_kernel_sizes)
451
+ self.num_upsamples = len(upsample_rates)
452
+ self.m_source = SourceModuleHnNSF(
453
+ sampling_rate=sampling_rate,
454
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
455
+ harmonic_num=nb_harmonics,
456
+ sine_amp=nsf_alpha,
457
+ add_noise_std=nsf_sigma,
458
+ voiced_threshod=nsf_voiced_threshold)
459
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
460
+
461
+ self.conv_pre = weight_norm(
462
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
463
+ )
464
+
465
+ # Up
466
+ self.ups = nn.ModuleList()
467
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
468
+ self.ups.append(
469
+ weight_norm(
470
+ ConvTranspose1d(
471
+ base_channels // (2**i),
472
+ base_channels // (2**(i + 1)),
473
+ k,
474
+ u,
475
+ padding=(k - u) // 2,
476
+ )
477
+ )
478
+ )
479
+
480
+ # Down
481
+ self.source_downs = nn.ModuleList()
482
+ self.source_resblocks = nn.ModuleList()
483
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
484
+ downsample_cum_rates = np.cumprod(downsample_rates)
485
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
486
+ source_resblock_dilation_sizes)):
487
+ if u == 1:
488
+ self.source_downs.append(
489
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
490
+ )
491
+ else:
492
+ self.source_downs.append(
493
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2))
494
+ )
495
+
496
+ self.source_resblocks.append(
497
+ ResBlock(base_channels // (2 ** (i + 1)), k, d,
498
+ use_additional_convs, source_resblock_nonlinear_activation,
499
+ source_resblock_nonlinear_activation_params)
500
+ )
501
+
502
+ self.resblocks = nn.ModuleList()
503
+ for i in range(len(self.ups)):
504
+ ch = base_channels // (2**(i + 1))
505
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
506
+ self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
507
+ resblock_nonlinear_activation,
508
+ resblock_nonlinear_activation_params))
509
+
510
+ if self.global_channels > 0:
511
+ self.conv_global_cond = weight_norm(
512
+ Conv1d(global_channels, base_channels, 1)
513
+ )
514
+ self.conv_global_cond.apply(init_weights)
515
+
516
+ if self.cond_in_each_up_layer:
517
+ self.conv_conds = nn.ModuleList()
518
+ for i in range(len(self.ups)):
519
+ self.conv_conds.append(weight_norm(
520
+ nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
521
+ )
522
+ self.conv_conds.apply(init_weights)
523
+
524
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
525
+ self.ups.apply(init_weights)
526
+ self.conv_post.apply(init_weights)
527
+
528
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
529
+ window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
530
+ self.register_buffer("stft_window", window)
531
+
532
+ def output_size(self):
533
+ return self.out_channels
534
+
535
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
536
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
537
+
538
+ har_source, _, _ = self.m_source(f0)
539
+ return har_source.transpose(1, 2)
540
+
541
+ def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
542
+ # x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
543
+
544
+ s = self._f02source(f0)
545
+
546
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
547
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
548
+
549
+ x = self.conv_pre(x)
550
+ if self.global_channels > 0 and g is not None:
551
+ x = x + self.conv_global_cond(g)
552
+
553
+ for i in range(self.num_upsamples):
554
+ if self.act_pre_each_up_layer:
555
+ x = F.leaky_relu(x, self.lrelu_slope)
556
+ x = self.ups[i](x)
557
+
558
+ if self.cond_in_each_up_layer and g is not None:
559
+ x = x + self.conv_conds[i](g)
560
+
561
+ if i == self.num_upsamples - 1:
562
+ x = self.reflection_pad(x)
563
+
564
+ # fusion
565
+ si = self.source_downs[i](s_stft)
566
+ si = self.source_resblocks[i](si)
567
+ x = x + si
568
+
569
+ xs = None
570
+ for j in range(self.num_kernels):
571
+ if xs is None:
572
+ xs = self.resblocks[i * self.num_kernels + j](x)
573
+ else:
574
+ xs += self.resblocks[i * self.num_kernels + j](x)
575
+ x = xs / self.num_kernels
576
+
577
+ x = F.leaky_relu(x)
578
+ x = self.conv_post(x)
579
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
580
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
581
+
582
+ x = self._istft(magnitude, phase)
583
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
584
+ return x
585
+
586
+ def remove_weight_norm(self):
587
+ print('Removing weight norm...')
588
+ for l in self.ups:
589
+ remove_weight_norm(l)
590
+ for l in self.resblocks:
591
+ l.remove_weight_norm()
592
+ remove_weight_norm(self.conv_pre)
593
+ remove_weight_norm(self.conv_post)
594
+ if self.global_channels > 0:
595
+ remove_weight_norm(self.conv_global_cond)
596
+ if self.cond_in_each_up_layer:
597
+ for l in self.conv_conds:
598
+ remove_weight_norm(l)
599
+ self.source_module.remove_weight_norm()
600
+ for l in self.source_downs:
601
+ remove_weight_norm(l)
602
+ for l in self.source_resblocks:
603
+ l.remove_weight_norm()
604
+
605
+ def _stft(self, x):
606
+ spec = torch.stft(
607
+ x,
608
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window,
609
+ return_complex=True)
610
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
611
+ return spec[...,0], spec[...,1]
612
+
613
+ def _istft(self, magnitude, phase):
614
+ magnitude = torch.clip(magnitude, max=1e2)
615
+ real = magnitude * torch.cos(phase)
616
+ img = magnitude * torch.sin(phase)
617
+ inverse_transform = torch.istft(
618
+ # torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1),
619
+ torch.complex(real, img),
620
+ self.istft_params["n_fft"], self.istft_params["hop_len"],
621
+ self.istft_params["n_fft"], window=self.stft_window,
622
+ return_complex=False
623
+ )
624
+
625
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
funcineforge/models/modules/hifigan/mel_spectrum.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import numpy as np
4
+ from librosa.filters import mel as librosa_mel_fn
5
+
6
+
7
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
8
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
9
+
10
+
11
+ def dynamic_range_decompression(x, C=1):
12
+ return np.exp(x) / C
13
+
14
+
15
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
16
+ return torch.log(torch.clamp(x, min=clip_val) * C)
17
+
18
+
19
+ def dynamic_range_decompression_torch(x, C=1):
20
+ return torch.exp(x) / C
21
+
22
+
23
+ def spectral_normalize_torch(magnitudes):
24
+ output = dynamic_range_compression_torch(magnitudes)
25
+ return output
26
+
27
+
28
+ def spectral_de_normalize_torch(magnitudes):
29
+ output = dynamic_range_decompression_torch(magnitudes)
30
+ return output
31
+
32
+
33
+ mel_basis = {}
34
+ hann_window = {}
35
+
36
+
37
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
38
+ if torch.min(y) < -1.:
39
+ print('min value is ', torch.min(y))
40
+ if torch.max(y) > 1.:
41
+ print('max value is ', torch.max(y))
42
+
43
+ global mel_basis, hann_window
44
+ if fmax not in mel_basis:
45
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
46
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
47
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
48
+
49
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
50
+ y = y.squeeze(1)
51
+
52
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
53
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
54
+
55
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
56
+
57
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
58
+ spec = spectral_normalize_torch(spec)
59
+
60
+ return spec
61
+
62
+
63
+ def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
64
+ if torch.min(y) < -1.:
65
+ print('min value is ', torch.min(y))
66
+ if torch.max(y) > 1.:
67
+ print('max value is ', torch.max(y))
68
+
69
+ global mel_basis, hann_window
70
+ if fmax not in mel_basis:
71
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
72
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
73
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
74
+
75
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
76
+ y = y.squeeze(1)
77
+
78
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
79
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
80
+
81
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
82
+ spec = spectral_normalize_torch(spec)
83
+
84
+ return spec
85
+
86
+
87
+ def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
88
+ global mel_basis, hann_window
89
+ spec = spectral_de_normalize_torch(spec)
90
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec)
91
+ spec = spectral_normalize_torch(spec)
92
+
93
+ return spec
funcineforge/models/modules/hifigan/nsf_utils.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Neural Source Filter based modules implementation.
3
+
4
+ Neural source-filter waveform models for statistical parametric speech synthesis
5
+
6
+ """
7
+
8
+ import numpy as np
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.nn.utils import weight_norm, remove_weight_norm
15
+ from torch.distributions.uniform import Uniform
16
+ from torch.distributions.normal import Normal
17
+
18
+ class SineGen(torch.nn.Module):
19
+ """ Definition of sine generator
20
+ SineGen(samp_rate, harmonic_num = 0,
21
+ sine_amp = 0.1, noise_std = 0.003,
22
+ voiced_threshold = 0,
23
+ flag_for_pulse=False)
24
+ samp_rate: sampling rate in Hz
25
+ harmonic_num: number of harmonic overtones (default 0)
26
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
27
+ noise_std: std of Gaussian noise (default 0.003)
28
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
29
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
30
+ Note: when flag_for_pulse is True, the first time step of a voiced
31
+ segment is always sin(np.pi) or cos(0)
32
+ """
33
+
34
+ def __init__(self, samp_rate, harmonic_num=0,
35
+ sine_amp=0.1, noise_std=0.003,
36
+ voiced_threshold=0):
37
+ super(SineGen, self).__init__()
38
+ self.sine_amp = sine_amp
39
+ self.noise_std = noise_std
40
+ self.harmonic_num = harmonic_num
41
+ self.sampling_rate = samp_rate
42
+ self.voiced_threshold = voiced_threshold
43
+
44
+ def _f02uv(self, f0):
45
+ # generate uv signal
46
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
47
+ return uv
48
+
49
+ @torch.no_grad()
50
+ def forward(self, f0):
51
+ """
52
+ :param f0: [B, 1, sample_len], Hz
53
+ :return: [B, 1, sample_len]
54
+ """
55
+
56
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
57
+ for i in range(self.harmonic_num + 1):
58
+ F_mat[:, i:i+1, :] = f0 * (i+1) / self.sampling_rate
59
+
60
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
61
+ u_dist = Uniform(low=-np.pi, high=np.pi)
62
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
63
+ phase_vec[:, 0, :] = 0
64
+
65
+ # generate sine waveforms
66
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
67
+
68
+ # generate uv signal
69
+ uv = self._f02uv(f0)
70
+
71
+ # noise: for unvoiced should be similar to sine_amp
72
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
73
+ # . for voiced regions is self.noise_std
74
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
75
+ noise = noise_amp * torch.randn_like(sine_waves)
76
+
77
+ # first: set the unvoiced part to 0 by uv
78
+ # then: additive noise
79
+ sine_waves = sine_waves * uv + noise
80
+ return sine_waves, uv, noise
81
+
82
+
83
+ class SourceModuleHnNSF(torch.nn.Module):
84
+ """ SourceModule for hn-nsf
85
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
86
+ add_noise_std=0.003, voiced_threshod=0)
87
+ sampling_rate: sampling_rate in Hz
88
+ harmonic_num: number of harmonic above F0 (default: 0)
89
+ sine_amp: amplitude of sine source signal (default: 0.1)
90
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
91
+ note that amplitude of noise in unvoiced is decided
92
+ by sine_amp
93
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
94
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
95
+ F0_sampled (batchsize, length, 1)
96
+ Sine_source (batchsize, length, 1)
97
+ noise_source (batchsize, length 1)
98
+ uv (batchsize, length, 1)
99
+ """
100
+
101
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
102
+ add_noise_std=0.003, voiced_threshod=0):
103
+ super(SourceModuleHnNSF, self).__init__()
104
+
105
+ self.sine_amp = sine_amp
106
+ self.noise_std = add_noise_std
107
+
108
+ # to produce sine waveforms
109
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
110
+ sine_amp, add_noise_std, voiced_threshod)
111
+
112
+ # to merge source harmonics into a single excitation
113
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
114
+ self.l_tanh = torch.nn.Tanh()
115
+
116
+ def forward(self, x):
117
+ """
118
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
119
+ F0_sampled (batchsize, length, 1)
120
+ Sine_source (batchsize, length, 1)
121
+ noise_source (batchsize, length 1)
122
+ """
123
+ # source for harmonic branch
124
+ with torch.no_grad():
125
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1,2))
126
+ sine_wavs = sine_wavs.transpose(1,2)
127
+ uv = uv.transpose(1,2)
128
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
129
+
130
+ # source for noise branch, in the same shape as uv
131
+ noise = torch.randn_like(uv) * self.sine_amp / 3
132
+ return sine_merge, noise, uv
133
+
134
+
135
+ class SourceModule(torch.nn.Module):
136
+ def __init__(self,
137
+ nb_harmonics: int,
138
+ upsample_ratio: int,
139
+ sampling_rate: int,
140
+ alpha: float = 0.1,
141
+ sigma: float = 0.003,
142
+ voiced_threshold: float = 10
143
+ ):
144
+ super(SourceModule, self).__init__()
145
+
146
+ self.nb_harmonics = nb_harmonics
147
+ self.upsample_ratio = upsample_ratio
148
+ self.sampling_rate = sampling_rate
149
+ self.alpha = alpha
150
+ self.sigma = sigma
151
+ self.voiced_threshold = voiced_threshold
152
+
153
+ self.ffn = nn.Sequential(
154
+ weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
155
+ nn.Tanh())
156
+
157
+ def f02uv(self, f0):
158
+ # generate uv signal
159
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
160
+ return uv
161
+
162
+ def forward(self, f0):
163
+ """
164
+ :param f0: [B, 1, frame_len], Hz
165
+ :return: [B, 1, sample_len]
166
+ """
167
+ with torch.no_grad():
168
+ uv = self.f02uv(f0)
169
+ f0_samples = F.interpolate(f0, scale_factor=(self.upsample_ratio), mode='nearest')
170
+ uv_samples = F.interpolate(uv, scale_factor=(self.upsample_ratio), mode='nearest')
171
+
172
+ F_mat = torch.zeros((f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(f0_samples.device)
173
+ for i in range(self.nb_harmonics + 1):
174
+ F_mat[:, i:i+1, :] = f0_samples * (i+1) / self.sampling_rate
175
+
176
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
177
+ u_dist = Uniform(low=-np.pi, high=np.pi)
178
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.nb_harmonics + 1, 1)).to(F_mat.device)
179
+ phase_vec[:, 0, :] = 0
180
+
181
+ n_dist = Normal(loc=0., scale=self.sigma)
182
+ noise = n_dist.sample(sample_shape=(f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(F_mat.device)
183
+
184
+ e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
185
+ e_unvoice = self.alpha / 3 / self.sigma * noise
186
+
187
+ e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
188
+
189
+ return self.ffn(e)
190
+
191
+ def remove_weight_norm(self):
192
+ remove_weight_norm(self.ffn[0])
193
+
194
+
195
+ class ConvRNNF0Predictor(nn.Module):
196
+ def __init__(self,
197
+ num_class: int = 1,
198
+ in_channels: int = 80,
199
+ cond_channels: int = 512,
200
+ use_cond_rnn: bool = True,
201
+ bidirectional_rnn: bool = False,
202
+ ):
203
+
204
+ super().__init__()
205
+
206
+ self.num_class = num_class
207
+ self.use_cond_rnn = use_cond_rnn
208
+
209
+ self.condnet = nn.Sequential(
210
+ weight_norm(
211
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
212
+ ),
213
+ nn.ELU(),
214
+ weight_norm(
215
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
216
+ ),
217
+ nn.ELU(),
218
+ weight_norm(
219
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
220
+ ),
221
+ nn.ELU(),
222
+ weight_norm(
223
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
224
+ ),
225
+ nn.ELU(),
226
+ weight_norm(
227
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
228
+ ),
229
+ nn.ELU(),
230
+ )
231
+
232
+ if self.use_cond_rnn:
233
+ self.rnn = nn.GRU(
234
+ cond_channels,
235
+ cond_channels // 2 if bidirectional_rnn else cond_channels,
236
+ num_layers=1,
237
+ batch_first=True,
238
+ bidirectional=bidirectional_rnn,
239
+ )
240
+
241
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
242
+
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ x = self.condnet(x)
245
+ if self.use_cond_rnn:
246
+ x, _ = self.rnn(x.transpose(1, 2))
247
+ else:
248
+ x = x.transpose(1, 2)
249
+
250
+ return torch.abs(self.classifier(x).squeeze(-1))
251
+
252
+
253
+
funcineforge/models/specaug/__init__.py ADDED
File without changes
funcineforge/models/specaug/mask_along_axis.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from typing import Sequence
4
+ from typing import Union
5
+
6
+
7
+ def mask_along_axis(
8
+ spec: torch.Tensor,
9
+ spec_lengths: torch.Tensor,
10
+ mask_width_range: Sequence[int] = (0, 30),
11
+ dim: int = 1,
12
+ num_mask: int = 2,
13
+ replace_with_zero: bool = True,
14
+ fill_value: float = 0.0,
15
+ ):
16
+ """Apply mask along the specified direction.
17
+
18
+ Args:
19
+ spec: (Batch, Length, Freq)
20
+ spec_lengths: (Length): Not using lengths in this implementation
21
+ mask_width_range: Select the width randomly between this range
22
+ """
23
+
24
+ org_size = spec.size()
25
+ if spec.dim() == 4:
26
+ # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
27
+ spec = spec.view(-1, spec.size(2), spec.size(3))
28
+
29
+ B = spec.shape[0]
30
+ # D = Length or Freq
31
+ D = spec.shape[dim]
32
+ # mask_length: (B, num_mask, 1)
33
+ mask_length = torch.randint(
34
+ mask_width_range[0],
35
+ mask_width_range[1],
36
+ (B, num_mask),
37
+ device=spec.device,
38
+ ).unsqueeze(2)
39
+
40
+ # mask_pos: (B, num_mask, 1)
41
+ mask_pos = torch.randint(
42
+ 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
43
+ ).unsqueeze(2)
44
+
45
+ # aran: (1, 1, D)
46
+ aran = torch.arange(D, device=spec.device)[None, None, :]
47
+ # mask: (Batch, num_mask, D)
48
+ mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
49
+ # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
50
+ mask = mask.any(dim=1)
51
+ if dim == 1:
52
+ # mask: (Batch, Length, 1)
53
+ mask = mask.unsqueeze(2)
54
+ elif dim == 2:
55
+ # mask: (Batch, 1, Freq)
56
+ mask = mask.unsqueeze(1)
57
+
58
+ if replace_with_zero:
59
+ value = fill_value
60
+ else:
61
+ value = spec.mean()
62
+
63
+ spec = spec.masked_fill(mask, value)
64
+ spec = spec.view(*org_size)
65
+ return spec, spec_lengths
66
+
67
+
68
+ class MaskAlongAxis(torch.nn.Module):
69
+ def __init__(
70
+ self,
71
+ mask_width_range: Union[int, Sequence[int]] = (0, 30),
72
+ num_mask: int = 2,
73
+ dim: Union[int, str] = "time",
74
+ replace_with_zero: bool = True,
75
+ fill_value: float = 0.0,
76
+ ):
77
+ if isinstance(mask_width_range, int):
78
+ mask_width_range = (0, mask_width_range)
79
+ if len(mask_width_range) != 2:
80
+ raise TypeError(
81
+ f"mask_width_range must be a tuple of int and int values: " f"{mask_width_range}",
82
+ )
83
+
84
+ assert mask_width_range[1] > mask_width_range[0]
85
+ if isinstance(dim, str):
86
+ if dim == "time":
87
+ dim = 1
88
+ elif dim == "freq":
89
+ dim = 2
90
+ else:
91
+ raise ValueError("dim must be int, 'time' or 'freq'")
92
+ if dim == 1:
93
+ self.mask_axis = "time"
94
+ elif dim == 2:
95
+ self.mask_axis = "freq"
96
+ else:
97
+ self.mask_axis = "unknown"
98
+
99
+ super().__init__()
100
+ self.mask_width_range = mask_width_range
101
+ self.num_mask = num_mask
102
+ self.dim = dim
103
+ self.replace_with_zero = replace_with_zero
104
+ self.fill_value = fill_value
105
+
106
+ def extra_repr(self):
107
+ return (
108
+ f"mask_width_range={self.mask_width_range}, "
109
+ f"num_mask={self.num_mask}, axis={self.mask_axis}"
110
+ )
111
+
112
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
113
+ """Forward function.
114
+
115
+ Args:
116
+ spec: (Batch, Length, Freq)
117
+ """
118
+
119
+ return mask_along_axis(
120
+ spec,
121
+ spec_lengths,
122
+ mask_width_range=self.mask_width_range,
123
+ dim=self.dim,
124
+ num_mask=self.num_mask,
125
+ replace_with_zero=self.replace_with_zero,
126
+ fill_value=self.fill_value,
127
+ )
128
+
129
+
130
+ class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
131
+ """Mask input spec along a specified axis with variable maximum width.
132
+
133
+ Formula:
134
+ max_width = max_width_ratio * seq_len
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
140
+ num_mask: int = 2,
141
+ dim: Union[int, str] = "time",
142
+ replace_with_zero: bool = True,
143
+ fill_value: float = 0.0,
144
+ ):
145
+ if isinstance(mask_width_ratio_range, float):
146
+ mask_width_ratio_range = (0.0, mask_width_ratio_range)
147
+ if len(mask_width_ratio_range) != 2:
148
+ raise TypeError(
149
+ f"mask_width_ratio_range must be a tuple of float and float values: "
150
+ f"{mask_width_ratio_range}",
151
+ )
152
+
153
+ assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
154
+ if isinstance(dim, str):
155
+ if dim == "time":
156
+ dim = 1
157
+ elif dim == "freq":
158
+ dim = 2
159
+ else:
160
+ raise ValueError("dim must be int, 'time' or 'freq'")
161
+ if dim == 1:
162
+ self.mask_axis = "time"
163
+ elif dim == 2:
164
+ self.mask_axis = "freq"
165
+ else:
166
+ self.mask_axis = "unknown"
167
+
168
+ super().__init__()
169
+ self.mask_width_ratio_range = mask_width_ratio_range
170
+ self.num_mask = num_mask
171
+ self.dim = dim
172
+ self.replace_with_zero = replace_with_zero
173
+ self.fill_value = fill_value
174
+
175
+ def extra_repr(self):
176
+ return (
177
+ f"mask_width_ratio_range={self.mask_width_ratio_range}, "
178
+ f"num_mask={self.num_mask}, axis={self.mask_axis}"
179
+ )
180
+
181
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
182
+ """Forward function.
183
+
184
+ Args:
185
+ spec: (Batch, Length, Freq)
186
+ """
187
+
188
+ max_seq_len = spec.shape[self.dim]
189
+ min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
190
+ min_mask_width = max([0, min_mask_width])
191
+ max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
192
+ max_mask_width = min([max_seq_len, max_mask_width])
193
+
194
+ if max_mask_width > min_mask_width:
195
+ return mask_along_axis(
196
+ spec,
197
+ spec_lengths,
198
+ mask_width_range=(min_mask_width, max_mask_width),
199
+ dim=self.dim,
200
+ num_mask=self.num_mask,
201
+ replace_with_zero=self.replace_with_zero,
202
+ fill_value=self.fill_value,
203
+ )
204
+ return spec, spec_lengths
funcineforge/models/specaug/specaug.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SpecAugment module."""
2
+
3
+ from typing import Optional
4
+ from typing import Sequence
5
+ from typing import Union
6
+
7
+ from funcineforge.models.specaug.mask_along_axis import MaskAlongAxis
8
+ from funcineforge.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
9
+ from funcineforge.models.specaug.time_warp import TimeWarp
10
+
11
+ import torch.nn as nn
12
+
13
+
14
+ class SpecAug(nn.Module):
15
+ """Implementation of SpecAug.
16
+
17
+ Reference:
18
+ Daniel S. Park et al.
19
+ "SpecAugment: A Simple Data
20
+ Augmentation Method for Automatic Speech Recognition"
21
+
22
+ .. warning::
23
+ When using cuda mode, time_warp doesn't have reproducibility
24
+ due to `torch.nn.functional.interpolate`.
25
+
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ apply_time_warp: bool = True,
31
+ time_warp_window: int = 5,
32
+ time_warp_mode: str = "bicubic",
33
+ apply_freq_mask: bool = True,
34
+ freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
35
+ num_freq_mask: int = 2,
36
+ apply_time_mask: bool = True,
37
+ time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
38
+ time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
39
+ num_time_mask: int = 2,
40
+ fill_value: float = 0.0,
41
+ ):
42
+ if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
43
+ raise ValueError("Either one of time_warp, time_mask, or freq_mask should be applied")
44
+ if (
45
+ apply_time_mask
46
+ and (time_mask_width_range is not None)
47
+ and (time_mask_width_ratio_range is not None)
48
+ ):
49
+ raise ValueError(
50
+ 'Either one of "time_mask_width_range" or '
51
+ '"time_mask_width_ratio_range" can be used'
52
+ )
53
+ super().__init__()
54
+ self.apply_time_warp = apply_time_warp
55
+ self.apply_freq_mask = apply_freq_mask
56
+ self.apply_time_mask = apply_time_mask
57
+
58
+ if apply_time_warp:
59
+ self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
60
+ else:
61
+ self.time_warp = None
62
+
63
+ if apply_freq_mask:
64
+ self.freq_mask = MaskAlongAxis(
65
+ dim="freq",
66
+ mask_width_range=freq_mask_width_range,
67
+ num_mask=num_freq_mask,
68
+ fill_value=fill_value,
69
+ )
70
+ else:
71
+ self.freq_mask = None
72
+
73
+ if apply_time_mask:
74
+ if time_mask_width_range is not None:
75
+ self.time_mask = MaskAlongAxis(
76
+ dim="time",
77
+ mask_width_range=time_mask_width_range,
78
+ num_mask=num_time_mask,
79
+ fill_value=fill_value,
80
+ )
81
+ elif time_mask_width_ratio_range is not None:
82
+ self.time_mask = MaskAlongAxisVariableMaxWidth(
83
+ dim="time",
84
+ mask_width_ratio_range=time_mask_width_ratio_range,
85
+ num_mask=num_time_mask,
86
+ fill_value=fill_value,
87
+ )
88
+ else:
89
+ raise ValueError(
90
+ 'Either one of "time_mask_width_range" or '
91
+ '"time_mask_width_ratio_range" should be used.'
92
+ )
93
+ else:
94
+ self.time_mask = None
95
+
96
+ def forward(self, x, x_lengths=None):
97
+ if self.time_warp is not None:
98
+ x, x_lengths = self.time_warp(x, x_lengths)
99
+ if self.freq_mask is not None:
100
+ x, x_lengths = self.freq_mask(x, x_lengths)
101
+ if self.time_mask is not None:
102
+ x, x_lengths = self.time_mask(x, x_lengths)
103
+ return x, x_lengths
funcineforge/models/specaug/time_warp.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Time warp module."""
2
+
3
+ import torch
4
+
5
+ from funcineforge.models.utils.nets_utils import pad_list
6
+
7
+ DEFAULT_TIME_WARP_MODE = "bicubic"
8
+
9
+
10
+ def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
11
+ """Time warping using torch.interpolate.
12
+
13
+ Args:
14
+ x: (Batch, Time, Freq)
15
+ window: time warp parameter
16
+ mode: Interpolate mode
17
+ """
18
+
19
+ # bicubic supports 4D or more dimension tensor
20
+ org_size = x.size()
21
+ if x.dim() == 3:
22
+ # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
23
+ x = x[:, None]
24
+
25
+ t = x.shape[2]
26
+ if t - window <= window:
27
+ return x.view(*org_size)
28
+
29
+ center = torch.randint(window, t - window, (1,))[0]
30
+ warped = torch.randint(center - window, center + window, (1,))[0] + 1
31
+
32
+ # left: (Batch, Channel, warped, Freq)
33
+ # right: (Batch, Channel, time - warped, Freq)
34
+ left = torch.nn.functional.interpolate(
35
+ x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
36
+ )
37
+ right = torch.nn.functional.interpolate(
38
+ x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
39
+ )
40
+
41
+ if x.requires_grad:
42
+ x = torch.cat([left, right], dim=-2)
43
+ else:
44
+ x[:, :, :warped] = left
45
+ x[:, :, warped:] = right
46
+
47
+ return x.view(*org_size)
48
+
49
+
50
+ class TimeWarp(torch.nn.Module):
51
+ """Time warping using torch.interpolate.
52
+
53
+ Args:
54
+ window: time warp parameter
55
+ mode: Interpolate mode
56
+ """
57
+
58
+ def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
59
+ super().__init__()
60
+ self.window = window
61
+ self.mode = mode
62
+
63
+ def extra_repr(self):
64
+ return f"window={self.window}, mode={self.mode}"
65
+
66
+ def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
67
+ """Forward function.
68
+
69
+ Args:
70
+ x: (Batch, Time, Freq)
71
+ x_lengths: (Batch,)
72
+ """
73
+
74
+ if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
75
+ # Note that applying same warping for each sample
76
+ y = time_warp(x, window=self.window, mode=self.mode)
77
+ else:
78
+ # FIXME(kamo): I have no idea to batchify Timewarp
79
+ ys = []
80
+ for i in range(x.size(0)):
81
+ _y = time_warp(
82
+ x[i][None, : x_lengths[i]],
83
+ window=self.window,
84
+ mode=self.mode,
85
+ )[0]
86
+ ys.append(_y)
87
+ y = pad_list(ys, 0.0)
88
+
89
+ return y, x_lengths
funcineforge/models/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import torch
2
+ dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
funcineforge/models/utils/llm_decoding.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Union
5
+ from funcineforge.utils.hinter import hint_once
6
+ import numpy as np
7
+ dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
8
+
9
+
10
+ class LLMDecoder(nn.Module):
11
+ def __init__(self, **kwargs):
12
+ super(LLMDecoder, self).__init__()
13
+ self.eos_token = kwargs["eos"]
14
+ if isinstance(self.eos_token, int):
15
+ self.eos_token = [self.eos_token]
16
+ self.token_embeder = kwargs["token_embeder"]
17
+ self.ras_conf = kwargs.get("ras_conf", {})
18
+ self.token_offset = kwargs.get("token_offset", 0)
19
+
20
+ def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25, beam_size=1):
21
+ prob, indices = [], []
22
+ cum_prob = 0.0
23
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
24
+ for i in range(len(sorted_idx)):
25
+ # sampling both top-p and numbers.
26
+ if cum_prob < top_p and len(prob) < top_k:
27
+ cum_prob += sorted_value[i]
28
+ prob.append(sorted_value[i])
29
+ indices.append(sorted_idx[i])
30
+ else:
31
+ break
32
+ prob = torch.tensor(prob).to(weighted_scores)
33
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
34
+ sampling_ids = prob.multinomial(beam_size, replacement=True)
35
+ top_ids = indices[sampling_ids]
36
+ return top_ids
37
+
38
+ def random_sampling(self, weighted_scores, beam_size=1):
39
+ top_ids = weighted_scores.softmax(dim=0).multinomial(beam_size, replacement=True)
40
+ return top_ids
41
+
42
+ # Repetition Aware Sampling in VALL-E 2
43
+ def ras_sampling(
44
+ self, weighted_scores, decoded_tokens, *,
45
+ top_p=0.8, top_k=25, win_size=10, tau_r=0.1
46
+ ):
47
+ if self.ras_conf is not None:
48
+ top_p = self.ras_conf.get("top_p", top_p)
49
+ top_k = self.ras_conf.get("top_k", top_k)
50
+ win_size = self.ras_conf.get("win_size", win_size)
51
+ tau_r = self.ras_conf.get("tau_r", tau_r)
52
+
53
+ hint_once(f"using Repetition Aware Sampling: top_p: {top_p}, top_k: {top_k},win_size: {win_size}, tau_r: {tau_r}", "ras_sampling")
54
+ top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
55
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()
56
+ if rep_num >= win_size * tau_r:
57
+ top_ids = self.random_sampling(weighted_scores)
58
+
59
+ return top_ids
60
+
61
+ def sampling_ids(
62
+ self,
63
+ weighted_scores: torch.Tensor,
64
+ sampling: Union[bool, int, float] = True,
65
+ decoded_tokens: list = None,
66
+ ):
67
+ if isinstance(sampling, bool):
68
+ if sampling:
69
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
70
+ else:
71
+ top_ids = weighted_scores.topk(1)[1]
72
+ elif isinstance(sampling, int):
73
+ prob, indices = weighted_scores.softmax(dim=0).topk(sampling)
74
+ sampling_ids = prob.multinomial(1, replacement=True)
75
+ top_ids = indices[sampling_ids]
76
+ elif isinstance(sampling, float):
77
+ prob, indices = [], []
78
+ cum_prob = 0.0
79
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
80
+ for i in range(len(sorted_idx)):
81
+ # sampling both top-p and numbers.
82
+ if cum_prob < sampling and len(prob) < 25:
83
+ cum_prob += sorted_value[i]
84
+ prob.append(sorted_value[i])
85
+ indices.append(sorted_idx[i])
86
+ else:
87
+ break
88
+ prob = torch.tensor(prob).to(weighted_scores)
89
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
90
+ sampling_ids = prob.multinomial(1, replacement=True)
91
+ top_ids = indices[sampling_ids]
92
+ elif isinstance(sampling, str) and sampling.lower() == "ras":
93
+ top_ids = self.ras_sampling(weighted_scores, decoded_tokens=decoded_tokens)
94
+ else:
95
+ raise NotImplementedError(f"Not implemented for {type(sampling)} sampling")
96
+
97
+ return top_ids
98
+
99
+ def __call__(self, input_embeddings, llm, states, quantize=False, **kwargs):
100
+ max_length = kwargs.get("max_length", 60 * 25)
101
+ min_length = kwargs.get("min_length", 2 * 25)
102
+ sampling = kwargs.get("sampling", True)
103
+ device = kwargs.get("device", "cuda")
104
+ llm_dtype = kwargs.get("llm_dtype", "fp32")
105
+ use_llm_cache = kwargs.get("use_llm_cache", True)
106
+ include_eos = kwargs.get("include_eos", False)
107
+ custom_eos_token = kwargs.get("custom_eos_token", self.eos_token)
108
+ avoid_token = kwargs.get("avoid_token", None)
109
+
110
+ llm_cache = states.get("llm_cache", None)
111
+ out_tokens, hit_eos = [], False
112
+ for i in range(max_length):
113
+ with torch.cuda.amp.autocast(
114
+ enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
115
+ ) if quantize is False else nullcontext():
116
+ # default attention_mask is causal, no longer need manually construct
117
+ # input_masks = torch.ones((1, input_embeddings.shape[1]), device=input_embeddings.device).to(torch.bool)
118
+
119
+ if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
120
+ outputs = llm.base_model.model(
121
+ inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings,
122
+ # attention_mask=input_masks,
123
+ output_hidden_states=True,
124
+ return_dict=True,
125
+ use_cache=use_llm_cache,
126
+ past_key_values=llm_cache,
127
+ )
128
+ else:
129
+ outputs = llm(
130
+ inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings,
131
+ # attention_mask=input_masks,
132
+ output_hidden_states=True,
133
+ return_dict=True,
134
+ use_cache=use_llm_cache,
135
+ past_key_values=llm_cache,
136
+ )
137
+ lm_hidden_states = outputs.hidden_states[-1]
138
+ h = llm.lm_head(lm_hidden_states[:, -1])
139
+ # logp = h.log_softmax(dim=-1).squeeze(0)
140
+ logp = h.squeeze(0)
141
+ if use_llm_cache:
142
+ llm_cache = outputs.past_key_values
143
+
144
+ pred = torch.log_softmax(logp, dim=-1)
145
+ if min_length is not None and i < min_length:
146
+ for x in custom_eos_token:
147
+ if pred.dtype == torch.bfloat16:
148
+ pred[x] = float(np.finfo(np.float16).min)
149
+ else:
150
+ pred[x] = float(np.finfo(np.float32).min)
151
+ if avoid_token is not None and len(avoid_token) > 0:
152
+ for x in avoid_token:
153
+ if pred.dtype == torch.bfloat16:
154
+ pred[x] = float(np.finfo(np.float16).min)
155
+ else:
156
+ pred[x] = float(np.finfo(np.float32).min)
157
+ top_id = self.sampling_ids(pred, sampling, out_tokens)[0].item()
158
+
159
+ if top_id in custom_eos_token:
160
+ if include_eos:
161
+ out_tokens.append(top_id)
162
+ hit_eos = True
163
+ break
164
+
165
+ out_tokens.append(top_id)
166
+ if use_llm_cache:
167
+ input_embeddings = self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset)
168
+ else:
169
+ input_embeddings = torch.cat([
170
+ input_embeddings,
171
+ self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset)
172
+ ], dim=1)
173
+
174
+ out_tokens = torch.tensor([out_tokens], dtype=torch.int64, device=device)
175
+
176
+ states = {"llm_cache": llm_cache}
177
+
178
+ return out_tokens, hit_eos, states
funcineforge/models/utils/mask_along_axis.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Sequence
3
+ from typing import Union
4
+
5
+
6
+ class MaskTailVariableMaxWidth(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
10
+ replace_value: float = 0.0,
11
+ ):
12
+ super().__init__()
13
+ self.mask_width_ratio_range = mask_width_ratio_range
14
+ self.replace_value = replace_value
15
+
16
+ def extra_repr(self):
17
+ return (
18
+ f"mask_width_ratio_range={self.mask_width_ratio_range}, "
19
+ )
20
+
21
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
22
+ bb, tt, _ = spec.shape
23
+
24
+ mask_width_ratio = torch.rand((bb, 1), device=spec.device)
25
+ ratio_st, ratio_ed = self.mask_width_ratio_range
26
+ mask_width_ratio = mask_width_ratio * (ratio_ed - ratio_st) + ratio_st
27
+ mask_length = (mask_width_ratio * spec_lengths.unsqueeze(1)).to(spec_lengths)
28
+
29
+ # mask_pos: (B, 1)
30
+ mask_start_pos = spec_lengths.unsqueeze(-1) - mask_length
31
+
32
+ aran = torch.arange(tt, device=spec.device)[None, :]
33
+ # mask: (Batch, L)
34
+ mask = aran < mask_start_pos
35
+ # (Batch, L) -> (Batch, L, 1)
36
+ mask = mask.unsqueeze(2)
37
+
38
+ return mask
39
+
40
+ class PrefixMaskVariableMaxWidth(torch.nn.Module):
41
+ def __init__(
42
+ self,
43
+ mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
44
+ replace_value: float = 0.0,
45
+ ):
46
+ super().__init__()
47
+ self.mask_width_ratio_range = mask_width_ratio_range
48
+ self.replace_value = replace_value
49
+
50
+ def extra_repr(self):
51
+ return (
52
+ f"mask_width_ratio_range={self.mask_width_ratio_range}, "
53
+ )
54
+
55
+ def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None, return_mask: bool = False):
56
+ bb, tt, _ = spec.shape
57
+
58
+ mask_width_ratio_range = torch.tensor(self.mask_width_ratio_range, dtype=torch.float32, device=spec.device)
59
+ mask_width_range = (mask_width_ratio_range * tt).long()
60
+ mask_length = torch.randint(
61
+ mask_width_range[0],
62
+ mask_width_range[1],
63
+ (bb, 1),
64
+ device=spec.device,
65
+ ).unsqueeze(2)
66
+
67
+ # mask_pos: (B, num_mask, 1)
68
+ mask_pos = tt - mask_length
69
+
70
+ aran = torch.arange(tt, device=spec.device)[None, None, :]
71
+ # mask: (Batch, num_mask, L)
72
+ mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
73
+ # Multiply masks: (Batch, num_mask, L) -> (Batch, L, 1)
74
+ mask = mask.any(dim=1).unsqueeze(2)
75
+
76
+ return mask
funcineforge/models/utils/masks.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def add_optional_chunk_mask(xs: torch.Tensor,
4
+ masks: torch.Tensor,
5
+ use_dynamic_chunk: bool,
6
+ use_dynamic_left_chunk: bool,
7
+ decoding_chunk_size: int,
8
+ static_chunk_size: int,
9
+ num_decoding_left_chunks: int,
10
+ enable_full_context: bool = True):
11
+ """ Apply optional mask for encoder.
12
+
13
+ Args:
14
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
15
+ mask (torch.Tensor): mask for xs, (B, 1, L)
16
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
17
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
18
+ training.
19
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
20
+ 0: default for training, use random dynamic chunk.
21
+ <0: for decoding, use full chunk.
22
+ >0: for decoding, use fixed chunk size as set.
23
+ static_chunk_size (int): chunk size for static chunk training/decoding
24
+ if it's greater than 0, if use_dynamic_chunk is true,
25
+ this parameter will be ignored
26
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
27
+ the chunk size is decoding_chunk_size.
28
+ >=0: use num_decoding_left_chunks
29
+ <0: use all left chunks
30
+ enable_full_context (bool):
31
+ True: chunk size is either [1, 25] or full context(max_len)
32
+ False: chunk size ~ U[1, 25]
33
+
34
+ Returns:
35
+ torch.Tensor: chunk mask of the input xs.
36
+ """
37
+ # Whether to use chunk mask or not
38
+ if use_dynamic_chunk:
39
+ max_len = xs.size(1)
40
+ if decoding_chunk_size < 0:
41
+ chunk_size = max_len
42
+ num_left_chunks = -1
43
+ elif decoding_chunk_size > 0:
44
+ chunk_size = decoding_chunk_size
45
+ num_left_chunks = num_decoding_left_chunks
46
+ else:
47
+ # chunk size is either [1, 25] or full context(max_len).
48
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
49
+ # delay, the maximum frame is 100 / 4 = 25.
50
+ chunk_size = torch.randint(1, max_len, (1, )).item()
51
+ num_left_chunks = -1
52
+ if chunk_size > max_len // 2 and enable_full_context:
53
+ chunk_size = max_len
54
+ else:
55
+ chunk_size = chunk_size % 25 + 1
56
+ if use_dynamic_left_chunk:
57
+ max_left_chunks = (max_len - 1) // chunk_size
58
+ num_left_chunks = torch.randint(0, max_left_chunks,
59
+ (1, )).item()
60
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
61
+ num_left_chunks,
62
+ xs.device) # (L, L)
63
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
64
+ chunk_masks = masks & chunk_masks # (B, L, L)
65
+ elif static_chunk_size > 0:
66
+ num_left_chunks = num_decoding_left_chunks
67
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
68
+ num_left_chunks,
69
+ xs.device) # (L, L)
70
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
71
+ chunk_masks = masks & chunk_masks # (B, L, L)
72
+ else:
73
+ chunk_masks = masks
74
+ assert chunk_masks.dtype == torch.bool
75
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
76
+ print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
77
+ chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
78
+ return chunk_masks
79
+
80
+
81
+ def subsequent_chunk_mask(
82
+ size: int,
83
+ chunk_size: int,
84
+ num_left_chunks: int = -1,
85
+ device: torch.device = torch.device("cpu"),
86
+ ) -> torch.Tensor:
87
+ """Create mask for subsequent steps (size, size) with chunk size,
88
+ this is for streaming encoder
89
+
90
+ Args:
91
+ size (int): size of mask
92
+ chunk_size (int): size of chunk
93
+ num_left_chunks (int): number of left chunks
94
+ <0: use full chunk
95
+ >=0: use num_left_chunks
96
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
97
+
98
+ Returns:
99
+ torch.Tensor: mask
100
+
101
+ Examples:
102
+ >>> subsequent_chunk_mask(4, 2)
103
+ [[1, 1, 0, 0],
104
+ [1, 1, 0, 0],
105
+ [1, 1, 1, 1],
106
+ [1, 1, 1, 1]]
107
+ """
108
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
109
+ pos_idx = torch.arange(size, device=device)
110
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
111
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
112
+ return ret
113
+
114
+ def causal_block_mask(size, block_size=1, device="cpu", dtype=torch.bool):
115
+ """Create mask for subsequent steps (size, size).
116
+
117
+ :param int size: size of mask
118
+ :param int block_size: block size of mask
119
+ :param str device: "cpu" or "cuda" or torch.Tensor.device
120
+ :param torch.dtype dtype: result dtype
121
+ :rtype: torch.Tensor
122
+ >>> causal_block_mask(4, 2)
123
+ [[1, 1, 0, 0],
124
+ [1, 1, 0, 0],
125
+ [1, 1, 1, 1],
126
+ [1, 1, 1, 1]]
127
+ """
128
+ # assert size % block_size == 0
129
+ pos_idx = torch.arange(size, device=device)
130
+ block_value = (torch.div(pos_idx, block_size, rounding_mode='trunc') + 1) * block_size
131
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
132
+ return ret.to(dtype)
funcineforge/models/utils/nets_utils.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Network related utility tools."""
4
+
5
+ import logging
6
+ from typing import Dict, List, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def to_device(m, x):
13
+ """Send tensor into the device of the module.
14
+
15
+ Args:
16
+ m (torch.nn.Module): Torch module.
17
+ x (Tensor): Torch tensor.
18
+
19
+ Returns:
20
+ Tensor: Torch tensor located in the same place as torch module.
21
+
22
+ """
23
+ if isinstance(m, torch.nn.Module):
24
+ device = next(m.parameters()).device
25
+ elif isinstance(m, torch.Tensor):
26
+ device = m.device
27
+ else:
28
+ raise TypeError("Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}")
29
+ return x.to(device)
30
+
31
+
32
+ def pad_list(xs, pad_value):
33
+ """Perform padding for the list of tensors.
34
+
35
+ Args:
36
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
37
+ pad_value (float): Value for padding.
38
+
39
+ Returns:
40
+ Tensor: Padded tensor (B, Tmax, `*`).
41
+
42
+ Examples:
43
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
44
+ >>> x
45
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
46
+ >>> pad_list(x, 0)
47
+ tensor([[1., 1., 1., 1.],
48
+ [1., 1., 0., 0.],
49
+ [1., 0., 0., 0.]])
50
+
51
+ """
52
+ n_batch = len(xs)
53
+ max_len = max(x.size(0) for x in xs)
54
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
55
+
56
+ for i in range(n_batch):
57
+ pad[i, : xs[i].size(0)] = xs[i]
58
+
59
+ return pad
60
+
61
+
62
+ def pad_list_all_dim(xs, pad_value):
63
+ """Perform padding for the list of tensors.
64
+
65
+ Args:
66
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
67
+ pad_value (float): Value for padding.
68
+
69
+ Returns:
70
+ Tensor: Padded tensor (B, Tmax, `*`).
71
+
72
+ Examples:
73
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
74
+ >>> x
75
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
76
+ >>> pad_list(x, 0)
77
+ tensor([[1., 1., 1., 1.],
78
+ [1., 1., 0., 0.],
79
+ [1., 0., 0., 0.]])
80
+
81
+ """
82
+ n_batch = len(xs)
83
+ num_dim = len(xs[0].shape)
84
+ max_len_all_dim = []
85
+ for i in range(num_dim):
86
+ max_len_all_dim.append(max(x.size(i) for x in xs))
87
+ pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
88
+
89
+ for i in range(n_batch):
90
+ if num_dim == 1:
91
+ pad[i, : xs[i].size(0)] = xs[i]
92
+ elif num_dim == 2:
93
+ pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
94
+ elif num_dim == 3:
95
+ pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
96
+ else:
97
+ raise ValueError(
98
+ "pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
99
+ )
100
+
101
+ return pad
102
+
103
+
104
+ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
105
+ """Make mask tensor containing indices of padded part.
106
+
107
+ Args:
108
+ lengths (LongTensor or List): Batch of lengths (B,).
109
+ xs (Tensor, optional): The reference tensor.
110
+ If set, masks will be the same shape as this tensor.
111
+ length_dim (int, optional): Dimension indicator of the above tensor.
112
+ See the example.
113
+
114
+ Returns:
115
+ Tensor: Mask tensor containing indices of padded part.
116
+ dtype=torch.uint8 in PyTorch 1.2-
117
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
118
+
119
+ Examples:
120
+ With only lengths.
121
+
122
+ >>> lengths = [5, 3, 2]
123
+ >>> make_pad_mask(lengths)
124
+ masks = [[0, 0, 0, 0 ,0],
125
+ [0, 0, 0, 1, 1],
126
+ [0, 0, 1, 1, 1]]
127
+
128
+ With the reference tensor.
129
+
130
+ >>> xs = torch.zeros((3, 2, 4))
131
+ >>> make_pad_mask(lengths, xs)
132
+ tensor([[[0, 0, 0, 0],
133
+ [0, 0, 0, 0]],
134
+ [[0, 0, 0, 1],
135
+ [0, 0, 0, 1]],
136
+ [[0, 0, 1, 1],
137
+ [0, 0, 1, 1]]], dtype=torch.uint8)
138
+ >>> xs = torch.zeros((3, 2, 6))
139
+ >>> make_pad_mask(lengths, xs)
140
+ tensor([[[0, 0, 0, 0, 0, 1],
141
+ [0, 0, 0, 0, 0, 1]],
142
+ [[0, 0, 0, 1, 1, 1],
143
+ [0, 0, 0, 1, 1, 1]],
144
+ [[0, 0, 1, 1, 1, 1],
145
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
146
+
147
+ With the reference tensor and dimension indicator.
148
+
149
+ >>> xs = torch.zeros((3, 6, 6))
150
+ >>> make_pad_mask(lengths, xs, 1)
151
+ tensor([[[0, 0, 0, 0, 0, 0],
152
+ [0, 0, 0, 0, 0, 0],
153
+ [0, 0, 0, 0, 0, 0],
154
+ [0, 0, 0, 0, 0, 0],
155
+ [0, 0, 0, 0, 0, 0],
156
+ [1, 1, 1, 1, 1, 1]],
157
+ [[0, 0, 0, 0, 0, 0],
158
+ [0, 0, 0, 0, 0, 0],
159
+ [0, 0, 0, 0, 0, 0],
160
+ [1, 1, 1, 1, 1, 1],
161
+ [1, 1, 1, 1, 1, 1],
162
+ [1, 1, 1, 1, 1, 1]],
163
+ [[0, 0, 0, 0, 0, 0],
164
+ [0, 0, 0, 0, 0, 0],
165
+ [1, 1, 1, 1, 1, 1],
166
+ [1, 1, 1, 1, 1, 1],
167
+ [1, 1, 1, 1, 1, 1],
168
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
169
+ >>> make_pad_mask(lengths, xs, 2)
170
+ tensor([[[0, 0, 0, 0, 0, 1],
171
+ [0, 0, 0, 0, 0, 1],
172
+ [0, 0, 0, 0, 0, 1],
173
+ [0, 0, 0, 0, 0, 1],
174
+ [0, 0, 0, 0, 0, 1],
175
+ [0, 0, 0, 0, 0, 1]],
176
+ [[0, 0, 0, 1, 1, 1],
177
+ [0, 0, 0, 1, 1, 1],
178
+ [0, 0, 0, 1, 1, 1],
179
+ [0, 0, 0, 1, 1, 1],
180
+ [0, 0, 0, 1, 1, 1],
181
+ [0, 0, 0, 1, 1, 1]],
182
+ [[0, 0, 1, 1, 1, 1],
183
+ [0, 0, 1, 1, 1, 1],
184
+ [0, 0, 1, 1, 1, 1],
185
+ [0, 0, 1, 1, 1, 1],
186
+ [0, 0, 1, 1, 1, 1],
187
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
188
+
189
+ """
190
+ if length_dim == 0:
191
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
192
+
193
+ if not isinstance(lengths, list):
194
+ lengths = lengths.tolist()
195
+ bs = int(len(lengths))
196
+ if maxlen is None:
197
+ if xs is None:
198
+ maxlen = int(max(lengths))
199
+ else:
200
+ maxlen = xs.size(length_dim)
201
+ else:
202
+ assert xs is None
203
+ assert maxlen >= int(max(lengths))
204
+
205
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
206
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
207
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
208
+ mask = seq_range_expand >= seq_length_expand
209
+
210
+ if xs is not None:
211
+ assert xs.size(0) == bs, (xs.size(0), bs)
212
+
213
+ if length_dim < 0:
214
+ length_dim = xs.dim() + length_dim
215
+ # ind = (:, None, ..., None, :, , None, ..., None)
216
+ ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
217
+ mask = mask[ind].expand_as(xs).to(xs.device)
218
+ return mask
219
+
220
+
221
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
222
+ """Make mask tensor containing indices of non-padded part.
223
+
224
+ Args:
225
+ lengths (LongTensor or List): Batch of lengths (B,).
226
+ xs (Tensor, optional): The reference tensor.
227
+ If set, masks will be the same shape as this tensor.
228
+ length_dim (int, optional): Dimension indicator of the above tensor.
229
+ See the example.
230
+
231
+ Returns:
232
+ ByteTensor: mask tensor containing indices of padded part.
233
+ dtype=torch.uint8 in PyTorch 1.2-
234
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
235
+
236
+ Examples:
237
+ With only lengths.
238
+
239
+ >>> lengths = [5, 3, 2]
240
+ >>> make_non_pad_mask(lengths)
241
+ masks = [[1, 1, 1, 1 ,1],
242
+ [1, 1, 1, 0, 0],
243
+ [1, 1, 0, 0, 0]]
244
+
245
+ With the reference tensor.
246
+
247
+ >>> xs = torch.zeros((3, 2, 4))
248
+ >>> make_non_pad_mask(lengths, xs)
249
+ tensor([[[1, 1, 1, 1],
250
+ [1, 1, 1, 1]],
251
+ [[1, 1, 1, 0],
252
+ [1, 1, 1, 0]],
253
+ [[1, 1, 0, 0],
254
+ [1, 1, 0, 0]]], dtype=torch.uint8)
255
+ >>> xs = torch.zeros((3, 2, 6))
256
+ >>> make_non_pad_mask(lengths, xs)
257
+ tensor([[[1, 1, 1, 1, 1, 0],
258
+ [1, 1, 1, 1, 1, 0]],
259
+ [[1, 1, 1, 0, 0, 0],
260
+ [1, 1, 1, 0, 0, 0]],
261
+ [[1, 1, 0, 0, 0, 0],
262
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
263
+
264
+ With the reference tensor and dimension indicator.
265
+
266
+ >>> xs = torch.zeros((3, 6, 6))
267
+ >>> make_non_pad_mask(lengths, xs, 1)
268
+ tensor([[[1, 1, 1, 1, 1, 1],
269
+ [1, 1, 1, 1, 1, 1],
270
+ [1, 1, 1, 1, 1, 1],
271
+ [1, 1, 1, 1, 1, 1],
272
+ [1, 1, 1, 1, 1, 1],
273
+ [0, 0, 0, 0, 0, 0]],
274
+ [[1, 1, 1, 1, 1, 1],
275
+ [1, 1, 1, 1, 1, 1],
276
+ [1, 1, 1, 1, 1, 1],
277
+ [0, 0, 0, 0, 0, 0],
278
+ [0, 0, 0, 0, 0, 0],
279
+ [0, 0, 0, 0, 0, 0]],
280
+ [[1, 1, 1, 1, 1, 1],
281
+ [1, 1, 1, 1, 1, 1],
282
+ [0, 0, 0, 0, 0, 0],
283
+ [0, 0, 0, 0, 0, 0],
284
+ [0, 0, 0, 0, 0, 0],
285
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
286
+ >>> make_non_pad_mask(lengths, xs, 2)
287
+ tensor([[[1, 1, 1, 1, 1, 0],
288
+ [1, 1, 1, 1, 1, 0],
289
+ [1, 1, 1, 1, 1, 0],
290
+ [1, 1, 1, 1, 1, 0],
291
+ [1, 1, 1, 1, 1, 0],
292
+ [1, 1, 1, 1, 1, 0]],
293
+ [[1, 1, 1, 0, 0, 0],
294
+ [1, 1, 1, 0, 0, 0],
295
+ [1, 1, 1, 0, 0, 0],
296
+ [1, 1, 1, 0, 0, 0],
297
+ [1, 1, 1, 0, 0, 0],
298
+ [1, 1, 1, 0, 0, 0]],
299
+ [[1, 1, 0, 0, 0, 0],
300
+ [1, 1, 0, 0, 0, 0],
301
+ [1, 1, 0, 0, 0, 0],
302
+ [1, 1, 0, 0, 0, 0],
303
+ [1, 1, 0, 0, 0, 0],
304
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
305
+
306
+ """
307
+ return ~make_pad_mask(lengths, xs, length_dim)
308
+
309
+
310
+ def mask_by_length(xs, lengths, fill=0):
311
+ """Mask tensor according to length.
312
+
313
+ Args:
314
+ xs (Tensor): Batch of input tensor (B, `*`).
315
+ lengths (LongTensor or List): Batch of lengths (B,).
316
+ fill (int or float): Value to fill masked part.
317
+
318
+ Returns:
319
+ Tensor: Batch of masked input tensor (B, `*`).
320
+
321
+ Examples:
322
+ >>> x = torch.arange(5).repeat(3, 1) + 1
323
+ >>> x
324
+ tensor([[1, 2, 3, 4, 5],
325
+ [1, 2, 3, 4, 5],
326
+ [1, 2, 3, 4, 5]])
327
+ >>> lengths = [5, 3, 2]
328
+ >>> mask_by_length(x, lengths)
329
+ tensor([[1, 2, 3, 4, 5],
330
+ [1, 2, 3, 0, 0],
331
+ [1, 2, 0, 0, 0]])
332
+
333
+ """
334
+ assert xs.size(0) == len(lengths)
335
+ ret = xs.data.new(*xs.size()).fill_(fill)
336
+ for i, l in enumerate(lengths):
337
+ ret[i, :l] = xs[i, :l]
338
+ return ret
339
+
340
+
341
+ def to_torch_tensor(x):
342
+ """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
343
+
344
+ Args:
345
+ x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
346
+
347
+ Returns:
348
+ Tensor or ComplexTensor: Type converted inputs.
349
+
350
+ Examples:
351
+ >>> xs = np.ones(3, dtype=np.float32)
352
+ >>> xs = to_torch_tensor(xs)
353
+ tensor([1., 1., 1.])
354
+ >>> xs = torch.ones(3, 4, 5)
355
+ >>> assert to_torch_tensor(xs) is xs
356
+ >>> xs = {'real': xs, 'imag': xs}
357
+ >>> to_torch_tensor(xs)
358
+ ComplexTensor(
359
+ Real:
360
+ tensor([1., 1., 1.])
361
+ Imag;
362
+ tensor([1., 1., 1.])
363
+ )
364
+
365
+ """
366
+ # If numpy, change to torch tensor
367
+ if isinstance(x, np.ndarray):
368
+ if x.dtype.kind == "c":
369
+ # Dynamically importing because torch_complex requires python3
370
+ from torch_complex.tensor import ComplexTensor
371
+
372
+ return ComplexTensor(x)
373
+ else:
374
+ return torch.from_numpy(x)
375
+
376
+ # If {'real': ..., 'imag': ...}, convert to ComplexTensor
377
+ elif isinstance(x, dict):
378
+ # Dynamically importing because torch_complex requires python3
379
+ from torch_complex.tensor import ComplexTensor
380
+
381
+ if "real" not in x or "imag" not in x:
382
+ raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
383
+ # Relative importing because of using python3 syntax
384
+ return ComplexTensor(x["real"], x["imag"])
385
+
386
+ # If torch.Tensor, as it is
387
+ elif isinstance(x, torch.Tensor):
388
+ return x
389
+
390
+ else:
391
+ error = (
392
+ "x must be numpy.ndarray, torch.Tensor or a dict like "
393
+ "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
394
+ "but got {}".format(type(x))
395
+ )
396
+ try:
397
+ from torch_complex.tensor import ComplexTensor
398
+ except Exception:
399
+ # If PY2
400
+ raise ValueError(error)
401
+ else:
402
+ # If PY3
403
+ if isinstance(x, ComplexTensor):
404
+ return x
405
+ else:
406
+ raise ValueError(error)
407
+
408
+
409
+ def get_subsample(train_args, mode, arch):
410
+ """Parse the subsampling factors from the args for the specified `mode` and `arch`.
411
+
412
+ Args:
413
+ train_args: argument Namespace containing options.
414
+ mode: one of ('asr', 'mt', 'st')
415
+ arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
416
+
417
+ Returns:
418
+ np.ndarray / List[np.ndarray]: subsampling factors.
419
+ """
420
+ if arch == "transformer":
421
+ return np.array([1])
422
+
423
+ elif mode == "mt" and arch == "rnn":
424
+ # +1 means input (+1) and layers outputs (train_args.elayer)
425
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
426
+ logging.warning("Subsampling is not performed for machine translation.")
427
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
428
+ return subsample
429
+
430
+ elif (
431
+ (mode == "asr" and arch in ("rnn", "rnn-t"))
432
+ or (mode == "mt" and arch == "rnn")
433
+ or (mode == "st" and arch == "rnn")
434
+ ):
435
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
436
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
437
+ ss = train_args.subsample.split("_")
438
+ for j in range(min(train_args.elayers + 1, len(ss))):
439
+ subsample[j] = int(ss[j])
440
+ else:
441
+ logging.warning(
442
+ "Subsampling is not performed for vgg*. "
443
+ "It is performed in max pooling layers at CNN."
444
+ )
445
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
446
+ return subsample
447
+
448
+ elif mode == "asr" and arch == "rnn_mix":
449
+ subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32)
450
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
451
+ ss = train_args.subsample.split("_")
452
+ for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
453
+ subsample[j] = int(ss[j])
454
+ else:
455
+ logging.warning(
456
+ "Subsampling is not performed for vgg*. "
457
+ "It is performed in max pooling layers at CNN."
458
+ )
459
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
460
+ return subsample
461
+
462
+ elif mode == "asr" and arch == "rnn_mulenc":
463
+ subsample_list = []
464
+ for idx in range(train_args.num_encs):
465
+ subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
466
+ if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
467
+ ss = train_args.subsample[idx].split("_")
468
+ for j in range(min(train_args.elayers[idx] + 1, len(ss))):
469
+ subsample[j] = int(ss[j])
470
+ else:
471
+ logging.warning(
472
+ "Encoder %d: Subsampling is not performed for vgg*. "
473
+ "It is performed in max pooling layers at CNN.",
474
+ idx + 1,
475
+ )
476
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
477
+ subsample_list.append(subsample)
478
+ return subsample_list
479
+
480
+ else:
481
+ raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
482
+
483
+
484
+ def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
485
+ """Replace keys of old prefix with new prefix in state dict."""
486
+ # need this list not to break the dict iterator
487
+ old_keys = [k for k in state_dict if k.startswith(old_prefix)]
488
+ if len(old_keys) > 0:
489
+ logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
490
+ for k in old_keys:
491
+ v = state_dict.pop(k)
492
+ new_k = k.replace(old_prefix, new_prefix)
493
+ state_dict[new_k] = v
494
+
495
+
496
+ class Swish(torch.nn.Module):
497
+ """Swish activation definition.
498
+
499
+ Swish(x) = (beta * x) * sigmoid(x)
500
+ where beta = 1 defines standard Swish activation.
501
+
502
+ References:
503
+ https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
504
+ E-swish variant: https://arxiv.org/abs/1801.07145.
505
+
506
+ Args:
507
+ beta: Beta parameter for E-Swish.
508
+ (beta >= 1. If beta < 1, use standard Swish).
509
+ use_builtin: Whether to use PyTorch function if available.
510
+
511
+ """
512
+
513
+ def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
514
+ super().__init__()
515
+
516
+ self.beta = beta
517
+
518
+ if beta > 1:
519
+ self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
520
+ else:
521
+ if use_builtin:
522
+ self.swish = torch.nn.SiLU()
523
+ else:
524
+ self.swish = lambda x: x * torch.sigmoid(x)
525
+
526
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
527
+ """Forward computation."""
528
+ return self.swish(x)
529
+
530
+
531
+ def get_activation(act):
532
+ """Return activation function."""
533
+
534
+ activation_funcs = {
535
+ "hardtanh": torch.nn.Hardtanh,
536
+ "tanh": torch.nn.Tanh,
537
+ "relu": torch.nn.ReLU,
538
+ "selu": torch.nn.SELU,
539
+ "swish": Swish,
540
+ }
541
+
542
+ return activation_funcs[act]()
543
+
544
+
545
+ class TooShortUttError(Exception):
546
+ """Raised when the utt is too short for subsampling.
547
+
548
+ Args:
549
+ message: Error message to display.
550
+ actual_size: The size that cannot pass the subsampling.
551
+ limit: The size limit for subsampling.
552
+
553
+ """
554
+
555
+ def __init__(self, message: str, actual_size: int, limit: int) -> None:
556
+ """Construct a TooShortUttError module."""
557
+ super().__init__(message)
558
+
559
+ self.actual_size = actual_size
560
+ self.limit = limit
561
+
562
+
563
+ def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
564
+ """Check if the input is too short for subsampling.
565
+
566
+ Args:
567
+ sub_factor: Subsampling factor for Conv2DSubsampling.
568
+ size: Input size.
569
+
570
+ Returns:
571
+ : Whether an error should be sent.
572
+ : Size limit for specified subsampling factor.
573
+
574
+ """
575
+ if sub_factor == 2 and size < 3:
576
+ return True, 7
577
+ elif sub_factor == 4 and size < 7:
578
+ return True, 7
579
+ elif sub_factor == 6 and size < 11:
580
+ return True, 11
581
+
582
+ return False, -1
583
+
584
+
585
+ def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
586
+ """Get conv2D second layer parameters for given subsampling factor.
587
+
588
+ Args:
589
+ sub_factor: Subsampling factor (1/X).
590
+ input_size: Input size.
591
+
592
+ Returns:
593
+ : Kernel size for second convolution.
594
+ : Stride for second convolution.
595
+ : Conv2DSubsampling output size.
596
+
597
+ """
598
+ if sub_factor == 2:
599
+ return 3, 1, (((input_size - 1) // 2 - 2))
600
+ elif sub_factor == 4:
601
+ return 3, 2, (((input_size - 1) // 2 - 1) // 2)
602
+ elif sub_factor == 6:
603
+ return 5, 3, (((input_size - 1) // 2 - 2) // 3)
604
+ else:
605
+ raise ValueError("subsampling_factor parameter should be set to either 2, 4 or 6.")
606
+
607
+
608
+ def make_chunk_mask(
609
+ size: int,
610
+ chunk_size: int,
611
+ left_chunk_size: int = 0,
612
+ device: torch.device = None,
613
+ ) -> torch.Tensor:
614
+ """Create chunk mask for the subsequent steps (size, size).
615
+
616
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
617
+
618
+ Args:
619
+ size: Size of the source mask.
620
+ chunk_size: Number of frames in chunk.
621
+ left_chunk_size: Size of the left context in chunks (0 means full context).
622
+ device: Device for the mask tensor.
623
+
624
+ Returns:
625
+ mask: Chunk mask. (size, size)
626
+
627
+ """
628
+ mask = torch.zeros(size, size, device=device, dtype=torch.bool)
629
+
630
+ for i in range(size):
631
+ if left_chunk_size < 0:
632
+ start = 0
633
+ else:
634
+ start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
635
+
636
+ end = min((i // chunk_size + 1) * chunk_size, size)
637
+ mask[i, start:end] = True
638
+
639
+ return ~mask
640
+
641
+
642
+ def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
643
+ """Create source mask for given lengths.
644
+
645
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
646
+
647
+ Args:
648
+ lengths: Sequence lengths. (B,)
649
+
650
+ Returns:
651
+ : Mask for the sequence lengths. (B, max_len)
652
+
653
+ """
654
+ max_len = lengths.max()
655
+ batch_size = lengths.size(0)
656
+
657
+ expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
658
+
659
+ return expanded_lengths >= lengths.unsqueeze(1)
660
+
661
+
662
+ def get_transducer_task_io(
663
+ labels: torch.Tensor,
664
+ encoder_out_lens: torch.Tensor,
665
+ ignore_id: int = -1,
666
+ blank_id: int = 0,
667
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
668
+ """Get Transducer loss I/O.
669
+
670
+ Args:
671
+ labels: Label ID sequences. (B, L)
672
+ encoder_out_lens: Encoder output lengths. (B,)
673
+ ignore_id: Padding symbol ID.
674
+ blank_id: Blank symbol ID.
675
+
676
+ Returns:
677
+ decoder_in: Decoder inputs. (B, U)
678
+ target: Target label ID sequences. (B, U)
679
+ t_len: Time lengths. (B,)
680
+ u_len: Label lengths. (B,)
681
+
682
+ """
683
+
684
+ def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
685
+ """Create padded batch of labels from a list of labels sequences.
686
+
687
+ Args:
688
+ labels: Labels sequences. [B x (?)]
689
+ padding_value: Padding value.
690
+
691
+ Returns:
692
+ labels: Batch of padded labels sequences. (B,)
693
+
694
+ """
695
+ batch_size = len(labels)
696
+
697
+ padded = (
698
+ labels[0]
699
+ .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
700
+ .fill_(padding_value)
701
+ )
702
+
703
+ for i in range(batch_size):
704
+ padded[i, : labels[i].size(0)] = labels[i]
705
+
706
+ return padded
707
+
708
+ device = labels.device
709
+
710
+ labels_unpad = [y[y != ignore_id] for y in labels]
711
+ blank = labels[0].new([blank_id])
712
+
713
+ decoder_in = pad_list(
714
+ [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
715
+ ).to(device)
716
+
717
+ target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
718
+
719
+ encoder_out_lens = list(map(int, encoder_out_lens))
720
+ t_len = torch.IntTensor(encoder_out_lens).to(device)
721
+
722
+ u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
723
+
724
+ return decoder_in, target, t_len, u_len
725
+
726
+
727
+ def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
728
+ """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
729
+ if t.size(dim) == pad_len:
730
+ return t
731
+ else:
732
+ pad_size = list(t.shape)
733
+ pad_size[dim] = pad_len - t.size(dim)
734
+ return torch.cat([t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim)
funcineforge/tokenizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tokenizer import FunCineForgeTokenizer