Spaces:
Running on Zero
Running on Zero
Upload 111 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- .gitignore +1 -0
- LICENSE +201 -0
- README.md +155 -14
- README_zh.md +153 -0
- app.py +415 -0
- data/ref.wav +3 -0
- data/sample.mp4 +3 -0
- decode_conf/decode.yaml +42 -0
- decode_conf/diar.yaml +51 -0
- decode_conf/ds_stage0_fp32.json +33 -0
- funcineforge/.DS_Store +0 -0
- funcineforge/__init__.py +7 -0
- funcineforge/auto/__init__.py +0 -0
- funcineforge/auto/auto_frontend.py +95 -0
- funcineforge/auto/auto_model.py +173 -0
- funcineforge/datasets/__init__.py +2 -0
- funcineforge/datasets/datasets.py +193 -0
- funcineforge/datasets/index_ds.py +151 -0
- funcineforge/download/__init__.py +0 -0
- funcineforge/download/download_model_from_hub.py +220 -0
- funcineforge/download/file.py +320 -0
- funcineforge/download/name_maps_from_hub.py +42 -0
- funcineforge/face/__init__.py +1 -0
- funcineforge/face/face_recognition.py +16 -0
- funcineforge/models/__init__.py +5 -0
- funcineforge/models/causal_hifigan.py +834 -0
- funcineforge/models/flow_matching_model.py +514 -0
- funcineforge/models/inference_model.py +116 -0
- funcineforge/models/language_model.py +274 -0
- funcineforge/models/modules/__init__.py +0 -0
- funcineforge/models/modules/dit_flow_matching/__init__.py +0 -0
- funcineforge/models/modules/dit_flow_matching/dit_model.py +208 -0
- funcineforge/models/modules/dit_flow_matching/dit_modules.py +622 -0
- funcineforge/models/modules/hifigan/__init__.py +14 -0
- funcineforge/models/modules/hifigan/activations.py +120 -0
- funcineforge/models/modules/hifigan/discriminator.py +299 -0
- funcineforge/models/modules/hifigan/generator.py +625 -0
- funcineforge/models/modules/hifigan/mel_spectrum.py +93 -0
- funcineforge/models/modules/hifigan/nsf_utils.py +253 -0
- funcineforge/models/specaug/__init__.py +0 -0
- funcineforge/models/specaug/mask_along_axis.py +204 -0
- funcineforge/models/specaug/specaug.py +103 -0
- funcineforge/models/specaug/time_warp.py +89 -0
- funcineforge/models/utils/__init__.py +2 -0
- funcineforge/models/utils/llm_decoding.py +178 -0
- funcineforge/models/utils/mask_along_axis.py +76 -0
- funcineforge/models/utils/masks.py +132 -0
- funcineforge/models/utils/nets_utils.py +734 -0
- funcineforge/tokenizer/__init__.py +1 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/ref.wav filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/sample.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,14 +1,155 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
license:
|
| 11 |
-
|
| 12 |
-
--
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### <p align="center">「English | [简体中文](./README_zh.md)」</p>
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<b>🎬 Fun-CineForge: A Unified Dataset Pipeline and Model for Zero-Shot Movie Dubbing<br>
|
| 5 |
+
in Diverse Cinematic Scenes</b>
|
| 6 |
+
</p>
|
| 7 |
+
|
| 8 |
+
<div align="center">
|
| 9 |
+
|
| 10 |
+

|
| 11 |
+
<a href=""><img src="https://img.shields.io/badge/OS-Linux-orange.svg"></a>
|
| 12 |
+
<a href=""><img src="https://img.shields.io/badge/Python->=3.8-aff.svg"></a>
|
| 13 |
+
<a href=""><img src="https://img.shields.io/badge/Pytorch->=2.1-blue"></a>
|
| 14 |
+
</div>
|
| 15 |
+
|
| 16 |
+
<div align="center">
|
| 17 |
+
<h4><a href="#Dataset&Demo">Dataset & Demo</a>
|
| 18 |
+
|<a href="#Environment">Environment</a>
|
| 19 |
+
|<a href="#Dataset-Pipeline">Dataset Pipeline</a>
|
| 20 |
+
|<a href="#Dubbing-Model">Dubbing Model</a>
|
| 21 |
+
|<a href="#Recent-Updates">Recent Updates</a>
|
| 22 |
+
|<a href="#Publication">Publication</a>
|
| 23 |
+
|<a href="#Comminicate">Comminicate</a>
|
| 24 |
+
</h4>
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
**Fun-CineForge** contains an end-to-end dataset pipeline for producing large-scale dubbing datasets and an MLLM-based dubbing model designed for diverse cinematic scenes. Using this pipeline, we constructed the first large-scale Chinese television dubbing dataset CineDub-CN, which includes rich annotations and diverse scenes. In monologue, narration, dialogue, and multi-speaker scenes, our dubbing model consistently outperforms state-of-the-art methods in terms of audio quality, lip-sync, timbre transition, and instruction following.
|
| 28 |
+
|
| 29 |
+
<a name="Dataset&Demo"></a>
|
| 30 |
+
## Dataset & Demo 🎬
|
| 31 |
+
You can access [https://funcineforge.github.io/](https://funcineforge.github.io/) to get our CineDub-CN dataset samples and demo samples.
|
| 32 |
+
|
| 33 |
+
<a name="Environment"></a>
|
| 34 |
+
## Environmental Installation
|
| 35 |
+
|
| 36 |
+
Fun-CineForge relies on Conda and Python environments. Execute **setup.py** to automatically install the entire project environment and open-source model.
|
| 37 |
+
|
| 38 |
+
```shell
|
| 39 |
+
# Conda
|
| 40 |
+
git clone git@github.com:FunAudioLLM/FunCineForge.git
|
| 41 |
+
conda create -n FunCineForge python=3.10 -y && conda activate FunCineForge
|
| 42 |
+
sudo apt-get install ffmpeg
|
| 43 |
+
# Initial settings
|
| 44 |
+
python setup.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
<a name="Dataset-Pipeline"></a>
|
| 48 |
+
## Dataset Pipeline 🔨
|
| 49 |
+
|
| 50 |
+
### Data collection
|
| 51 |
+
If you want to produce your own data,
|
| 52 |
+
we recommend that you refer to the following requirements to collect the corresponding movies or television series.
|
| 53 |
+
|
| 54 |
+
1. Video source: TV dramas or movies, non documentaries, with more monologues or dialogue scenes, clear and unobstructed faces (such as without masks and veils).
|
| 55 |
+
2. Speech Requirements: Standard pronunciation, clear articulation, prominent human voice. Avoid materials with strong dialects, excessive background noise, or strong colloquialism.
|
| 56 |
+
3. Image Requirements: High resolution, clear facial details, sufficient lighting, avoiding extremely dark or strong backlit scenes.
|
| 57 |
+
|
| 58 |
+
### How to use
|
| 59 |
+
|
| 60 |
+
- [1] Standardize video format and name; trim the beginning and end of long videos; extract the audio from the trimmed video. (default is to trim 10 seconds from both the beginning and end.)
|
| 61 |
+
```shell
|
| 62 |
+
python normalize_trim.py --root datasets/raw_zh --intro 10 --outro 10
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
- [2] [Speech Separation](./speech_separation/README.md). The audio is used to separate the vocals from the instrumental music.
|
| 66 |
+
```shell
|
| 67 |
+
cd speech_separation
|
| 68 |
+
python run.py --root datasets/clean/zh --gpus 0 1 2 3
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
- [3] [VideoClipper](./video_clip/README.md). For long videos, VideoClipper is used to obtain sentence-level subtitle files and clip the long video into segments based on timestamps. Now it supports bilingualism in both Chinese and English. Below is an example in Chinese. It is recommended to use gpu acceleration for English.
|
| 72 |
+
```shell
|
| 73 |
+
cd video_clip
|
| 74 |
+
bash run.sh --stage 1 --stop_stage 2 --input datasets/raw_zh --output datasets/clean/zh --lang zh --device cpu
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
- Video duration limit and check for cleanup. (Without --execute, only pre-deleted files will be printed. After checking, add --execute to confirm the deletion.)
|
| 78 |
+
```shell
|
| 79 |
+
python clean_video.py --root datasets/clean/zh
|
| 80 |
+
python clean_srt.py --root datasets/clean/zh --lang zh
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
- [4] [Speaker Diarization](./speaker_diarization/README.md). Multimodal active speaker recognition obtains RTTM files; identifies the speaker's facial frames, extracts frame-level speaker face and lip raw data.
|
| 84 |
+
```shell
|
| 85 |
+
cd speaker_diarization
|
| 86 |
+
bash run.sh --stage 1 --stop_stage 4 --hf_access_token hf_xxx --root datasets/clean/zh --gpus "0 1 2 3"
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
- (Reference) Extract speech tokens based on the CosyVoice3 tokenizer for llm training.
|
| 90 |
+
```shell
|
| 91 |
+
python speech_tokenizer.py --root datasets/clean/zh
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
- [5] Multimodal CoT Correction. Based on general-purpose MLLMs, the system uses audio, ASR text, and RTTM files as input. It leverages Chain-of-Thought (CoT) reasoning to extract clues and corrects the results of the specialized models. It also annotates character age, gender, and vocal timbre. Experimental results show that this strategy reduces the CER from 4.53% to 0.94% and the speaker diarization error rate from 8.38% to 1.20%, achieving quality comparable to or even better than manual transcription. Adding the --resume enables breakpoint COT inference to prevent wasted resources from repeated COT inferences. Now supports both Chinese and English.
|
| 95 |
+
```shell
|
| 96 |
+
python cot.py --root_dir datasets/clean/zh --lang zh --provider google --model gemini-3-pro-preview --api_key xxx --resume
|
| 97 |
+
python cot.py --root_dir datasets/clean/en --lang en --provider google --model gemini-3-pro-preview --api_key xxx --resume
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
- The construction of the dataset retrieval file will read all production data, perform bidirectional verification of script content and speaker separation results.
|
| 101 |
+
```shell
|
| 102 |
+
python build_datasets.py --root_zh datasets/clean/zh --root_en datasets/clean/en --out_dir datasets/clean --save
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
<a name="Dubbing-Model"></a>
|
| 106 |
+
## Dubbing Model ⚙️
|
| 107 |
+
We've open-sourced the inference code and the **infer.sh** script, and provided some test cases in the data folder for your experience. Inference requires a consumer-grade GPU. Run the following command:
|
| 108 |
+
|
| 109 |
+
```shell
|
| 110 |
+
cd exps
|
| 111 |
+
bash infer.sh
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
The API for multi-speaker dubbing from raw videos and SRT scripts is under development ...
|
| 115 |
+
|
| 116 |
+
<a name="Recent-Updates"></a>
|
| 117 |
+
## Recent Updates 🚀
|
| 118 |
+
- 2025/12/18: Fun-CineForge dataset pipeline toolkit is online! 🔥
|
| 119 |
+
- 2026/01/19: Chinese demo samples and CineDub-CN dataset samples released. 🔥
|
| 120 |
+
- 2026/01/25: Fix some environmental and operational issues.
|
| 121 |
+
- 2026/02/09: Optimized the data pipeline and added support for English videos.
|
| 122 |
+
- 2026/03/05: English demo samples and CineDub-EN dataset samples released. 🔥
|
| 123 |
+
- 2026/03/16: Open source inference code and checkpoints. 🔥
|
| 124 |
+
|
| 125 |
+
<a name="Publication"></a>
|
| 126 |
+
## Publication 📚
|
| 127 |
+
If you use our dataset or code, please cite the following paper:
|
| 128 |
+
<pre>
|
| 129 |
+
@misc{liu2026funcineforgeunifieddatasettoolkit,
|
| 130 |
+
title={FunCineForge: A Unified Dataset Toolkit and Model for Zero-Shot Movie Dubbing in Diverse Cinematic Scenes},
|
| 131 |
+
author={Jiaxuan Liu and Yang Xiang and Han Zhao and Xiangang Li and Zhenhua Ling},
|
| 132 |
+
year={2026},
|
| 133 |
+
eprint={2601.14777},
|
| 134 |
+
archivePrefix={arXiv},
|
| 135 |
+
primaryClass={cs.CV},
|
| 136 |
+
}
|
| 137 |
+
</pre>
|
| 138 |
+
|
| 139 |
+
<a name="Comminicate"></a>
|
| 140 |
+
## Comminicate 🍟
|
| 141 |
+
The Fun-CineForge open-source project is developed and maintained by the Tongyi Lab Speech Team and a student from NERCSLIP, University of Science and Technology of China.
|
| 142 |
+
We welcome you to participate in discussions on Fun-CineForge [GitHub Issues](https://github.com/FunAudioLLM/FunCineForge/issues) or contact us for collaborative development.
|
| 143 |
+
For any questions, you can contact the [developer](mailto:jxliu@mail.ustc.edu.cn).
|
| 144 |
+
|
| 145 |
+
⭐ Hope you will support Fun-CineForge. Thank you.
|
| 146 |
+
|
| 147 |
+
### Disclaimer
|
| 148 |
+
|
| 149 |
+
This repository contains research artifacts:
|
| 150 |
+
|
| 151 |
+
⚠️ Currently not a commercial product of Tongyi Lab.
|
| 152 |
+
|
| 153 |
+
⚠️ Released for academic research / cutting-edge exploration purposes
|
| 154 |
+
|
| 155 |
+
⚠️ CineDub Dataset samples are subject to specific license terms.
|
README_zh.md
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### <p align="center">「[English](./README.md) | 简体中文」</p>
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<b>🎬 Fun-CineForge:一种用于多样化影视场景零样本配音的统一数据集管道和模型</b>
|
| 5 |
+
</p>
|
| 6 |
+
|
| 7 |
+
<div align="center">
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
<a href=""><img src="https://img.shields.io/badge/OS-Linux-orange.svg"></a>
|
| 11 |
+
<a href=""><img src="https://img.shields.io/badge/Python->=3.8-aff.svg"></a>
|
| 12 |
+
<a href=""><img src="https://img.shields.io/badge/Pytorch->=2.1-blue"></a>
|
| 13 |
+
</div>
|
| 14 |
+
|
| 15 |
+
<div align="center">
|
| 16 |
+
<h4><a href="#数据集&样例">数据集 & 样例</a>
|
| 17 |
+
|<a href="#环境安装">环境安装</a>
|
| 18 |
+
|<a href="#数据集管道">数据集管道</a>
|
| 19 |
+
|<a href="#配音模型">配音模型</a>
|
| 20 |
+
|<a href="#近期更新">近期更新</a>
|
| 21 |
+
|<a href="#发表">发表</a>
|
| 22 |
+
|<a href="#社区交流">社区交流</a>
|
| 23 |
+
</h4>
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
**Fun-CineForge** 包含一个生产大规模配音数据集的端到端数据集管道,和一个基于多模态大模型的配音模型,该模型专为多样的电影场景而设计。利用该管道,我们构建了首个大规模中文电视剧配音数据集 CineDub-CN,该数据集包含丰富的标注和多样化的场景。在独白、旁白、对话和多说话人场景中,我们的配音模型在音频质量、唇形同步、音色转换和指令遵循等方面全部优于最先进的方法。
|
| 27 |
+
|
| 28 |
+
<a name="数据集&样例"></a>
|
| 29 |
+
## 数据集 & 样例 🎬
|
| 30 |
+
您可以访问此 [https://funcineforge.github.io/](https://funcineforge.github.io/) 获取我们的 CineDub-CN 数据集和 CineDub-EN 数据集样例和演示样例。
|
| 31 |
+
|
| 32 |
+
<a name="环境安装"></a>
|
| 33 |
+
## 环境安装
|
| 34 |
+
|
| 35 |
+
Fun-CineForge 依赖 Conda 和 Python 环境。执行 **setup.py** 自动安装整个项目环境和开源模型。
|
| 36 |
+
|
| 37 |
+
```shell
|
| 38 |
+
# Conda
|
| 39 |
+
git clone git@github.com:FunAudioLLM/FunCineForge.git
|
| 40 |
+
conda create -n FunCineForge python=3.10 -y && conda activate FunCineForge
|
| 41 |
+
sudo apt-get install ffmpeg
|
| 42 |
+
# 初始化设置
|
| 43 |
+
python setup.py
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
<a name="数据集管道"></a>
|
| 47 |
+
## 数据集管道 🔨
|
| 48 |
+
|
| 49 |
+
### 数据收集
|
| 50 |
+
如果您想自行生产数据,我们建议您参考下面的要求收集相应的电影或影视剧。
|
| 51 |
+
|
| 52 |
+
1. 视频来源:电视剧或电影,非纪录片,人物独白或对话场景较多,人脸清晰且无遮挡(如无面罩、面纱)。
|
| 53 |
+
2. 语音要求:发音标准,吐字清晰,人声突出。避免方言浓重、背景噪音过大或口语感过强的素材。
|
| 54 |
+
3. 图片要求:高分辨率,面部细节清晰,光线充足,避免极端阴暗或强烈逆光的场景。
|
| 55 |
+
|
| 56 |
+
### 使用方法
|
| 57 |
+
|
| 58 |
+
- [1] 将视频格式、名称标准化;裁剪长视频的片头片尾;提取裁剪后视频的音频。(默认是从起止各裁剪 10 秒。)
|
| 59 |
+
```shell
|
| 60 |
+
python normalize_trim.py --root datasets/raw_zh --intro 10 --outro 10
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
- [2] [Speech Separation](./speech_separation/README.md). 音频进行人声乐声分离。
|
| 64 |
+
```shell
|
| 65 |
+
cd speech_separation
|
| 66 |
+
python run.py --root datasets/clean/zh --gpus 0 1 2 3
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
- [3] [VideoClipper](./video_clip/README.md). 对于长视频,使用 VideoClipper 获取句子级别的字幕文件,并根据时间戳将长视频剪辑成片段。现在它支持中英双语。以下是中文示例。英文建议采用 gpu 加速处理。
|
| 70 |
+
```shell
|
| 71 |
+
cd video_clip
|
| 72 |
+
bash run.sh --stage 1 --stop_stage 2 --input datasets/raw_zh --output datasets/clean/zh --lang zh --device cpu
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
- 视频时长限制及清理检查。(若不使用--execute参数,则仅打印已预删除的文件。检查后,若需确认删除,请添加--execute参数。)
|
| 76 |
+
```shell
|
| 77 |
+
python clean_video.py --root datasets/clean/zh
|
| 78 |
+
python clean_srt.py --root datasets/clean/zh --lang zh
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
- [4] [Speaker Diarization](./speaker_diarization/README.md). 多模态主动说话人识别,得到 RTTM 文件;识别说话人的面部帧,提取帧级的说话人面部和唇部原始数据,从面部帧中识别说话帧,提取说话帧的面部特征。
|
| 82 |
+
```shell
|
| 83 |
+
cd speaker_diarization
|
| 84 |
+
bash run.sh --stage 1 --stop_stage 4 --hf_access_token hf_xxx --root datasets/clean/zh --gpus "0 1 2 3"
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
- (参考)基于 CosyVoice3 tokenizer 提取 speech tokens 用于大模型训练。
|
| 88 |
+
```shell
|
| 89 |
+
python speech_tokenizer.py --root datasets/clean/zh
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
- [5] 多模态思维链校正。该系统基于通用多模态大模型,以音频、ASR 抄本和 RTTM 文件为输入,利用思维链推理来提取线索,并校正专用模型的结果,并标注人物年龄、性别和音色。实验结果表明,该策略将词错率从4.53% 降低到 0.94%,说话人识别错误率从 8.38% 降低到 1.20%,其质量可与人工转录相媲美,甚至更优。添加--resume选项可启用断点思维链推理,以避免重复思维链推理造成的资源浪费。现支持中英文。
|
| 93 |
+
```shell
|
| 94 |
+
python cot.py --root_dir datasets/clean/zh --lang zh --provider google --model gemini-3-pro-preview --api_key xxx --resume
|
| 95 |
+
python cot.py --root_dir datasets/clean/en --lang en --provider google --model gemini-3-pro-preview --api_key xxx --resume
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
- 数据集检索文件的构建会读取生产的所有数据,双向校验脚本内容和说话人分离结果。
|
| 99 |
+
```shell
|
| 100 |
+
python build_datasets.py --root_zh datasets/clean/zh --root_en datasets/clean/en --out_dir datasets/clean --save
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
<a name="Dubbing-Model"></a>
|
| 104 |
+
## 配音模型 ⚙️
|
| 105 |
+
我们开源了推理代码和 **infer.sh** 脚本,在 data 文件夹中提供了一些测试样例,以供体验。推理需要一张消费级 GPU。按下面的命令运行:
|
| 106 |
+
|
| 107 |
+
```shell
|
| 108 |
+
cd exps
|
| 109 |
+
bash infer.sh
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
从原始视频和 SRT 脚本进行多人配音的 API 调用接口在开发中 ...
|
| 113 |
+
|
| 114 |
+
<a name="近期更新"></a>
|
| 115 |
+
## 近期更新 🚀
|
| 116 |
+
- 2025/12/18:Fun-CineForge 数据集管道工具包上线!🔥
|
| 117 |
+
- 2026/01/19:发布中文演示样例和 CineDub-CN 数据集样例。 🔥
|
| 118 |
+
- 2026/01/25:修复了一些环境和运行问题。
|
| 119 |
+
- 2026/02/09:优化了数据管道,新增支持英文视频的能力。
|
| 120 |
+
- 2026/03/05:发布英文演示样例和 CineDub-EN 数据集样例。 🔥
|
| 121 |
+
- 2026/03/16:开源推理代码和 checkpoints。 🔥
|
| 122 |
+
|
| 123 |
+
<a name="发表"></a>
|
| 124 |
+
## 发表 📚
|
| 125 |
+
如果您使用了我们的数据集或代码,请引用以下论文:
|
| 126 |
+
<pre>
|
| 127 |
+
@misc{liu2026funcineforgeunifieddatasettoolkit,
|
| 128 |
+
title={FunCineForge: A Unified Dataset Toolkit and Model for Zero-Shot Movie Dubbing in Diverse Cinematic Scenes},
|
| 129 |
+
author={Jiaxuan Liu and Yang Xiang and Han Zhao and Xiangang Li and Zhenhua Ling},
|
| 130 |
+
year={2026},
|
| 131 |
+
eprint={2601.14777},
|
| 132 |
+
archivePrefix={arXiv},
|
| 133 |
+
primaryClass={cs.CV},
|
| 134 |
+
}
|
| 135 |
+
</pre>
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
<a name="社区交流"></a>
|
| 139 |
+
## 社区交流 🍟
|
| 140 |
+
Fun-CineForge 开源项目由通义实验室语音团队和中国科学技术大学 NERCSLIP 学生开发并维护,我们欢迎您在 Fun-CineForge [GitHub Issues](https://github.com/FunAudioLLM/FunCineForge/issues) 参与问题讨论,或联系我们合作开发。
|
| 141 |
+
有任何问题您可以联系[开发者](mailto:jxliu@mail.ustc.edu.cn)。
|
| 142 |
+
|
| 143 |
+
⭐ 希望您你支持 Fun-CineForge,谢谢。
|
| 144 |
+
|
| 145 |
+
### 免责声明
|
| 146 |
+
|
| 147 |
+
该仓库包含的研究成果:
|
| 148 |
+
|
| 149 |
+
⚠️ 目前非通义实验室商业化产品
|
| 150 |
+
|
| 151 |
+
⚠️ 供学术研究/前沿探索用途
|
| 152 |
+
|
| 153 |
+
⚠️ 数据集样例受特定许可条款约束
|
app.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import typing
|
| 7 |
+
import time
|
| 8 |
+
import shutil
|
| 9 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
|
| 10 |
+
from moviepy.audio.AudioClip import CompositeAudioClip
|
| 11 |
+
from modelscope import snapshot_download
|
| 12 |
+
from utils import get_video_duration, generate_jsonl_data, validate_timestamps, parse_srt_content
|
| 13 |
+
# 尝试导入模型库
|
| 14 |
+
from funcineforge import AutoFrontend
|
| 15 |
+
from speaker_diarization.run import GlobalModels
|
| 16 |
+
snapshot_download(
|
| 17 |
+
repo_id="FunAudioLLM/Fun-CineForge",
|
| 18 |
+
revision='v1.0.0',
|
| 19 |
+
local_dir='pretrained_models',
|
| 20 |
+
ignore_patterns=[
|
| 21 |
+
"*.md",
|
| 22 |
+
".git*",
|
| 23 |
+
"funcineforge_zh_en/llm/config.yaml"
|
| 24 |
+
],
|
| 25 |
+
repo_type="model",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ==================== 配置区域 ====================
|
| 30 |
+
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
SERVER_PORT = 7860
|
| 32 |
+
TEMP_DIR = "temp_workdir"
|
| 33 |
+
CONFIG_FRONTEND = "decode_conf/diar.yaml"
|
| 34 |
+
CONFIG_MODEL = "decode_conf/decode.yaml"
|
| 35 |
+
PRETRAIN = "pretrained_models"
|
| 36 |
+
MAX_SEGMENTS = 8 # UI 片段数上限
|
| 37 |
+
DEFAULT_VIDEO_PATH="data/sample.mp4"
|
| 38 |
+
DEFAULT_AUDIO_PATH="data/ref.wav"
|
| 39 |
+
DEFAULT_TEXT = "我军无粮,利在急战。今乘魏兵新败,不敢出兵,出其不意,乘机退去,方可平安无事。"
|
| 40 |
+
DEFAULT_CLUE = "一位中年男性以沉稳但略带担忧的语调,分析我军无粮急战的困境与敌军心败状态。他随即提出一种撤退方案,整体流露出对战局的担忧和谋求生路。"
|
| 41 |
+
# 全局模型实例(延迟加载)
|
| 42 |
+
model_pool: typing.Optional[GlobalModels] = None
|
| 43 |
+
engine = None
|
| 44 |
+
|
| 45 |
+
def init_engine():
|
| 46 |
+
"""延迟加载模型,避免启动时卡住"""
|
| 47 |
+
global engine
|
| 48 |
+
engine = AutoFrontend(PRETRAIN, CONFIG_MODEL, TEMP_DIR, DEVICE)
|
| 49 |
+
return engine
|
| 50 |
+
|
| 51 |
+
def init_frontend_models():
|
| 52 |
+
global model_pool
|
| 53 |
+
model_pool = GlobalModels(
|
| 54 |
+
hf_token = None,
|
| 55 |
+
config_path = CONFIG_FRONTEND,
|
| 56 |
+
pretrained_dir= PRETRAIN,
|
| 57 |
+
device = DEVICE,
|
| 58 |
+
pool_sizes = {"face": 1, "asd": 1, "fr": 1},
|
| 59 |
+
batch_size = 1,
|
| 60 |
+
preload = True
|
| 61 |
+
)
|
| 62 |
+
return model_pool
|
| 63 |
+
|
| 64 |
+
# ==================== Gradio UI 逻辑 ====================
|
| 65 |
+
|
| 66 |
+
def create_segments_ui():
|
| 67 |
+
segments = []
|
| 68 |
+
accordions = []
|
| 69 |
+
for i in range(MAX_SEGMENTS):
|
| 70 |
+
with gr.Accordion(f"🎬 配音片段 {i + 1}", open=(i == 0), visible=(i == 0)) as acc:
|
| 71 |
+
accordions.append(acc)
|
| 72 |
+
with gr.Row():
|
| 73 |
+
text_input = gr.Textbox(label="📝 配音文本内容", placeholder="输入台词...", lines=2, scale=3, elem_id=f"text_{i}")
|
| 74 |
+
clue_input = gr.Textbox(label="💡 线索描述", placeholder="一位中年男性角色语气沉稳且坚定,流露出对自身忠诚的强烈自信与决心。整体情感是忠贞不渝的承诺和不容置疑的信念。", lines=2, scale=3, elem_id=f"clue_{i}")
|
| 75 |
+
with gr.Row():
|
| 76 |
+
start_time = gr.Number(label="⏱️ 起始时间 (s)", value=0.0 + i*5, precision=2, scale=2, elem_id=f"start_{i}")
|
| 77 |
+
end_time = gr.Number(label="⏱️ 终止时间 (s)", value=5.0 + i*5, precision=2, scale=2, elem_id=f"end_{i}")
|
| 78 |
+
with gr.Row():
|
| 79 |
+
age_input = gr.Dropdown(label="👤 年龄", choices=["儿童", "青年", "中年", "中老年", "老年", "不确定"], value="不确定", scale=2, elem_id=f"age_{i}")
|
| 80 |
+
gender_input = gr.Dropdown(label="👤 性别", choices=["男", "女", "不确定"], value="不确定", scale=2, elem_id=f"gender_{i}")
|
| 81 |
+
with gr.Row():
|
| 82 |
+
ref_audio = gr.Audio(label="🎤 参考语音 (可选,默认以视频原声作为参考音频)", sources=["upload"], type="filepath", scale=4,elem_id=f"audio_{i}")
|
| 83 |
+
load_audio_btn = gr.Button("📂 加载示例音频", size="sm", variant="secondary", scale=1) if i == 0 else None
|
| 84 |
+
with gr.Row():
|
| 85 |
+
enable_check = gr.Checkbox(label="启用此片段", value=(i == 0), scale=1, elem_id=f"enable_{i}")
|
| 86 |
+
|
| 87 |
+
segments.append({
|
| 88 |
+
"accordion": acc, "text": text_input, "clue": clue_input, "start": start_time, "end": end_time,
|
| 89 |
+
"age": age_input, "gender": gender_input, "audio": ref_audio,
|
| 90 |
+
"enable": enable_check, "index": i, "load_audio_btn": load_audio_btn})
|
| 91 |
+
return segments, accordions
|
| 92 |
+
|
| 93 |
+
def add_segment_fn(current_count):
|
| 94 |
+
"""点击加号:显示下一个片段,到达上限则禁用按钮"""
|
| 95 |
+
if current_count >= MAX_SEGMENTS:
|
| 96 |
+
return [current_count] + [gr.update() for _ in range(MAX_SEGMENTS)] + [gr.update(interactive=False, value=f"已达上限 ({MAX_SEGMENTS})")]
|
| 97 |
+
|
| 98 |
+
new_count = current_count + 1
|
| 99 |
+
vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
|
| 100 |
+
btn = gr.update(interactive=(new_count < MAX_SEGMENTS), value="➕新片段")
|
| 101 |
+
return [new_count] + vis + [btn]
|
| 102 |
+
|
| 103 |
+
def load_srt_fn(srt_file, current_count):
|
| 104 |
+
empty_fields = [gr.update() for _ in range(MAX_SEGMENTS * 4)]
|
| 105 |
+
empty_vis = [gr.update() for _ in range(MAX_SEGMENTS)]
|
| 106 |
+
if not srt_file:
|
| 107 |
+
return [current_count] + empty_fields + empty_vis + [gr.update()]
|
| 108 |
+
try:
|
| 109 |
+
with open(srt_file, 'r', encoding='utf-8-sig') as f:
|
| 110 |
+
content = f.read()
|
| 111 |
+
except Exception as e:
|
| 112 |
+
gr.Warning(f"读取 SRT 文件失败: {e}")
|
| 113 |
+
return [current_count] + empty_fields + empty_vis + [gr.update()]
|
| 114 |
+
parsed = parse_srt_content(content)
|
| 115 |
+
if not parsed:
|
| 116 |
+
print(" 未解析到有效字幕,请检查 SRT 格式")
|
| 117 |
+
return [current_count] + empty_fields + empty_vis + [gr.update()]
|
| 118 |
+
updates = []
|
| 119 |
+
for i in range(MAX_SEGMENTS):
|
| 120 |
+
if i < len(parsed):
|
| 121 |
+
seg = parsed[i]
|
| 122 |
+
updates.append(gr.update(value=seg['text']))
|
| 123 |
+
updates.append(gr.update(value=round(seg['start'], 2)))
|
| 124 |
+
updates.append(gr.update(value=round(seg['end'], 2)))
|
| 125 |
+
updates.append(gr.update(value=True))
|
| 126 |
+
else:
|
| 127 |
+
updates.append(gr.update(value=""))
|
| 128 |
+
updates.append(gr.update(value=0.0))
|
| 129 |
+
updates.append(gr.update(value=5.0 + i*5))
|
| 130 |
+
updates.append(gr.update(value=False))
|
| 131 |
+
new_count = min(len(parsed), MAX_SEGMENTS)
|
| 132 |
+
vis = [gr.update(visible=(i < new_count)) for i in range(MAX_SEGMENTS)]
|
| 133 |
+
btn = gr.update(interactive=(new_count < MAX_SEGMENTS))
|
| 134 |
+
if len(parsed) > MAX_SEGMENTS:
|
| 135 |
+
gr.Warning(f"SRT 包含 {len(parsed)} 个片段,已截取前 {MAX_SEGMENTS} 条")
|
| 136 |
+
|
| 137 |
+
return [new_count] + updates + vis + [btn]
|
| 138 |
+
|
| 139 |
+
def process_dubbing(video_file, *segment_inputs, progress=gr.Progress()):
|
| 140 |
+
"""主推理流程"""
|
| 141 |
+
if not video_file:
|
| 142 |
+
return None, "❌ 请上传视频文件"
|
| 143 |
+
|
| 144 |
+
video_duration = get_video_duration(video_file)
|
| 145 |
+
if video_duration <= 0:
|
| 146 |
+
return None, "❌ 无法获取视频时长,请检查视频文件"
|
| 147 |
+
|
| 148 |
+
if os.path.exists(TEMP_DIR):
|
| 149 |
+
try:
|
| 150 |
+
shutil.rmtree(TEMP_DIR)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
return None, f"❌ 清空临时目录失败:{e}"
|
| 153 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
| 154 |
+
|
| 155 |
+
# 解析 segment_inputs
|
| 156 |
+
segments_data = []
|
| 157 |
+
for i in range(MAX_SEGMENTS):
|
| 158 |
+
base_idx = i * 8
|
| 159 |
+
enable = segment_inputs[base_idx + 7] # enable_check
|
| 160 |
+
if not enable: continue
|
| 161 |
+
text = segment_inputs[base_idx + 0]
|
| 162 |
+
if not text or not text.strip(): continue
|
| 163 |
+
|
| 164 |
+
clue = segment_inputs[base_idx + 1]
|
| 165 |
+
start = segment_inputs[base_idx + 2]
|
| 166 |
+
end = segment_inputs[base_idx + 3]
|
| 167 |
+
age = segment_inputs[base_idx + 4]
|
| 168 |
+
gender = segment_inputs[base_idx + 5]
|
| 169 |
+
ref_audio = segment_inputs[base_idx + 6]
|
| 170 |
+
|
| 171 |
+
errors = validate_timestamps(start, end, video_duration)
|
| 172 |
+
if errors:
|
| 173 |
+
return None, f"❌ 片段 {i+1} 时间戳错误:\n" + "\n".join(errors)
|
| 174 |
+
|
| 175 |
+
data = {
|
| 176 |
+
"text": str(text).strip(),
|
| 177 |
+
"clue": str(clue) if clue else "",
|
| 178 |
+
"start": float(start) if start else 0.0,
|
| 179 |
+
"end": float(end) if end else 0.0,
|
| 180 |
+
"age": str(age) if age else "不确定",
|
| 181 |
+
"gender": str(gender) if gender else "不确定",
|
| 182 |
+
"ref_audio": str(ref_audio) if ref_audio else ""
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
segments_data.append(data)
|
| 186 |
+
|
| 187 |
+
if not segments_data:
|
| 188 |
+
return None, "❌ 有效片段数据为空,请启用并填写至少一个片段"
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
progress(0.1, desc="📋 预处理视频,生成 JSONL 数据...")
|
| 192 |
+
frontend = init_frontend_models()
|
| 193 |
+
jsonl_path, jsonl_items = generate_jsonl_data(frontend, video_file, segments_data, TEMP_DIR, video_duration)
|
| 194 |
+
report_lines = [f"✅ 任务完成!共生成 **{len(jsonl_items)}** 个片段数据。\n", "详细 JSONL 数据预览:**", "=" * 40]
|
| 195 |
+
for idx, item in enumerate(jsonl_items):
|
| 196 |
+
report_lines.extend([f"\n---片段 #{idx + 1} ---", json.dumps(item, ensure_ascii=False, indent=2), "-" * 40])
|
| 197 |
+
full_report = "\n".join(report_lines)
|
| 198 |
+
|
| 199 |
+
progress(0.3, desc="🔄 FunCineForge 模型加载中...")
|
| 200 |
+
|
| 201 |
+
eng = init_engine()
|
| 202 |
+
if eng and jsonl_items:
|
| 203 |
+
try:
|
| 204 |
+
progress(0.5, desc="🚀 FunCineForge 模型推理中...")
|
| 205 |
+
eng.inference(jsonl_path)
|
| 206 |
+
|
| 207 |
+
progress(0.8, desc="🎵 正在将配音语音粘贴回静音视频...")
|
| 208 |
+
|
| 209 |
+
output_wav_dir = os.path.join(TEMP_DIR, "wav")
|
| 210 |
+
final_video_path = os.path.join(TEMP_DIR, "dubbed_video.mp4")
|
| 211 |
+
|
| 212 |
+
if not os.path.exists(output_wav_dir):
|
| 213 |
+
return None, f"⚠️ 未找到音频输出目录:{output_wav_dir}"
|
| 214 |
+
|
| 215 |
+
wav_files = sorted([f for f in os.listdir(output_wav_dir) if f.endswith('.wav')])
|
| 216 |
+
if not wav_files:
|
| 217 |
+
return None, f"⚠️ 未生成任何音频文件:{output_wav_dir}"
|
| 218 |
+
|
| 219 |
+
time_mapping = {}
|
| 220 |
+
for item in jsonl_items:
|
| 221 |
+
for wf in wav_files:
|
| 222 |
+
if wf.startswith(item['utt']):
|
| 223 |
+
time_mapping[wf] = float(item['start'])
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
original_clip = VideoFileClip(video_file)
|
| 227 |
+
video_duration = original_clip.duration
|
| 228 |
+
is_silent = original_clip.audio is None
|
| 229 |
+
video_only = original_clip if is_silent else original_clip.without_audio()
|
| 230 |
+
audio_clips = []
|
| 231 |
+
for wav_file, start_time in time_mapping.items():
|
| 232 |
+
wav_path = os.path.join(output_wav_dir, wav_file)
|
| 233 |
+
audio_clip = AudioFileClip(wav_path).with_start(start_time)
|
| 234 |
+
audio_clips.append(audio_clip)
|
| 235 |
+
|
| 236 |
+
final_audio = CompositeAudioClip(audio_clips)
|
| 237 |
+
if final_audio.duration < video_duration:
|
| 238 |
+
final_audio = final_audio.with_duration(video_duration)
|
| 239 |
+
final_clip = video_only.with_audio(final_audio)
|
| 240 |
+
final_clip.write_videofile(
|
| 241 |
+
final_video_path,
|
| 242 |
+
codec='libx264',
|
| 243 |
+
audio_codec='aac',
|
| 244 |
+
preset='veryfast',
|
| 245 |
+
threads=8,
|
| 246 |
+
fps=original_clip.fps,
|
| 247 |
+
logger=None
|
| 248 |
+
)
|
| 249 |
+
original_clip.close(); video_only.close()
|
| 250 |
+
for ac in audio_clips: ac.close()
|
| 251 |
+
if 'final_audio' in locals(): final_audio.close()
|
| 252 |
+
final_clip.close()
|
| 253 |
+
|
| 254 |
+
progress(1.0, desc="✅ 配音完成")
|
| 255 |
+
return final_video_path, full_report
|
| 256 |
+
except Exception as e:
|
| 257 |
+
import traceback; traceback.print_exc()
|
| 258 |
+
if "index out of range" in str(e):
|
| 259 |
+
return None, f"⚠️ 模型推理失败。错误:{str(e)},建议补齐输入的线索描述和说话人属性"
|
| 260 |
+
else:
|
| 261 |
+
return None, f"⚠️ 模型推理失败。错误:{str(e)}"
|
| 262 |
+
else:
|
| 263 |
+
time.sleep(1)
|
| 264 |
+
progress(1.0, desc="模拟完成")
|
| 265 |
+
return video_file, full_report
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
import traceback; traceback.print_exc()
|
| 269 |
+
return None, f"❌ 发生错误:{str(e)}"
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ==================== 主程序 ====================
|
| 273 |
+
|
| 274 |
+
def main():
|
| 275 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
| 276 |
+
with gr.Blocks(
|
| 277 |
+
title="Fun-CineForge 影视配音平台",
|
| 278 |
+
theme=gr.themes.Soft(),
|
| 279 |
+
css="""
|
| 280 |
+
.segment-accordion { margin: 10px 0; }
|
| 281 |
+
.gr-button-primary { background: #1976d2; }
|
| 282 |
+
.gr-button-stop { background: #d32f2f; }
|
| 283 |
+
"""
|
| 284 |
+
) as demo:
|
| 285 |
+
|
| 286 |
+
gr.Markdown("""
|
| 287 |
+
# 🎬 Fun-CineForge
|
| 288 |
+
|
| 289 |
+
**工作流程:** 上传短视频 → 配音片段信息(或上传 .srt 字幕文件) → 上传参考音色(可选) → 预处理、模型加载和推理 → 输出配音视频
|
| 290 |
+
""")
|
| 291 |
+
|
| 292 |
+
with gr.Row():
|
| 293 |
+
with gr.Column(scale=1):
|
| 294 |
+
video_input = gr.Video(label="上传视频", sources=["upload"])
|
| 295 |
+
load_video_btn = gr.Button("📂 加载示例视频", variant="secondary", size="sm")
|
| 296 |
+
srt_input = gr.UploadButton("上传 SRT 字幕", file_types=[".srt"], size="sm", variant="secondary")
|
| 297 |
+
# with gr.Row(elem_classes=["srt-compact"]):
|
| 298 |
+
# srt_input = gr.File(label="上传 SRT 字幕", file_types=[".srt"], height="auto")
|
| 299 |
+
gr.Markdown("### 🎛️ 配音片段配置")
|
| 300 |
+
|
| 301 |
+
segments, accordions = create_segments_ui()
|
| 302 |
+
seg_count_state = gr.State(1) #🔑记录当前可见片段数
|
| 303 |
+
add_segment_btn = gr.Button("➕添加新片段", size="sm", variant="secondary")
|
| 304 |
+
submit_btn = gr.Button("🚀 开始生成配音", variant="stop", size="lg")
|
| 305 |
+
|
| 306 |
+
with gr.Column(scale=1):
|
| 307 |
+
video_output = gr.Video(label="📺 配音后视频", autoplay=True)
|
| 308 |
+
|
| 309 |
+
status_text = gr.Textbox(label="结果状态", interactive=False, lines=2)
|
| 310 |
+
|
| 311 |
+
gr.Markdown("""
|
| 312 |
+
### 📝 使用说明
|
| 313 |
+
| 字段 | 说明 |
|
| 314 |
+
|------|------|
|
| 315 |
+
| 配音文本 | 该片段台词内容(支持中/英) |
|
| 316 |
+
| 线索描述 | 请参考样例格式,阐述配音要求,重点描述说话人的性别年龄、语气和情感 |
|
| 317 |
+
| 时间戳 | 起止时间戳 (可精确到毫秒),模型对时间戳敏感,建议紧邻有声区间。时长 ≤30s/片段 |
|
| 318 |
+
| 年龄/性别 | 说话人属性选项 |
|
| 319 |
+
| 参考语音 | 音色克隆参考 (可选) |
|
| 320 |
+
|
| 321 |
+
**⚠️ 注意:** 确保每个片段的时间戳不重叠,且时间戳不超过视频总时长。模型会根据片段的时��长度进行强制时间对齐,弱监督对齐唇部运动。
|
| 322 |
+
""")
|
| 323 |
+
|
| 324 |
+
# ==================== 事件绑定 ====================
|
| 325 |
+
|
| 326 |
+
# 收集所有片段组件作为输入
|
| 327 |
+
segment_inputs = []
|
| 328 |
+
for seg in segments:
|
| 329 |
+
segment_inputs.extend([
|
| 330 |
+
seg["text"],
|
| 331 |
+
seg["clue"],
|
| 332 |
+
seg["start"],
|
| 333 |
+
seg["end"],
|
| 334 |
+
seg["age"],
|
| 335 |
+
seg["gender"],
|
| 336 |
+
seg["audio"],
|
| 337 |
+
seg["enable"]
|
| 338 |
+
])
|
| 339 |
+
|
| 340 |
+
srt_update_fields = []
|
| 341 |
+
for seg in segments:
|
| 342 |
+
srt_update_fields.extend([seg["text"], seg["start"], seg["end"], seg["enable"]])
|
| 343 |
+
|
| 344 |
+
# 动态添加片段
|
| 345 |
+
add_segment_btn.click(
|
| 346 |
+
fn=add_segment_fn,
|
| 347 |
+
inputs=[seg_count_state],
|
| 348 |
+
outputs=[seg_count_state] + accordions + [add_segment_btn]
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# SRT 加载
|
| 352 |
+
srt_input.upload(
|
| 353 |
+
fn=load_srt_fn,
|
| 354 |
+
inputs=[srt_input, seg_count_state],
|
| 355 |
+
outputs=[seg_count_state] + srt_update_fields + accordions + [add_segment_btn]
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# 主推理
|
| 359 |
+
submit_btn.click(
|
| 360 |
+
fn=process_dubbing,
|
| 361 |
+
inputs=[video_input] + segment_inputs,
|
| 362 |
+
outputs=[video_output, status_text]
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# 视频上传联动时间戳
|
| 366 |
+
def update_timestamps(video):
|
| 367 |
+
if not video: return [gr.update() for _ in range(MAX_SEGMENTS * 2)]
|
| 368 |
+
dur = get_video_duration(video)
|
| 369 |
+
updates = []
|
| 370 |
+
for i in range(MAX_SEGMENTS):
|
| 371 |
+
updates.append(gr.update(value=0.0))
|
| 372 |
+
updates.append(gr.update(value=dur))
|
| 373 |
+
return updates
|
| 374 |
+
|
| 375 |
+
def load_default_video_fn():
|
| 376 |
+
return DEFAULT_VIDEO_PATH, DEFAULT_TEXT, DEFAULT_CLUE
|
| 377 |
+
|
| 378 |
+
def load_default_audio_fn():
|
| 379 |
+
return DEFAULT_AUDIO_PATH
|
| 380 |
+
|
| 381 |
+
load_video_btn.click(
|
| 382 |
+
fn=load_default_video_fn,
|
| 383 |
+
inputs=[],
|
| 384 |
+
outputs=[video_input, segments[0]["text"], segments[0]["clue"]]
|
| 385 |
+
).then(
|
| 386 |
+
fn=update_timestamps,
|
| 387 |
+
inputs=[video_input],
|
| 388 |
+
outputs=[segment_inputs[i] for i in range(len(segment_inputs)) if i % 8 in [2, 3]]
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
video_input.change(
|
| 392 |
+
fn=update_timestamps,
|
| 393 |
+
inputs=[video_input],
|
| 394 |
+
outputs=[comp for pair in zip(segment_inputs[2::8], segment_inputs[3::8]) for comp in pair]
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if segments and segments[0]["load_audio_btn"]:
|
| 398 |
+
segments[0]["load_audio_btn"].click(
|
| 399 |
+
fn=load_default_audio_fn,
|
| 400 |
+
inputs=[],
|
| 401 |
+
outputs=[segments[0]["audio"]]
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# ==================== 启动服务 ====================
|
| 405 |
+
|
| 406 |
+
demo.launch(
|
| 407 |
+
server_name="0.0.0.0",
|
| 408 |
+
server_port=SERVER_PORT,
|
| 409 |
+
share=False,
|
| 410 |
+
show_error=True,
|
| 411 |
+
inbrowser=True,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
main()
|
data/ref.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8420568976edb1cf17a63d9fa968aedaf3c0f68cca4dbf75a409876b96ad700b
|
| 3 |
+
size 788876
|
data/sample.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b901981a2213fc7f98cd6424869710e8396eb558ff2ff3e8ab5d52fe427e0ab6
|
| 3 |
+
size 2567737
|
decode_conf/decode.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: FunCineForgeInferModel
|
| 2 |
+
index_ds: FunCineForgeDS
|
| 3 |
+
xvec_model: pretrained_models/funcineforge_zh_en/camplus.onnx
|
| 4 |
+
model_conf: {}
|
| 5 |
+
|
| 6 |
+
dataset_conf:
|
| 7 |
+
# face is from the video, vocal is the reference audio, extract speaker ID and start-end timestamp from dialogue
|
| 8 |
+
load_meta_data_key: "text,clue,face,dialogue,vocal,video"
|
| 9 |
+
sos: 6561
|
| 10 |
+
eos: 6562
|
| 11 |
+
turn_of_speech: 6563
|
| 12 |
+
fill_token: 6564
|
| 13 |
+
ignore_id: -100
|
| 14 |
+
startofclue_token: 151646
|
| 15 |
+
endofclue_token: 151647
|
| 16 |
+
frame_shift: 25 # ms
|
| 17 |
+
timebook_size: 1500 # 60 * 25 = 1500
|
| 18 |
+
pangbai: 1500
|
| 19 |
+
dubai: 1501
|
| 20 |
+
duihua: 1502
|
| 21 |
+
duoren: 1503
|
| 22 |
+
male: 1504
|
| 23 |
+
female: 1505
|
| 24 |
+
child: 1506
|
| 25 |
+
youth: 1507
|
| 26 |
+
adult: 1508
|
| 27 |
+
middle: 1509
|
| 28 |
+
elderly: 1510
|
| 29 |
+
speaker_id_start: 1511
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
sampling: ras
|
| 33 |
+
lm_use_prompt: true
|
| 34 |
+
fm_use_prompt: true
|
| 35 |
+
use_llm_cache: true
|
| 36 |
+
seed: 0
|
| 37 |
+
max_length: 1500 # 60s * 25 fps
|
| 38 |
+
min_length: 50 # 2s * 25 fps
|
| 39 |
+
llm_dtype: fp32
|
| 40 |
+
fm_dtype: fp32
|
| 41 |
+
voc_dtype: fp32
|
| 42 |
+
batch_size: 1
|
decode_conf/diar.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diarization config
|
| 2 |
+
|
| 3 |
+
fbank_dim: 80
|
| 4 |
+
embedding_size: 192
|
| 5 |
+
|
| 6 |
+
feature_extractor:
|
| 7 |
+
obj: speakerlab.process.processor.FBank
|
| 8 |
+
args:
|
| 9 |
+
n_mels: <fbank_dim>
|
| 10 |
+
sample_rate: <sample_rate>
|
| 11 |
+
mean_nor: True
|
| 12 |
+
|
| 13 |
+
embedding_model:
|
| 14 |
+
obj: speakerlab.models.campplus.DTDNN.CAMPPlus
|
| 15 |
+
args:
|
| 16 |
+
feat_dim: <fbank_dim>
|
| 17 |
+
embedding_size: <embedding_size>
|
| 18 |
+
|
| 19 |
+
# for visual embeddings extraction
|
| 20 |
+
min_track: 10
|
| 21 |
+
num_failed_det: 10
|
| 22 |
+
crop_scale: 0.4
|
| 23 |
+
min_face_size: 1
|
| 24 |
+
face_det_stride: 5 # 每5帧检测一次人脸
|
| 25 |
+
shot_stride: 50
|
| 26 |
+
|
| 27 |
+
# for clustering
|
| 28 |
+
audio_cluster:
|
| 29 |
+
obj: speakerlab.process.cluster.CommonClustering
|
| 30 |
+
args:
|
| 31 |
+
cluster_type: spectral
|
| 32 |
+
min_num_spks: 1
|
| 33 |
+
max_num_spks: 15
|
| 34 |
+
min_cluster_size: 1
|
| 35 |
+
oracle_num: null
|
| 36 |
+
pval: 0.032
|
| 37 |
+
mer_cos: 0.8
|
| 38 |
+
|
| 39 |
+
vision_cluster:
|
| 40 |
+
obj: speakerlab.process.cluster.CommonClustering
|
| 41 |
+
args:
|
| 42 |
+
cluster_type: AHC
|
| 43 |
+
cluster_line: 2
|
| 44 |
+
min_cluster_size: 1
|
| 45 |
+
fix_cos_thr: 0.25
|
| 46 |
+
|
| 47 |
+
cluster:
|
| 48 |
+
obj: speakerlab.process.cluster.JointClustering
|
| 49 |
+
args:
|
| 50 |
+
audio_cluster: <audio_cluster>
|
| 51 |
+
vision_cluster: <vision_cluster>
|
decode_conf/ds_stage0_fp32.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 3 |
+
"gradient_accumulation_steps": 1,
|
| 4 |
+
"steps_per_print": 100,
|
| 5 |
+
"gradient_clipping": 5,
|
| 6 |
+
"fp16": {
|
| 7 |
+
"enabled": false,
|
| 8 |
+
"auto_cast": false,
|
| 9 |
+
"loss_scale": 0,
|
| 10 |
+
"initial_scale_power": 16,
|
| 11 |
+
"loss_scale_window": 1000,
|
| 12 |
+
"hysteresis": 2,
|
| 13 |
+
"consecutive_hysteresis": false,
|
| 14 |
+
"min_loss_scale": 1
|
| 15 |
+
},
|
| 16 |
+
"bf16": {
|
| 17 |
+
"enabled": false
|
| 18 |
+
},
|
| 19 |
+
"zero_force_ds_cpu_optimizer": false,
|
| 20 |
+
"zero_optimization": {
|
| 21 |
+
"stage": 0,
|
| 22 |
+
"offload_optimizer": {
|
| 23 |
+
"device": "none",
|
| 24 |
+
"pin_memory": true
|
| 25 |
+
},
|
| 26 |
+
"allgather_partitions": true,
|
| 27 |
+
"allgather_bucket_size": 5e8,
|
| 28 |
+
"overlap_comm": true,
|
| 29 |
+
"reduce_scatter": true,
|
| 30 |
+
"reduce_bucket_size": 5e8,
|
| 31 |
+
"contiguous_gradients" : true
|
| 32 |
+
}
|
| 33 |
+
}
|
funcineforge/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
funcineforge/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Initialize package."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from funcineforge.auto.auto_model import AutoModel
|
| 5 |
+
from funcineforge.auto.auto_frontend import AutoFrontend
|
| 6 |
+
|
| 7 |
+
os.environ["HYDRA_FULL_ERROR"] = "1"
|
funcineforge/auto/__init__.py
ADDED
|
File without changes
|
funcineforge/auto/auto_frontend.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from funcineforge.utils.hinter import get_logger
|
| 6 |
+
from funcineforge.models.utils import dtype_map
|
| 7 |
+
from funcineforge.datasets import FunCineForgeDS
|
| 8 |
+
|
| 9 |
+
class AutoFrontend:
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
ckpt_path: str,
|
| 13 |
+
config_path: str,
|
| 14 |
+
output_dir: str,
|
| 15 |
+
device: str = "cuda:0"
|
| 16 |
+
):
|
| 17 |
+
self.logger = get_logger(log_level=logging.INFO, local_rank=1, world_size=1)
|
| 18 |
+
self.device = device
|
| 19 |
+
self.output_dir = output_dir
|
| 20 |
+
self.lm_model = None
|
| 21 |
+
self.fm_model = None
|
| 22 |
+
self.voc_model = None
|
| 23 |
+
self.model = None
|
| 24 |
+
self.index_ds_class = None
|
| 25 |
+
|
| 26 |
+
self.dataset_conf = None
|
| 27 |
+
self.kwargs = OmegaConf.load(config_path)
|
| 28 |
+
|
| 29 |
+
if device.startswith("cuda"):
|
| 30 |
+
try:
|
| 31 |
+
device_id = int(device.split(":")[-1])
|
| 32 |
+
torch.cuda.set_device(device_id)
|
| 33 |
+
except (ValueError, IndexError):
|
| 34 |
+
self.logger.warning(f"Invalid cuda device string {device}, defaulting to 0")
|
| 35 |
+
torch.cuda.set_device(0)
|
| 36 |
+
else:
|
| 37 |
+
self.logger.info(f"Running on CPU")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
lm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/llm/ds-model.pt.best/mp_rank_00_model_states.pt")
|
| 41 |
+
fm_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/flow/ds-model.pt.best/mp_rank_00_model_states.pt")
|
| 42 |
+
voc_ckpt_path = os.path.join(ckpt_path, "funcineforge_zh_en/vocoder/ds-model.pt.best/avg_5_removewn.pt")
|
| 43 |
+
|
| 44 |
+
lm_exp_dir, lm_model_name, lm_ckpt_id, _ = lm_ckpt_path.rsplit("/", 3)
|
| 45 |
+
self.logger.info(f"init LM model form {lm_ckpt_path}")
|
| 46 |
+
|
| 47 |
+
from funcineforge import AutoModel
|
| 48 |
+
self.lm_model = (AutoModel(
|
| 49 |
+
model=os.path.join(lm_exp_dir, lm_model_name),
|
| 50 |
+
init_param=lm_ckpt_path,
|
| 51 |
+
output_dir=None,
|
| 52 |
+
device=device,
|
| 53 |
+
))
|
| 54 |
+
self.lm_model.model.to(dtype_map[self.kwargs.get("llm_dtype", "fp32")])
|
| 55 |
+
|
| 56 |
+
fm_exp_dir, fm_model_name, fm_ckpt_id, _ = fm_ckpt_path.rsplit("/", 3)
|
| 57 |
+
self.logger.info(f"build FM model form {fm_ckpt_path}")
|
| 58 |
+
self.fm_model = AutoModel(
|
| 59 |
+
model=os.path.join(fm_exp_dir, fm_model_name),
|
| 60 |
+
init_param=fm_ckpt_path,
|
| 61 |
+
output_dir=None,
|
| 62 |
+
device=device,
|
| 63 |
+
)
|
| 64 |
+
self.fm_model.model.to(dtype_map[self.kwargs.get("fm_dtype", "fp32")])
|
| 65 |
+
|
| 66 |
+
voc_exp_dir, voc_model_name, voc_ckpt_id, _ = voc_ckpt_path.rsplit("/", 3)
|
| 67 |
+
self.logger.info(f"build VOC model form {voc_ckpt_path}")
|
| 68 |
+
self.voc_model = AutoModel(
|
| 69 |
+
model=os.path.join(voc_exp_dir, voc_model_name),
|
| 70 |
+
init_param=voc_ckpt_path,
|
| 71 |
+
output_dir=None,
|
| 72 |
+
device=device,
|
| 73 |
+
)
|
| 74 |
+
self.voc_model.model.to(dtype_map[self.kwargs.get("voc_dtype", "fp32")])
|
| 75 |
+
|
| 76 |
+
self.logger.info(f"build inference model {self.kwargs.get('model')}")
|
| 77 |
+
self.kwargs["output_dir"] = output_dir
|
| 78 |
+
self.kwargs["tokenizer"] = None
|
| 79 |
+
self.model = AutoModel(
|
| 80 |
+
**self.kwargs,
|
| 81 |
+
lm_model=self.lm_model,
|
| 82 |
+
fm_model=self.fm_model,
|
| 83 |
+
voc_model=self.voc_model,
|
| 84 |
+
)
|
| 85 |
+
self.dataset_conf = self.kwargs.get("dataset_conf")
|
| 86 |
+
|
| 87 |
+
def inference(self, jsonl_path: str):
|
| 88 |
+
if not self.model:
|
| 89 |
+
raise RuntimeError("Model class not initialized.")
|
| 90 |
+
|
| 91 |
+
dataset = FunCineForgeDS(jsonl_path, **self.dataset_conf)
|
| 92 |
+
self.logger.info(f"Starting inference on {len(dataset)} items...")
|
| 93 |
+
|
| 94 |
+
self.model.inference(input=dataset, input_len=len(dataset))
|
| 95 |
+
self.logger.info("Inference finished.")
|
funcineforge/auto/auto_model.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from funcineforge.utils.misc import deep_update
|
| 10 |
+
from funcineforge.utils.set_all_random_seed import set_all_random_seed
|
| 11 |
+
from funcineforge.utils.load_pretrained_model import load_pretrained_model
|
| 12 |
+
from funcineforge.download.download_model_from_hub import download_model
|
| 13 |
+
from funcineforge.tokenizer import FunCineForgeTokenizer
|
| 14 |
+
from funcineforge.face import FaceRecIR101
|
| 15 |
+
import importlib
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def prepare_data_iterator(data_in, input_len):
|
| 19 |
+
""" """
|
| 20 |
+
data_list = []
|
| 21 |
+
key_list = []
|
| 22 |
+
for idx in range(input_len):
|
| 23 |
+
item = data_in[idx]
|
| 24 |
+
utt = item["utt"]
|
| 25 |
+
data_list.append(item)
|
| 26 |
+
key_list.append(utt)
|
| 27 |
+
return key_list, data_list
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AutoModel:
|
| 31 |
+
|
| 32 |
+
def __init__(self, **kwargs):
|
| 33 |
+
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
|
| 34 |
+
logging.basicConfig(level=log_level)
|
| 35 |
+
model, kwargs = self.build_model(**kwargs)
|
| 36 |
+
self.kwargs = kwargs
|
| 37 |
+
self.model = model
|
| 38 |
+
self.model_path = kwargs.get("model_path")
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def build_model(**kwargs):
|
| 42 |
+
assert "model" in kwargs
|
| 43 |
+
if "model_conf" not in kwargs:
|
| 44 |
+
logging.info("download models from {} or local dir".format(kwargs.get("hub", "ms")))
|
| 45 |
+
kwargs = download_model(**kwargs)
|
| 46 |
+
|
| 47 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
| 48 |
+
|
| 49 |
+
device = kwargs.get("device", "cuda")
|
| 50 |
+
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
| 51 |
+
device = "cpu"
|
| 52 |
+
kwargs["batch_size"] = 1
|
| 53 |
+
kwargs["device"] = device
|
| 54 |
+
|
| 55 |
+
torch.set_num_threads(kwargs.get("ncpu", 4))
|
| 56 |
+
|
| 57 |
+
# build tokenizer
|
| 58 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 59 |
+
if tokenizer is not None:
|
| 60 |
+
tokenizer = FunCineForgeTokenizer(**kwargs.get("tokenizer_conf", {}))
|
| 61 |
+
kwargs["token_list"] = (
|
| 62 |
+
tokenizer.token_list if hasattr(tokenizer, "token_list") else None
|
| 63 |
+
)
|
| 64 |
+
kwargs["token_list"] = (
|
| 65 |
+
tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
|
| 66 |
+
)
|
| 67 |
+
vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
|
| 68 |
+
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
|
| 69 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 70 |
+
else:
|
| 71 |
+
vocab_size = -1
|
| 72 |
+
kwargs["tokenizer"] = tokenizer
|
| 73 |
+
|
| 74 |
+
# build face_encoder
|
| 75 |
+
face_encoder = kwargs.get("face_encoder", None)
|
| 76 |
+
if face_encoder is not None:
|
| 77 |
+
face_encoder = FaceRecIR101(**kwargs.get("face_encoder_conf", {}))
|
| 78 |
+
kwargs["face_encoder"] = face_encoder
|
| 79 |
+
|
| 80 |
+
model_conf = {}
|
| 81 |
+
model_class_name = kwargs["model"]
|
| 82 |
+
deep_update(model_conf, kwargs.get("model_conf", {}))
|
| 83 |
+
deep_update(model_conf, kwargs)
|
| 84 |
+
module = importlib.import_module("funcineforge.models")
|
| 85 |
+
model_class = getattr(module, model_class_name)
|
| 86 |
+
model = model_class(**model_conf, vocab_size=vocab_size)
|
| 87 |
+
|
| 88 |
+
# init_param
|
| 89 |
+
init_param = kwargs.get("init_param", None)
|
| 90 |
+
if init_param is not None and os.path.exists(init_param):
|
| 91 |
+
logging.info(f"Loading pretrained params from ckpt: {init_param}")
|
| 92 |
+
load_pretrained_model(
|
| 93 |
+
path=init_param,
|
| 94 |
+
model=model,
|
| 95 |
+
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
| 96 |
+
scope_map=kwargs.get("scope_map", []),
|
| 97 |
+
excludes=kwargs.get("excludes", None),
|
| 98 |
+
use_deepspeed=kwargs.get("train_conf", {}).get("use_deepspeed", False),
|
| 99 |
+
save_deepspeed_zero_fp32=kwargs.get("save_deepspeed_zero_fp32", True),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# fp16
|
| 103 |
+
if kwargs.get("fp16", False):
|
| 104 |
+
model.to(torch.float16)
|
| 105 |
+
elif kwargs.get("bf16", False):
|
| 106 |
+
model.to(torch.bfloat16)
|
| 107 |
+
model.to(device)
|
| 108 |
+
|
| 109 |
+
return model, kwargs
|
| 110 |
+
|
| 111 |
+
def __call__(self, *args, **cfg):
|
| 112 |
+
kwargs = self.kwargs
|
| 113 |
+
deep_update(kwargs, cfg)
|
| 114 |
+
res = self.model(*args, kwargs)
|
| 115 |
+
return res
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def inference(self, input, input_len=None, model=None, kwargs=None, **cfg):
|
| 119 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
| 120 |
+
deep_update(kwargs, cfg)
|
| 121 |
+
model = self.model if model is None else model
|
| 122 |
+
model.eval()
|
| 123 |
+
batch_size = kwargs.get("batch_size", 1)
|
| 124 |
+
key_list, data_list = prepare_data_iterator(
|
| 125 |
+
input, input_len=input_len
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
speed_stats = {}
|
| 129 |
+
num_samples = len(data_list)
|
| 130 |
+
disable_pbar = self.kwargs.get("disable_pbar", False)
|
| 131 |
+
pbar = (
|
| 132 |
+
tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
|
| 133 |
+
)
|
| 134 |
+
time_speech_total = 0.0
|
| 135 |
+
time_escape_total = 0.0
|
| 136 |
+
count = 0
|
| 137 |
+
log_interval = kwargs.get("log_interval", None)
|
| 138 |
+
for beg_idx in range(0, num_samples, batch_size):
|
| 139 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
| 140 |
+
data_batch = data_list[beg_idx:end_idx]
|
| 141 |
+
key_batch = key_list[beg_idx:end_idx]
|
| 142 |
+
batch = {"data_in": data_batch, "data_lengths": end_idx - beg_idx, "key": key_batch}
|
| 143 |
+
|
| 144 |
+
time1 = time.perf_counter()
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
res = model.inference(**batch, **kwargs)
|
| 147 |
+
if isinstance(res, (list, tuple)):
|
| 148 |
+
results = res[0] if len(res) > 0 else [{"text": ""}]
|
| 149 |
+
meta_data = res[1] if len(res) > 1 else {}
|
| 150 |
+
time2 = time.perf_counter()
|
| 151 |
+
|
| 152 |
+
batch_data_time = meta_data.get("batch_data_time", -1)
|
| 153 |
+
time_escape = time2 - time1
|
| 154 |
+
speed_stats["forward"] = f"{time_escape:0.3f}"
|
| 155 |
+
speed_stats["batch_size"] = f"{len(results)}"
|
| 156 |
+
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
| 157 |
+
description = f"{speed_stats}, "
|
| 158 |
+
if pbar:
|
| 159 |
+
pbar.update(batch_size)
|
| 160 |
+
pbar.set_description(description)
|
| 161 |
+
else:
|
| 162 |
+
if log_interval is not None and count % log_interval == 0:
|
| 163 |
+
logging.info(
|
| 164 |
+
f"processed {count*batch_size}/{num_samples} samples: {key_batch[0]}"
|
| 165 |
+
)
|
| 166 |
+
time_speech_total += batch_data_time
|
| 167 |
+
time_escape_total += time_escape
|
| 168 |
+
count += 1
|
| 169 |
+
|
| 170 |
+
if pbar:
|
| 171 |
+
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
| 172 |
+
torch.cuda.empty_cache()
|
| 173 |
+
return
|
funcineforge/datasets/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .index_ds import FunCineForgeDS
|
| 2 |
+
from .datasets import FunCineForgeDataset
|
funcineforge/datasets/datasets.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import pickle
|
| 4 |
+
import numpy as np
|
| 5 |
+
from funcineforge.utils.hinter import hint_once
|
| 6 |
+
from funcineforge.datasets import FunCineForgeDS
|
| 7 |
+
from funcineforge.models import FunCineForgeSpecAug
|
| 8 |
+
|
| 9 |
+
class FunCineForgeDataset(torch.utils.data.Dataset):
|
| 10 |
+
"""
|
| 11 |
+
Dataset for Mixed LM of FunCineForge
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
path,
|
| 17 |
+
index_ds: str = None,
|
| 18 |
+
frontend=None,
|
| 19 |
+
tokenizer=None,
|
| 20 |
+
face_encoder=None,
|
| 21 |
+
int_pad_value: int = -1,
|
| 22 |
+
float_pad_value: float = 0.0,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.index_ds = FunCineForgeDS(path, **kwargs)
|
| 27 |
+
self.tokenizer = tokenizer
|
| 28 |
+
self.face_encoder = face_encoder
|
| 29 |
+
|
| 30 |
+
self.int_pad_value = int_pad_value
|
| 31 |
+
self.float_pad_value = float_pad_value
|
| 32 |
+
self.batch_size = kwargs.get("batch_size")
|
| 33 |
+
self.batch_type = kwargs.get("batch_type")
|
| 34 |
+
self.retry = kwargs.get("retry", 100)
|
| 35 |
+
|
| 36 |
+
# self.kwargs = kwargs
|
| 37 |
+
self.max_token_length = kwargs.get("max_token_length", 1500)
|
| 38 |
+
self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5)
|
| 39 |
+
self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500)
|
| 40 |
+
self.multiturn_num_max = kwargs.get("multiturn_num_max", 1)
|
| 41 |
+
self.face_size = kwargs.get("face_size", 512)
|
| 42 |
+
|
| 43 |
+
self.codebook_size = kwargs.get("codebook_size", 6561)
|
| 44 |
+
self.sos = kwargs.get("sos", self.codebook_size)
|
| 45 |
+
self.eos = kwargs.get("eos", self.codebook_size + 1)
|
| 46 |
+
self.turn_of_speech = kwargs.get("turn_of_speech", self.codebook_size + 2)
|
| 47 |
+
self.ignore_id = kwargs.get("ignore_id", -100)
|
| 48 |
+
|
| 49 |
+
specaug = kwargs.get("specaug", None)
|
| 50 |
+
specaug_conf = kwargs.get("specaug_conf", {})
|
| 51 |
+
if specaug is not None:
|
| 52 |
+
specaug = FunCineForgeSpecAug(**specaug_conf)
|
| 53 |
+
self.specaug = specaug
|
| 54 |
+
|
| 55 |
+
self.set_invalid_xvec_zeros = kwargs.get("set_invalid_xvec_zeros", False)
|
| 56 |
+
self.use_emotion_clue = kwargs.get("use_emotion_clue", False)
|
| 57 |
+
logging.info(f"use_emotion_clue: {self.use_emotion_clue}")
|
| 58 |
+
|
| 59 |
+
def get_source_len(self, index):
|
| 60 |
+
item = self.index_ds[index]
|
| 61 |
+
source_len = self.index_ds.get_source_len(item)
|
| 62 |
+
return source_len
|
| 63 |
+
|
| 64 |
+
def get_target_len(self, index):
|
| 65 |
+
item = self.index_ds[index]
|
| 66 |
+
return self.index_ds.get_target_len(item)
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
return len(self.index_ds)
|
| 70 |
+
|
| 71 |
+
def mixup_text_codec(self, text: torch.Tensor, aug_codec: torch.Tensor, timespk_ids: torch.Tensor, type_id: int):
|
| 72 |
+
text_len = text.shape[0]
|
| 73 |
+
timespk_len = timespk_ids.shape[0]
|
| 74 |
+
sequence = [self.sos, *text.tolist(), type_id, *timespk_ids.tolist(), self.turn_of_speech, *aug_codec.tolist(), self.eos]
|
| 75 |
+
# sequence = [self.sos, *text.tolist(), type_id, self.turn_of_speech, *aug_codec.tolist(), self.eos]
|
| 76 |
+
input_ids = torch.tensor(sequence, dtype=torch.int64)
|
| 77 |
+
text_flag = torch.zeros(len(sequence), dtype=torch.float32)
|
| 78 |
+
text_flag[1:text_len+1] = 1
|
| 79 |
+
timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
|
| 80 |
+
timespk_flag[text_len+1:text_len+2+timespk_len] = 1
|
| 81 |
+
# timespk_flag[text_len+1:text_len+2] = 1
|
| 82 |
+
codec_flag = 1 - (text_flag + timespk_flag)
|
| 83 |
+
labels = torch.tensor(sequence, dtype=torch.int64)
|
| 84 |
+
labels[:text_len+timespk_len+3] = self.ignore_id
|
| 85 |
+
# labels[:text_len+3] = self.ignore_id
|
| 86 |
+
|
| 87 |
+
return input_ids, labels, text_flag, codec_flag, timespk_flag
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, index):
|
| 90 |
+
output = None
|
| 91 |
+
for idx in range(self.retry):
|
| 92 |
+
if idx == 0:
|
| 93 |
+
index_cur = index
|
| 94 |
+
else:
|
| 95 |
+
index_cur = torch.randint(0, len(self.index_ds), ()).item()
|
| 96 |
+
item = self.index_ds[index_cur]
|
| 97 |
+
|
| 98 |
+
# clue + text
|
| 99 |
+
text = item["text"]
|
| 100 |
+
clue = "<|startofclue|>" + item["clue"] + "<|endofclue|>"
|
| 101 |
+
if self.use_emotion_clue:
|
| 102 |
+
text = clue + text
|
| 103 |
+
text_ids = torch.tensor(self.tokenizer.encode(text), dtype=torch.int32)
|
| 104 |
+
hint_once(f"raw text: {text}", "log_text")
|
| 105 |
+
|
| 106 |
+
# speech tokens
|
| 107 |
+
target_out = item["token"]
|
| 108 |
+
codec = torch.from_numpy(np.load(target_out))
|
| 109 |
+
codec_len = codec.shape[0] # 可用数据集中的 speech_length 代替
|
| 110 |
+
aug_codec = codec.clone()
|
| 111 |
+
if self.specaug is not None: # aug_codec是随机mask的codec增强鲁棒性
|
| 112 |
+
aug_codec, _ = self.specaug(aug_codec.float().unsqueeze(0).unsqueeze(-1))
|
| 113 |
+
aug_codec = aug_codec.squeeze(0).squeeze(-1).long()
|
| 114 |
+
|
| 115 |
+
# dialogue
|
| 116 |
+
timespk_ids = torch.from_numpy(item["timespk_ids"])
|
| 117 |
+
|
| 118 |
+
# mixup
|
| 119 |
+
type_id = item["type_id"]
|
| 120 |
+
input_ids, labels, text_flag, codec_flag, timespk_flag = self.mixup_text_codec(
|
| 121 |
+
text_ids, aug_codec, timespk_ids, type_id
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# face
|
| 125 |
+
face_features = item["face"]
|
| 126 |
+
face_emb = torch.zeros((codec_len, self.face_size), dtype=torch.float32) # face_emb 长度与 codec_len 相同
|
| 127 |
+
with open(face_features, 'rb') as f:
|
| 128 |
+
stat_obj = pickle.load(f)
|
| 129 |
+
embeddings = stat_obj['embeddings']
|
| 130 |
+
faceI = stat_obj['faceI']
|
| 131 |
+
for emb, frameI in zip(embeddings, faceI):
|
| 132 |
+
fi = int(frameI)
|
| 133 |
+
if 0 <= fi < codec_len:
|
| 134 |
+
end = min(fi + 5, codec_len)
|
| 135 |
+
face_emb[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
|
| 136 |
+
|
| 137 |
+
# attention_mask 对应序列长度包括input_id=(sos, <|startofclue|>, clue, <|endofclue|>, text, type_id, timespk_ids, turn_of_speech, speech, eos)
|
| 138 |
+
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
| 139 |
+
codec_len = torch.tensor([codec_len], dtype=torch.int32)
|
| 140 |
+
output = {
|
| 141 |
+
"input_ids": input_ids,
|
| 142 |
+
"face_emb": face_emb,
|
| 143 |
+
"attention_mask": attention_mask,
|
| 144 |
+
"labels_ids": labels,
|
| 145 |
+
"text_flag": text_flag,
|
| 146 |
+
"codec_flag": codec_flag,
|
| 147 |
+
"timespk_flag": timespk_flag,
|
| 148 |
+
"codec_len": codec_len,
|
| 149 |
+
}
|
| 150 |
+
break
|
| 151 |
+
return output
|
| 152 |
+
|
| 153 |
+
def collator(self, samples: list = None):
|
| 154 |
+
|
| 155 |
+
for idx in range(self.retry):
|
| 156 |
+
badcase_flag = False
|
| 157 |
+
|
| 158 |
+
outputs = {}
|
| 159 |
+
for sample in samples:
|
| 160 |
+
if sample is None:
|
| 161 |
+
continue
|
| 162 |
+
for key in sample.keys():
|
| 163 |
+
if key not in outputs:
|
| 164 |
+
outputs[key] = []
|
| 165 |
+
if isinstance(sample[key], (list, tuple)):
|
| 166 |
+
outputs[key].extend(sample[key])
|
| 167 |
+
else:
|
| 168 |
+
outputs[key].append(sample[key])
|
| 169 |
+
|
| 170 |
+
for key, data_list in outputs.items():
|
| 171 |
+
if isinstance(data_list[0], torch.Tensor):
|
| 172 |
+
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
|
| 173 |
+
|
| 174 |
+
pad_value = self.int_pad_value
|
| 175 |
+
else:
|
| 176 |
+
pad_value = self.float_pad_value
|
| 177 |
+
|
| 178 |
+
outputs[key] = torch.nn.utils.rnn.pad_sequence(
|
| 179 |
+
data_list, batch_first=True, padding_value=pad_value
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if self.batch_type != "example":
|
| 183 |
+
b, t = outputs["input_ids"].shape
|
| 184 |
+
if b > 1 and b * t > self.batch_size_token_max:
|
| 185 |
+
logging.info(
|
| 186 |
+
f"Warning, {idx}th, b*t: {b}*{t}={b * t} > batch_size_token_max: {self.batch_size_token_max}, drop last data"
|
| 187 |
+
)
|
| 188 |
+
samples = samples[:-1]
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
return outputs
|
funcineforge/datasets/index_ds.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FunCineForgeDS(torch.utils.data.Dataset):
|
| 8 |
+
|
| 9 |
+
def __init__(self, data_jsonl: str, **kwargs):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.max_source_length = kwargs.get("max_source_length", None)
|
| 13 |
+
self.max_text_length = kwargs.get("max_text_length", None)
|
| 14 |
+
self.max_token_length = kwargs.get("max_token_length", None)
|
| 15 |
+
self.ignore_id = kwargs.get("ignore_id", -100)
|
| 16 |
+
self.frame_shift = kwargs.get("frame_shift", 25)
|
| 17 |
+
self.timebook_size = kwargs.get("timebook_size", 1500)
|
| 18 |
+
self.type_map = {"旁白": kwargs.get("pangbai", self.timebook_size),
|
| 19 |
+
"独白": kwargs.get("dubai", self.timebook_size + 1),
|
| 20 |
+
"对话": kwargs.get("duihua", self.timebook_size + 2),
|
| 21 |
+
"多人": kwargs.get("duoren", self.timebook_size + 3),}
|
| 22 |
+
self.gender_map = {"男": kwargs.get("male", self.timebook_size + 4),
|
| 23 |
+
"male": kwargs.get("male", self.timebook_size + 4),
|
| 24 |
+
"女": kwargs.get("female", self.timebook_size + 5),
|
| 25 |
+
"female": kwargs.get("female", self.timebook_size + 5),}
|
| 26 |
+
self.age_map = {"儿童": kwargs.get("child", self.timebook_size + 6),
|
| 27 |
+
"child": kwargs.get("child", self.timebook_size + 6),
|
| 28 |
+
"青年": kwargs.get("youth", self.timebook_size + 7),
|
| 29 |
+
"teenager": kwargs.get("youth", self.timebook_size + 7),
|
| 30 |
+
"中年": kwargs.get("adult", self.timebook_size + 8),
|
| 31 |
+
"adult": kwargs.get("adult", self.timebook_size + 8),
|
| 32 |
+
"中老年": kwargs.get("middle", self.timebook_size + 9),
|
| 33 |
+
"middle-aged": kwargs.get("middle", self.timebook_size + 9),
|
| 34 |
+
"老年": kwargs.get("elderly", self.timebook_size + 10),
|
| 35 |
+
"elderly": kwargs.get("elderly", self.timebook_size + 10)}
|
| 36 |
+
self.speaker_id_start = kwargs.get("speaker_id_start", self.timebook_size + 11)
|
| 37 |
+
|
| 38 |
+
load_meta_data_key = kwargs.get("load_meta_data_key").split(",")
|
| 39 |
+
|
| 40 |
+
if not (data_jsonl.endswith(".jsonl") or data_jsonl.endswith(".json")):
|
| 41 |
+
# jsonl list file
|
| 42 |
+
with open(data_jsonl, encoding="utf-8") as fin:
|
| 43 |
+
file_list = fin.readlines()
|
| 44 |
+
logging.info(f"file_list: {file_list}")
|
| 45 |
+
else:
|
| 46 |
+
file_list = [data_jsonl]
|
| 47 |
+
|
| 48 |
+
contents = []
|
| 49 |
+
for file_json in file_list:
|
| 50 |
+
with open(file_json.strip(), encoding="utf-8") as fin:
|
| 51 |
+
for line in fin:
|
| 52 |
+
data_dict = json.loads(line.strip())
|
| 53 |
+
utt = data_dict["utt"]
|
| 54 |
+
data_type = data_dict.get("type")
|
| 55 |
+
type_id = self.type_map[data_type] if data_type in self.type_map else 1500
|
| 56 |
+
data = data_dict["messages"]
|
| 57 |
+
speech_length = data_dict.get("speech_length", -1)
|
| 58 |
+
# 2 for startofclue, endofclue
|
| 59 |
+
text_length = data_dict.get("text_length", -1) + data_dict.get("clue_length", -1) + 2
|
| 60 |
+
if self.max_token_length is not None and (speech_length > self.max_token_length or speech_length <= 0):
|
| 61 |
+
logging.info(
|
| 62 |
+
f"speech_length: {speech_length} > {self.max_token_length}, drop it: {data_dict}"
|
| 63 |
+
)
|
| 64 |
+
continue
|
| 65 |
+
if self.max_text_length is not None and (text_length > self.max_text_length or text_length <= 0):
|
| 66 |
+
logging.info(
|
| 67 |
+
f"text_length: {text_length} > {self.max_text_length}, drop it: {data_dict}"
|
| 68 |
+
)
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
skip_flag = None
|
| 72 |
+
roles = {item.get("role") for item in data}
|
| 73 |
+
for key in load_meta_data_key:
|
| 74 |
+
if key not in roles:
|
| 75 |
+
skip_flag = key
|
| 76 |
+
break
|
| 77 |
+
if skip_flag is not None:
|
| 78 |
+
logging.info(
|
| 79 |
+
f"doesn't have {skip_flag}, drop it: {data_dict}")
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
contents_i = {}
|
| 83 |
+
timespk_ids_len = 0
|
| 84 |
+
for i, item in enumerate(data):
|
| 85 |
+
role = item["role"]
|
| 86 |
+
content = item["content"]
|
| 87 |
+
for key in load_meta_data_key:
|
| 88 |
+
if role == key:
|
| 89 |
+
if key == "dialogue":
|
| 90 |
+
timespk_ids = self.timespk_to_codec(content)
|
| 91 |
+
timespk_ids_len = len(timespk_ids)
|
| 92 |
+
if timespk_ids_len == 0:
|
| 93 |
+
logging.info(f"[WARNING] len of timespk_ids is 0: {data_dict}")
|
| 94 |
+
contents_i["timespk_ids"] = timespk_ids
|
| 95 |
+
else:
|
| 96 |
+
contents_i[role] = content
|
| 97 |
+
contents_i["utt"] = utt
|
| 98 |
+
contents_i["type_id"] = type_id
|
| 99 |
+
# face embs len = speech tokens len, so need * 2;
|
| 100 |
+
# 4: sos, tos, eos; type_id
|
| 101 |
+
contents_i["source_len"] = speech_length * 2 + text_length + timespk_ids_len + 4
|
| 102 |
+
contents_i["speech_len"] = speech_length
|
| 103 |
+
contents_i["text_len"] = text_length # include clue_length
|
| 104 |
+
contents.append(contents_i)
|
| 105 |
+
|
| 106 |
+
self.contents = contents
|
| 107 |
+
|
| 108 |
+
logging.info("total_num of samplers: {}, {}".format(len(self.contents), data_jsonl))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def timespk_to_codec(self, dialogue):
|
| 112 |
+
# tuple tokens (start, spk, gender, age, end) * n_parts
|
| 113 |
+
n_parts = len(dialogue)
|
| 114 |
+
if n_parts == 0:
|
| 115 |
+
return np.array([], dtype=np.int64)
|
| 116 |
+
starts = np.array([part["start"] for part in dialogue])
|
| 117 |
+
durations = np.array([part["duration"] for part in dialogue])
|
| 118 |
+
speakers = np.array([int(part["spk"]) for part in dialogue])
|
| 119 |
+
genders = [part["gender"] for part in dialogue]
|
| 120 |
+
ages = [part["age"] for part in dialogue]
|
| 121 |
+
|
| 122 |
+
start_idxs = (starts * self.frame_shift + 1).astype(np.int64)
|
| 123 |
+
end_idxs = ((starts + durations) * self.frame_shift + 1).astype(np.int64)
|
| 124 |
+
spk_ids = (self.speaker_id_start + speakers - 1).astype(np.int64)
|
| 125 |
+
gender_ids = [self.gender_map.get(g, self.ignore_id) for g in genders]
|
| 126 |
+
age_ids = [self.age_map.get(a, self.ignore_id) for a in ages]
|
| 127 |
+
|
| 128 |
+
sequence = np.full(n_parts * 5, self.ignore_id, dtype=np.int64)
|
| 129 |
+
sequence[0::5] = start_idxs
|
| 130 |
+
sequence[1::5] = spk_ids
|
| 131 |
+
sequence[2::5] = gender_ids
|
| 132 |
+
sequence[3::5] = age_ids
|
| 133 |
+
sequence[4::5] = end_idxs
|
| 134 |
+
return sequence
|
| 135 |
+
|
| 136 |
+
def __len__(self):
|
| 137 |
+
return len(self.contents)
|
| 138 |
+
|
| 139 |
+
def __getitem__(self, index):
|
| 140 |
+
|
| 141 |
+
data = self.contents[index]
|
| 142 |
+
|
| 143 |
+
return data
|
| 144 |
+
|
| 145 |
+
def get_source_len(self, data_dict):
|
| 146 |
+
source_len = data_dict.get("source_len", 0)
|
| 147 |
+
return source_len
|
| 148 |
+
|
| 149 |
+
def get_target_len(self, data_dict):
|
| 150 |
+
target_len = data_dict.get("speech_len", 0)
|
| 151 |
+
return target_len
|
funcineforge/download/__init__.py
ADDED
|
File without changes
|
funcineforge/download/download_model_from_hub.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from omegaconf import OmegaConf, DictConfig
|
| 4 |
+
from funcineforge.download.name_maps_from_hub import name_maps_ms, name_maps_hf, name_maps_openai
|
| 5 |
+
|
| 6 |
+
def download_model(**kwargs):
|
| 7 |
+
hub = kwargs.get("hub", "ms")
|
| 8 |
+
if hub == "ms":
|
| 9 |
+
kwargs = download_from_ms(**kwargs)
|
| 10 |
+
elif hub == "hf":
|
| 11 |
+
kwargs = download_from_hf(**kwargs)
|
| 12 |
+
elif hub == "openai":
|
| 13 |
+
model_or_path = kwargs.get("model")
|
| 14 |
+
if os.path.exists(model_or_path):
|
| 15 |
+
# local path
|
| 16 |
+
kwargs["model_path"] = model_or_path
|
| 17 |
+
kwargs["model"] = "WhisperWarp"
|
| 18 |
+
else:
|
| 19 |
+
# model name
|
| 20 |
+
if model_or_path in name_maps_openai:
|
| 21 |
+
model_or_path = name_maps_openai[model_or_path]
|
| 22 |
+
kwargs["model_path"] = model_or_path
|
| 23 |
+
|
| 24 |
+
return kwargs
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def download_from_ms(**kwargs):
|
| 28 |
+
model_or_path = kwargs.get("model")
|
| 29 |
+
if model_or_path in name_maps_ms:
|
| 30 |
+
model_or_path = name_maps_ms[model_or_path]
|
| 31 |
+
model_revision = kwargs.get("model_revision", "master")
|
| 32 |
+
if not os.path.exists(model_or_path) and "model_path" not in kwargs:
|
| 33 |
+
try:
|
| 34 |
+
model_or_path = get_or_download_model_dir(
|
| 35 |
+
model_or_path,
|
| 36 |
+
model_revision,
|
| 37 |
+
is_training=kwargs.get("is_training"),
|
| 38 |
+
check_latest=kwargs.get("check_latest", True),
|
| 39 |
+
)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"Download: {model_or_path} failed!: {e}")
|
| 42 |
+
|
| 43 |
+
kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
|
| 44 |
+
|
| 45 |
+
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
| 46 |
+
with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
|
| 47 |
+
conf_json = json.load(f)
|
| 48 |
+
|
| 49 |
+
cfg = {}
|
| 50 |
+
if "file_path_metas" in conf_json:
|
| 51 |
+
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
| 52 |
+
# cfg.update(kwargs)
|
| 53 |
+
cfg = OmegaConf.merge(cfg, kwargs)
|
| 54 |
+
if "config" in cfg:
|
| 55 |
+
config = OmegaConf.load(cfg["config"])
|
| 56 |
+
kwargs = OmegaConf.merge(config, cfg)
|
| 57 |
+
kwargs["model"] = config["model"]
|
| 58 |
+
elif os.path.exists(os.path.join(model_or_path, "config.yaml")):
|
| 59 |
+
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
|
| 60 |
+
kwargs = OmegaConf.merge(config, kwargs)
|
| 61 |
+
|
| 62 |
+
init_param = kwargs.get("init_param", "")
|
| 63 |
+
if (
|
| 64 |
+
isinstance(init_param, str)
|
| 65 |
+
and not os.path.exists(init_param)
|
| 66 |
+
or isinstance(init_param, (list, tuple))
|
| 67 |
+
):
|
| 68 |
+
init_param_new = init_param
|
| 69 |
+
if isinstance(init_param, str):
|
| 70 |
+
init_param = init_param.split(",")
|
| 71 |
+
for init_param_i in init_param:
|
| 72 |
+
if not os.path.exists(init_param_i):
|
| 73 |
+
print(f"init_param: {init_param_i}, does not exist")
|
| 74 |
+
init_param_i = os.path.join(model_or_path, "model.pt")
|
| 75 |
+
init_param_new = f"{init_param_new},{init_param_i}"
|
| 76 |
+
kwargs["init_param"] = init_param_new
|
| 77 |
+
# assert os.path.exists(kwargs["init_param"]), "init_param does not exist"
|
| 78 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
| 79 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
| 80 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
| 81 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
|
| 82 |
+
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
| 83 |
+
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
| 84 |
+
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
| 85 |
+
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
| 86 |
+
kwargs["model"] = config["model"]
|
| 87 |
+
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
| 88 |
+
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
| 89 |
+
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
| 90 |
+
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
| 91 |
+
if isinstance(kwargs, DictConfig):
|
| 92 |
+
kwargs = OmegaConf.to_container(kwargs, resolve=True)
|
| 93 |
+
|
| 94 |
+
return kwargs
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def download_from_hf(**kwargs):
|
| 98 |
+
model_or_path = kwargs.get("model")
|
| 99 |
+
if model_or_path in name_maps_hf:
|
| 100 |
+
model_or_path = name_maps_hf[model_or_path]
|
| 101 |
+
model_revision = kwargs.get("model_revision", "master")
|
| 102 |
+
if not os.path.exists(model_or_path) and "model_path" not in kwargs:
|
| 103 |
+
try:
|
| 104 |
+
model_or_path = get_or_download_model_dir_hf(
|
| 105 |
+
model_or_path,
|
| 106 |
+
model_revision,
|
| 107 |
+
is_training=kwargs.get("is_training"),
|
| 108 |
+
check_latest=kwargs.get("check_latest", True),
|
| 109 |
+
)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Download: {model_or_path} failed!: {e}")
|
| 112 |
+
|
| 113 |
+
kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
|
| 114 |
+
|
| 115 |
+
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
|
| 116 |
+
with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
|
| 117 |
+
conf_json = json.load(f)
|
| 118 |
+
|
| 119 |
+
cfg = {}
|
| 120 |
+
if "file_path_metas" in conf_json:
|
| 121 |
+
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
|
| 122 |
+
cfg = OmegaConf.merge(cfg, kwargs)
|
| 123 |
+
# cfg.update(kwargs)
|
| 124 |
+
if "config" in cfg:
|
| 125 |
+
config = OmegaConf.load(cfg["config"])
|
| 126 |
+
kwargs = OmegaConf.merge(config, cfg)
|
| 127 |
+
kwargs["model"] = config["model"]
|
| 128 |
+
elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
|
| 129 |
+
os.path.join(model_or_path, "model.pt")
|
| 130 |
+
):
|
| 131 |
+
config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
|
| 132 |
+
kwargs = OmegaConf.merge(config, kwargs)
|
| 133 |
+
init_param = os.path.join(model_or_path, "model.pt")
|
| 134 |
+
kwargs["init_param"] = init_param
|
| 135 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
| 136 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
| 137 |
+
if os.path.exists(os.path.join(model_or_path, "tokens.json")):
|
| 138 |
+
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
|
| 139 |
+
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
| 140 |
+
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
| 141 |
+
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
| 142 |
+
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
| 143 |
+
kwargs["model"] = config["model"]
|
| 144 |
+
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
|
| 145 |
+
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
| 146 |
+
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
|
| 147 |
+
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
|
| 148 |
+
if isinstance(kwargs, DictConfig):
|
| 149 |
+
kwargs = OmegaConf.to_container(kwargs, resolve=True)
|
| 150 |
+
|
| 151 |
+
return kwargs
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
|
| 155 |
+
print(file_path_metas)
|
| 156 |
+
if isinstance(file_path_metas, dict):
|
| 157 |
+
for k, v in file_path_metas.items():
|
| 158 |
+
if isinstance(v, str):
|
| 159 |
+
p = os.path.join(model_or_path, v)
|
| 160 |
+
if os.path.exists(p):
|
| 161 |
+
cfg[k] = p
|
| 162 |
+
elif isinstance(v, dict):
|
| 163 |
+
if k not in cfg:
|
| 164 |
+
cfg[k] = {}
|
| 165 |
+
add_file_root_path(model_or_path, v, cfg[k])
|
| 166 |
+
return cfg
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_or_download_model_dir(
|
| 170 |
+
model,
|
| 171 |
+
model_revision=None,
|
| 172 |
+
is_training=False,
|
| 173 |
+
check_latest=True,
|
| 174 |
+
):
|
| 175 |
+
"""Get local model directory or download model if necessary.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
model (str): model id or path to local model directory.
|
| 179 |
+
model_revision (str, optional): model version number.
|
| 180 |
+
:param is_training:
|
| 181 |
+
"""
|
| 182 |
+
from modelscope.hub.check_model import check_local_model_is_latest
|
| 183 |
+
from modelscope.hub.snapshot_download import snapshot_download
|
| 184 |
+
|
| 185 |
+
from modelscope.utils.constant import Invoke, ThirdParty
|
| 186 |
+
|
| 187 |
+
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
|
| 188 |
+
|
| 189 |
+
if os.path.exists(model) and check_latest:
|
| 190 |
+
model_cache_dir = model if os.path.isdir(model) else os.path.dirname(model)
|
| 191 |
+
try:
|
| 192 |
+
check_local_model_is_latest(
|
| 193 |
+
model_cache_dir, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funcineforge"}
|
| 194 |
+
)
|
| 195 |
+
except:
|
| 196 |
+
print("could not check the latest version")
|
| 197 |
+
else:
|
| 198 |
+
model_cache_dir = snapshot_download(
|
| 199 |
+
model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funcineforge"}
|
| 200 |
+
)
|
| 201 |
+
return model_cache_dir
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_or_download_model_dir_hf(
|
| 205 |
+
model,
|
| 206 |
+
model_revision=None,
|
| 207 |
+
is_training=False,
|
| 208 |
+
check_latest=True,
|
| 209 |
+
):
|
| 210 |
+
"""Get local model directory or download model if necessary.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
model (str): model id or path to local model directory.
|
| 214 |
+
model_revision (str, optional): model version number.
|
| 215 |
+
:param is_training:
|
| 216 |
+
"""
|
| 217 |
+
from huggingface_hub import snapshot_download
|
| 218 |
+
|
| 219 |
+
model_cache_dir = snapshot_download(model)
|
| 220 |
+
return model_cache_dir
|
funcineforge/download/file.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
from abc import ABCMeta, abstractmethod
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Generator, Union
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
from urllib.parse import urlparse
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def download_from_url(url):
|
| 15 |
+
result = urlparse(url)
|
| 16 |
+
file_path = None
|
| 17 |
+
if result.scheme is not None and len(result.scheme) > 0:
|
| 18 |
+
storage = HTTPStorage()
|
| 19 |
+
# bytes
|
| 20 |
+
data = storage.read(url)
|
| 21 |
+
work_dir = tempfile.TemporaryDirectory().name
|
| 22 |
+
if not os.path.exists(work_dir):
|
| 23 |
+
os.makedirs(work_dir)
|
| 24 |
+
file_path = os.path.join(work_dir, os.path.basename(url))
|
| 25 |
+
with open(file_path, "wb") as fb:
|
| 26 |
+
fb.write(data)
|
| 27 |
+
assert file_path is not None, f"failed to download: {url}"
|
| 28 |
+
return file_path
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Storage(metaclass=ABCMeta):
|
| 32 |
+
"""Abstract class of storage.
|
| 33 |
+
|
| 34 |
+
All backends need to implement two apis: ``read()`` and ``read_text()``.
|
| 35 |
+
``read()`` reads the file as a byte stream and ``read_text()`` reads
|
| 36 |
+
the file as texts.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def read(self, filepath: str):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def read_text(self, filepath: str):
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class LocalStorage(Storage):
|
| 57 |
+
"""Local hard disk storage"""
|
| 58 |
+
|
| 59 |
+
def read(self, filepath: Union[str, Path]) -> bytes:
|
| 60 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
filepath (str or Path): Path to read data.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
bytes: Expected bytes object.
|
| 67 |
+
"""
|
| 68 |
+
with open(filepath, "rb") as f:
|
| 69 |
+
content = f.read()
|
| 70 |
+
return content
|
| 71 |
+
|
| 72 |
+
def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
|
| 73 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
filepath (str or Path): Path to read data.
|
| 77 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 78 |
+
Default: 'utf-8'.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
str: Expected text reading from ``filepath``.
|
| 82 |
+
"""
|
| 83 |
+
with open(filepath, "r", encoding=encoding) as f:
|
| 84 |
+
value_buf = f.read()
|
| 85 |
+
return value_buf
|
| 86 |
+
|
| 87 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 88 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 89 |
+
|
| 90 |
+
Note:
|
| 91 |
+
``write`` will create a directory if the directory of ``filepath``
|
| 92 |
+
does not exist.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
obj (bytes): Data to be written.
|
| 96 |
+
filepath (str or Path): Path to write data.
|
| 97 |
+
"""
|
| 98 |
+
dirname = os.path.dirname(filepath)
|
| 99 |
+
if dirname and not os.path.exists(dirname):
|
| 100 |
+
os.makedirs(dirname, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
with open(filepath, "wb") as f:
|
| 103 |
+
f.write(obj)
|
| 104 |
+
|
| 105 |
+
def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
|
| 106 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
| 107 |
+
|
| 108 |
+
Note:
|
| 109 |
+
``write_text`` will create a directory if the directory of
|
| 110 |
+
``filepath`` does not exist.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
obj (str): Data to be written.
|
| 114 |
+
filepath (str or Path): Path to write data.
|
| 115 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 116 |
+
Default: 'utf-8'.
|
| 117 |
+
"""
|
| 118 |
+
dirname = os.path.dirname(filepath)
|
| 119 |
+
if dirname and not os.path.exists(dirname):
|
| 120 |
+
os.makedirs(dirname, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
with open(filepath, "w", encoding=encoding) as f:
|
| 123 |
+
f.write(obj)
|
| 124 |
+
|
| 125 |
+
@contextlib.contextmanager
|
| 126 |
+
def as_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
|
| 127 |
+
"""Only for unified API and do nothing."""
|
| 128 |
+
yield filepath
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class HTTPStorage(Storage):
|
| 132 |
+
"""HTTP and HTTPS storage."""
|
| 133 |
+
|
| 134 |
+
def read(self, url):
|
| 135 |
+
# TODO @wenmeng.zwm add progress bar if file is too large
|
| 136 |
+
r = requests.get(url)
|
| 137 |
+
r.raise_for_status()
|
| 138 |
+
return r.content
|
| 139 |
+
|
| 140 |
+
def read_text(self, url):
|
| 141 |
+
r = requests.get(url)
|
| 142 |
+
r.raise_for_status()
|
| 143 |
+
return r.text
|
| 144 |
+
|
| 145 |
+
@contextlib.contextmanager
|
| 146 |
+
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
| 147 |
+
"""Download a file from ``filepath``.
|
| 148 |
+
|
| 149 |
+
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
| 150 |
+
can be called with ``with`` statement, and when exists from the
|
| 151 |
+
``with`` statement, the temporary path will be released.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
filepath (str): Download a file from ``filepath``.
|
| 155 |
+
|
| 156 |
+
Examples:
|
| 157 |
+
>>> storage = HTTPStorage()
|
| 158 |
+
>>> # After existing from the ``with`` clause,
|
| 159 |
+
>>> # the path will be removed
|
| 160 |
+
>>> with storage.get_local_path('http://path/to/file') as path:
|
| 161 |
+
... # do something here
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
| 165 |
+
f.write(self.read(filepath))
|
| 166 |
+
f.close()
|
| 167 |
+
yield f.name
|
| 168 |
+
finally:
|
| 169 |
+
os.remove(f.name)
|
| 170 |
+
|
| 171 |
+
def write(self, obj: bytes, url: Union[str, Path]) -> None:
|
| 172 |
+
raise NotImplementedError("write is not supported by HTTP Storage")
|
| 173 |
+
|
| 174 |
+
def write_text(self, obj: str, url: Union[str, Path], encoding: str = "utf-8") -> None:
|
| 175 |
+
raise NotImplementedError("write_text is not supported by HTTP Storage")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class OSSStorage(Storage):
|
| 179 |
+
"""OSS storage."""
|
| 180 |
+
|
| 181 |
+
def __init__(self, oss_config_file=None):
|
| 182 |
+
# read from config file or env var
|
| 183 |
+
raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
|
| 184 |
+
|
| 185 |
+
def read(self, filepath):
|
| 186 |
+
raise NotImplementedError("OSSStorage.read to be implemented in the future")
|
| 187 |
+
|
| 188 |
+
def read_text(self, filepath, encoding="utf-8"):
|
| 189 |
+
raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
|
| 190 |
+
|
| 191 |
+
@contextlib.contextmanager
|
| 192 |
+
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
| 193 |
+
"""Download a file from ``filepath``.
|
| 194 |
+
|
| 195 |
+
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
|
| 196 |
+
can be called with ``with`` statement, and when exists from the
|
| 197 |
+
``with`` statement, the temporary path will be released.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
filepath (str): Download a file from ``filepath``.
|
| 201 |
+
|
| 202 |
+
Examples:
|
| 203 |
+
>>> storage = OSSStorage()
|
| 204 |
+
>>> # After existing from the ``with`` clause,
|
| 205 |
+
>>> # the path will be removed
|
| 206 |
+
>>> with storage.get_local_path('http://path/to/file') as path:
|
| 207 |
+
... # do something here
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
| 211 |
+
f.write(self.read(filepath))
|
| 212 |
+
f.close()
|
| 213 |
+
yield f.name
|
| 214 |
+
finally:
|
| 215 |
+
os.remove(f.name)
|
| 216 |
+
|
| 217 |
+
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 218 |
+
raise NotImplementedError("OSSStorage.write to be implemented in the future")
|
| 219 |
+
|
| 220 |
+
def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
|
| 221 |
+
raise NotImplementedError("OSSStorage.write_text to be implemented in the future")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
G_STORAGES = {}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class File(object):
|
| 228 |
+
_prefix_to_storage: dict = {
|
| 229 |
+
"oss": OSSStorage,
|
| 230 |
+
"http": HTTPStorage,
|
| 231 |
+
"https": HTTPStorage,
|
| 232 |
+
"local": LocalStorage,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def _get_storage(uri):
|
| 237 |
+
assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
|
| 238 |
+
|
| 239 |
+
if "://" not in uri:
|
| 240 |
+
# local path
|
| 241 |
+
storage_type = "local"
|
| 242 |
+
else:
|
| 243 |
+
prefix, _ = uri.split("://")
|
| 244 |
+
storage_type = prefix
|
| 245 |
+
|
| 246 |
+
assert storage_type in File._prefix_to_storage, (
|
| 247 |
+
f"Unsupported uri {uri}, valid prefixs: " f"{list(File._prefix_to_storage.keys())}"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if storage_type not in G_STORAGES:
|
| 251 |
+
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
|
| 252 |
+
|
| 253 |
+
return G_STORAGES[storage_type]
|
| 254 |
+
|
| 255 |
+
@staticmethod
|
| 256 |
+
def read(uri: str) -> bytes:
|
| 257 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
filepath (str or Path): Path to read data.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
bytes: Expected bytes object.
|
| 264 |
+
"""
|
| 265 |
+
storage = File._get_storage(uri)
|
| 266 |
+
return storage.read(uri)
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
|
| 270 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
filepath (str or Path): Path to read data.
|
| 274 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 275 |
+
Default: 'utf-8'.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
str: Expected text reading from ``filepath``.
|
| 279 |
+
"""
|
| 280 |
+
storage = File._get_storage(uri)
|
| 281 |
+
return storage.read_text(uri)
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def write(obj: bytes, uri: Union[str, Path]) -> None:
|
| 285 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 286 |
+
|
| 287 |
+
Note:
|
| 288 |
+
``write`` will create a directory if the directory of ``filepath``
|
| 289 |
+
does not exist.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
obj (bytes): Data to be written.
|
| 293 |
+
filepath (str or Path): Path to write data.
|
| 294 |
+
"""
|
| 295 |
+
storage = File._get_storage(uri)
|
| 296 |
+
return storage.write(obj, uri)
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
|
| 300 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
| 301 |
+
|
| 302 |
+
Note:
|
| 303 |
+
``write_text`` will create a directory if the directory of
|
| 304 |
+
``filepath`` does not exist.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
obj (str): Data to be written.
|
| 308 |
+
filepath (str or Path): Path to write data.
|
| 309 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 310 |
+
Default: 'utf-8'.
|
| 311 |
+
"""
|
| 312 |
+
storage = File._get_storage(uri)
|
| 313 |
+
return storage.write_text(obj, uri)
|
| 314 |
+
|
| 315 |
+
@contextlib.contextmanager
|
| 316 |
+
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
|
| 317 |
+
"""Only for unified API and do nothing."""
|
| 318 |
+
storage = File._get_storage(uri)
|
| 319 |
+
with storage.as_local_path(uri) as local_path:
|
| 320 |
+
yield local_path
|
funcineforge/download/name_maps_from_hub.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name_maps_ms = {
|
| 2 |
+
"paraformer": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 3 |
+
"paraformer-zh": "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 4 |
+
"paraformer-en": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
|
| 5 |
+
"paraformer-en-spk": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
|
| 6 |
+
"paraformer-zh-streaming": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
| 7 |
+
"fsmn-vad": "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
| 8 |
+
"ct-punc": "iic/punc_ct-transformer_cn-en-common-vocab471067-large",
|
| 9 |
+
"ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
| 10 |
+
"fa-zh": "iic/speech_timestamp_prediction-v1-16k-offline",
|
| 11 |
+
"cam++": "iic/speech_campplus_sv_zh-cn_16k-common",
|
| 12 |
+
"Whisper-large-v3": "iic/Whisper-large-v3",
|
| 13 |
+
"Qwen-Audio": "Qwen/Qwen-Audio",
|
| 14 |
+
"emotion2vec_plus_large": "iic/emotion2vec_plus_large",
|
| 15 |
+
"emotion2vec_plus_base": "iic/emotion2vec_plus_base",
|
| 16 |
+
"emotion2vec_plus_seed": "iic/emotion2vec_plus_seed",
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
name_maps_hf = {
|
| 20 |
+
"paraformer": "funasr/paraformer-zh",
|
| 21 |
+
"paraformer-zh": "funasr/paraformer-zh",
|
| 22 |
+
"paraformer-en": "funasr/paraformer-zh",
|
| 23 |
+
"paraformer-zh-streaming": "funasr/paraformer-zh-streaming",
|
| 24 |
+
"fsmn-vad": "funasr/fsmn-vad",
|
| 25 |
+
"ct-punc": "funasr/ct-punc",
|
| 26 |
+
"ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
| 27 |
+
"fa-zh": "funasr/fa-zh",
|
| 28 |
+
"cam++": "funasr/campplus",
|
| 29 |
+
"iic/emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
|
| 30 |
+
"iic/emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
|
| 31 |
+
"iic/emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
name_maps_openai = {
|
| 35 |
+
"Whisper-base.en": "base.en",
|
| 36 |
+
"Whisper-base": "base",
|
| 37 |
+
"Whisper-large": "large",
|
| 38 |
+
"Whisper-large-v1": "large-v1",
|
| 39 |
+
"Whisper-large-v2": "large-v2",
|
| 40 |
+
"Whisper-large-v3": "large-v3",
|
| 41 |
+
"Whisper-large-v3-turbo": "turbo",
|
| 42 |
+
}
|
funcineforge/face/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .face_recognition import FaceRecIR101
|
funcineforge/face/face_recognition.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def FaceRecIR101(init_param_path, **kwargs):
|
| 2 |
+
"""
|
| 3 |
+
Face embeddings extraction with CurricularFace pretrained model.
|
| 4 |
+
Reference:
|
| 5 |
+
- https://modelscope.cn/models/iic/cv_ir101_facerecognition_cfglint
|
| 6 |
+
"""
|
| 7 |
+
import onnxruntime
|
| 8 |
+
options = onnxruntime.SessionOptions()
|
| 9 |
+
options.intra_op_num_threads = 8
|
| 10 |
+
options.inter_op_num_threads = 8
|
| 11 |
+
ort_session = onnxruntime.InferenceSession(
|
| 12 |
+
init_param_path,
|
| 13 |
+
sess_options=options,
|
| 14 |
+
providers=['CPUExecutionProvider']
|
| 15 |
+
)
|
| 16 |
+
return ort_session
|
funcineforge/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .specaug.specaug import SpecAug as FunCineForgeSpecAug
|
| 2 |
+
from .language_model import FunCineForgeLM
|
| 3 |
+
from .causal_hifigan import CausalHifiGan
|
| 4 |
+
from .flow_matching_model import CosyVoiceFlowMatching
|
| 5 |
+
from .inference_model import FunCineForgeInferModel
|
funcineforge/models/causal_hifigan.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 KaiHu
|
| 2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""HIFI-GAN"""
|
| 5 |
+
|
| 6 |
+
from typing import Dict
|
| 7 |
+
from typing import Tuple, List
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.signal import get_window
|
| 11 |
+
import torch
|
| 12 |
+
import torchaudio
|
| 13 |
+
from torch import nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.nn.utils import remove_weight_norm
|
| 16 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
| 17 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 18 |
+
import logging
|
| 19 |
+
from funcineforge.utils.device_funcs import to_device
|
| 20 |
+
import os
|
| 21 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 22 |
+
from funcineforge.models.utils import dtype_map
|
| 23 |
+
from funcineforge.models.modules.hifigan import init_weights
|
| 24 |
+
from funcineforge.models.modules.hifigan.activations import Snake
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LookRightConv1d(torch.nn.Conv1d):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
in_channels: int,
|
| 31 |
+
out_channels: int,
|
| 32 |
+
kernel_size: int,
|
| 33 |
+
stride: int = 1,
|
| 34 |
+
dilation: int = 1,
|
| 35 |
+
groups: int = 1,
|
| 36 |
+
bias: bool = True,
|
| 37 |
+
padding_mode: str = 'zeros',
|
| 38 |
+
device=None,
|
| 39 |
+
dtype=None
|
| 40 |
+
) -> None:
|
| 41 |
+
super(LookRightConv1d, self).__init__(in_channels, out_channels,
|
| 42 |
+
kernel_size, stride,
|
| 43 |
+
padding=0, dilation=dilation,
|
| 44 |
+
groups=groups, bias=bias,
|
| 45 |
+
padding_mode=padding_mode,
|
| 46 |
+
device=device, dtype=dtype)
|
| 47 |
+
assert stride == 1
|
| 48 |
+
self.causal_padding = kernel_size - 1
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 51 |
+
if context.size(2) == 0:
|
| 52 |
+
x = F.pad(x, (0, self.causal_padding), value=0.0)
|
| 53 |
+
else:
|
| 54 |
+
assert context.size(2) == self.causal_padding
|
| 55 |
+
x = torch.concat([x, context], dim=2)
|
| 56 |
+
x = super(LookRightConv1d, self).forward(x)
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
class LookLeftConv1d(torch.nn.Conv1d):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
in_channels: int,
|
| 63 |
+
out_channels: int,
|
| 64 |
+
kernel_size: int,
|
| 65 |
+
stride: int = 1,
|
| 66 |
+
dilation: int = 1,
|
| 67 |
+
groups: int = 1,
|
| 68 |
+
bias: bool = True,
|
| 69 |
+
padding_mode: str = 'zeros',
|
| 70 |
+
device=None,
|
| 71 |
+
dtype=None
|
| 72 |
+
) -> None:
|
| 73 |
+
super(LookLeftConv1d, self).__init__(in_channels, out_channels,
|
| 74 |
+
kernel_size, stride,
|
| 75 |
+
padding=0, dilation=dilation,
|
| 76 |
+
groups=groups, bias=bias,
|
| 77 |
+
padding_mode=padding_mode,
|
| 78 |
+
device=device, dtype=dtype)
|
| 79 |
+
assert stride == 1 and dilation == 1
|
| 80 |
+
self.causal_padding = kernel_size - 1
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 83 |
+
if cache.size(2) == 0:
|
| 84 |
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 85 |
+
else:
|
| 86 |
+
assert cache.size(2) == self.causal_padding
|
| 87 |
+
x = torch.concat([cache, x], dim=2)
|
| 88 |
+
# NOTE 兼容kernel_size=1的情况
|
| 89 |
+
if self.causal_padding == 0:
|
| 90 |
+
cache_new = x[:, :, :0]
|
| 91 |
+
else:
|
| 92 |
+
cache_new = x[:, :, -self.causal_padding:]
|
| 93 |
+
x = super(LookLeftConv1d, self).forward(x)
|
| 94 |
+
return x, cache_new
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CausalConvRNNF0Predictor(nn.Module):
|
| 98 |
+
def __init__(self,
|
| 99 |
+
num_class: int = 1,
|
| 100 |
+
in_channels: int = 80,
|
| 101 |
+
cond_channels: int = 512
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
self.num_class = num_class
|
| 106 |
+
self.condnet = nn.Sequential(
|
| 107 |
+
weight_norm(
|
| 108 |
+
LookRightConv1d(in_channels, cond_channels, kernel_size=4)
|
| 109 |
+
),
|
| 110 |
+
nn.ELU(),
|
| 111 |
+
weight_norm(
|
| 112 |
+
LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
|
| 113 |
+
),
|
| 114 |
+
nn.ELU(),
|
| 115 |
+
weight_norm(
|
| 116 |
+
LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
|
| 117 |
+
),
|
| 118 |
+
nn.ELU(),
|
| 119 |
+
weight_norm(
|
| 120 |
+
LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
|
| 121 |
+
),
|
| 122 |
+
nn.ELU(),
|
| 123 |
+
weight_norm(
|
| 124 |
+
LookLeftConv1d(cond_channels, cond_channels, kernel_size=3)
|
| 125 |
+
),
|
| 126 |
+
nn.ELU(),
|
| 127 |
+
)
|
| 128 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
| 129 |
+
|
| 130 |
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0, 0), finalize: bool = True) -> torch.Tensor:
|
| 131 |
+
if finalize is False:
|
| 132 |
+
x, context = x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:]
|
| 133 |
+
else:
|
| 134 |
+
x, context = x, x[:, :, :0]
|
| 135 |
+
x = self.condnet[0](x, context)
|
| 136 |
+
x = self.condnet[1](x)
|
| 137 |
+
if cache.size(0) != 0:
|
| 138 |
+
x, cache[0] = self.condnet[2](x, cache[0])
|
| 139 |
+
else:
|
| 140 |
+
x, _ = self.condnet[2](x)
|
| 141 |
+
x = self.condnet[3](x)
|
| 142 |
+
if cache.size(0) != 0:
|
| 143 |
+
x, cache[1] = self.condnet[4](x, cache[1])
|
| 144 |
+
else:
|
| 145 |
+
x, _ = self.condnet[4](x)
|
| 146 |
+
x = self.condnet[5](x)
|
| 147 |
+
if cache.size(0) != 0:
|
| 148 |
+
x, cache[2] = self.condnet[6](x, cache[2])
|
| 149 |
+
else:
|
| 150 |
+
x, _ = self.condnet[6](x)
|
| 151 |
+
x = self.condnet[7](x)
|
| 152 |
+
if cache.size(0) != 0:
|
| 153 |
+
x, cache[3] = self.condnet[8](x, cache[3])
|
| 154 |
+
else:
|
| 155 |
+
x, _ = self.condnet[8](x)
|
| 156 |
+
x = self.condnet[9](x)
|
| 157 |
+
x = x.transpose(1, 2)
|
| 158 |
+
x = torch.abs(self.classifier(x).squeeze(-1))
|
| 159 |
+
return x, cache
|
| 160 |
+
|
| 161 |
+
def init_cache(self, device):
|
| 162 |
+
return torch.zeros(4, 1, 512, 2).to(device)
|
| 163 |
+
|
| 164 |
+
def remove_weight_norm(self):
|
| 165 |
+
print('Removing weight norm...')
|
| 166 |
+
try:
|
| 167 |
+
remove_weight_norm(self.condnet[0])
|
| 168 |
+
remove_weight_norm(self.condnet[2])
|
| 169 |
+
remove_weight_norm(self.condnet[4])
|
| 170 |
+
remove_weight_norm(self.condnet[6])
|
| 171 |
+
remove_weight_norm(self.condnet[8])
|
| 172 |
+
except:
|
| 173 |
+
remove_parametrizations(self.condnet[0], 'weight')
|
| 174 |
+
remove_parametrizations(self.condnet[2], 'weight')
|
| 175 |
+
remove_parametrizations(self.condnet[4], 'weight')
|
| 176 |
+
remove_parametrizations(self.condnet[6], 'weight')
|
| 177 |
+
remove_parametrizations(self.condnet[8], 'weight')
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class LookLeftConvTranspose1d(torch.nn.Conv1d):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
in_channels: int,
|
| 184 |
+
out_channels: int,
|
| 185 |
+
kernel_size: int,
|
| 186 |
+
stride: int = 1,
|
| 187 |
+
dilation: int = 1,
|
| 188 |
+
groups: int = 1,
|
| 189 |
+
bias: bool = True,
|
| 190 |
+
padding_mode: str = 'zeros',
|
| 191 |
+
device=None,
|
| 192 |
+
dtype=None
|
| 193 |
+
) -> None:
|
| 194 |
+
super(LookLeftConvTranspose1d, self).__init__(in_channels, out_channels,
|
| 195 |
+
kernel_size, 1,
|
| 196 |
+
padding=0, dilation=dilation,
|
| 197 |
+
groups=groups, bias=bias,
|
| 198 |
+
padding_mode=padding_mode,
|
| 199 |
+
device=device, dtype=dtype)
|
| 200 |
+
assert dilation == 1 and stride != 1
|
| 201 |
+
self.causal_padding = kernel_size - 1
|
| 202 |
+
self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
|
| 203 |
+
|
| 204 |
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 205 |
+
x = self.upsample(x)
|
| 206 |
+
if cache.size(2) == 0:
|
| 207 |
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 208 |
+
else:
|
| 209 |
+
assert cache.size(2) == self.causal_padding
|
| 210 |
+
x = torch.concat([cache, x], dim=2)
|
| 211 |
+
cache_new = x[:, :, -self.causal_padding:]
|
| 212 |
+
x = super(LookLeftConvTranspose1d, self).forward(x)
|
| 213 |
+
return x, cache_new
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class LookLeftConv1dWithStride(torch.nn.Conv1d):
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
in_channels: int,
|
| 220 |
+
out_channels: int,
|
| 221 |
+
kernel_size: int,
|
| 222 |
+
stride: int = 1,
|
| 223 |
+
dilation: int = 1,
|
| 224 |
+
groups: int = 1,
|
| 225 |
+
bias: bool = True,
|
| 226 |
+
padding_mode: str = 'zeros',
|
| 227 |
+
device=None,
|
| 228 |
+
dtype=None
|
| 229 |
+
) -> None:
|
| 230 |
+
super(LookLeftConv1dWithStride, self).__init__(in_channels, out_channels,
|
| 231 |
+
kernel_size, stride,
|
| 232 |
+
padding=0, dilation=dilation,
|
| 233 |
+
groups=groups, bias=bias,
|
| 234 |
+
padding_mode=padding_mode,
|
| 235 |
+
device=device, dtype=dtype)
|
| 236 |
+
assert stride != 1 and dilation == 1
|
| 237 |
+
assert kernel_size % stride == 0
|
| 238 |
+
self.causal_padding = stride - 1
|
| 239 |
+
|
| 240 |
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 241 |
+
if cache.size(2) == 0:
|
| 242 |
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 243 |
+
else:
|
| 244 |
+
assert cache.size(2) == self.causal_padding
|
| 245 |
+
x = torch.concat([cache, x], dim=2)
|
| 246 |
+
cache_new = x[:, :, -self.causal_padding:]
|
| 247 |
+
x = super(LookLeftConv1dWithStride, self).forward(x)
|
| 248 |
+
return x, cache_new
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class LookLeftConv1dWithDilation(torch.nn.Conv1d):
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
in_channels: int,
|
| 255 |
+
out_channels: int,
|
| 256 |
+
kernel_size: int,
|
| 257 |
+
stride: int = 1,
|
| 258 |
+
dilation: int = 1,
|
| 259 |
+
groups: int = 1,
|
| 260 |
+
bias: bool = True,
|
| 261 |
+
padding_mode: str = 'zeros',
|
| 262 |
+
device=None,
|
| 263 |
+
dtype=None
|
| 264 |
+
) -> None:
|
| 265 |
+
super(LookLeftConv1dWithDilation, self).__init__(in_channels, out_channels,
|
| 266 |
+
kernel_size, stride,
|
| 267 |
+
padding=0, dilation=dilation,
|
| 268 |
+
groups=groups, bias=bias,
|
| 269 |
+
padding_mode=padding_mode,
|
| 270 |
+
device=device, dtype=dtype)
|
| 271 |
+
# NOTE(lyuxiang.lx) 这个causal_padding仅在kernel_size为奇数时才成立
|
| 272 |
+
assert kernel_size // 2 * dilation * 2 == int((kernel_size * dilation - dilation) / 2) * 2
|
| 273 |
+
self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2
|
| 274 |
+
|
| 275 |
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 276 |
+
if cache.size(2) == 0:
|
| 277 |
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
| 278 |
+
else:
|
| 279 |
+
assert cache.size(2) == self.causal_padding
|
| 280 |
+
x = torch.concat([cache, x], dim=2)
|
| 281 |
+
cache_new = x[:, :, -self.causal_padding:]
|
| 282 |
+
x = super(LookLeftConv1dWithDilation, self).forward(x)
|
| 283 |
+
return x, cache_new
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class ResBlock(torch.nn.Module):
|
| 287 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
| 288 |
+
def __init__(
|
| 289 |
+
self,
|
| 290 |
+
channels: int = 512,
|
| 291 |
+
kernel_size: int = 3,
|
| 292 |
+
dilations: List[int] = [1, 3, 5],
|
| 293 |
+
):
|
| 294 |
+
super(ResBlock, self).__init__()
|
| 295 |
+
self.convs1 = nn.ModuleList()
|
| 296 |
+
self.convs2 = nn.ModuleList()
|
| 297 |
+
|
| 298 |
+
for dilation in dilations:
|
| 299 |
+
self.convs1.append(
|
| 300 |
+
weight_norm(
|
| 301 |
+
LookLeftConv1dWithDilation(
|
| 302 |
+
channels,
|
| 303 |
+
channels,
|
| 304 |
+
kernel_size,
|
| 305 |
+
1,
|
| 306 |
+
dilation=dilation
|
| 307 |
+
) if dilation != 1 else
|
| 308 |
+
LookLeftConv1d(
|
| 309 |
+
channels,
|
| 310 |
+
channels,
|
| 311 |
+
kernel_size,
|
| 312 |
+
1,
|
| 313 |
+
dilation=dilation
|
| 314 |
+
)
|
| 315 |
+
)
|
| 316 |
+
)
|
| 317 |
+
self.convs2.append(
|
| 318 |
+
weight_norm(
|
| 319 |
+
LookLeftConv1d(
|
| 320 |
+
channels,
|
| 321 |
+
channels,
|
| 322 |
+
kernel_size,
|
| 323 |
+
1,
|
| 324 |
+
dilation=1
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
+
self.convs1.apply(init_weights)
|
| 329 |
+
self.convs2.apply(init_weights)
|
| 330 |
+
self.activations1 = nn.ModuleList([
|
| 331 |
+
Snake(channels, alpha_logscale=False)
|
| 332 |
+
for _ in range(len(self.convs1))
|
| 333 |
+
])
|
| 334 |
+
self.activations2 = nn.ModuleList([
|
| 335 |
+
Snake(channels, alpha_logscale=False)
|
| 336 |
+
for _ in range(len(self.convs2))
|
| 337 |
+
])
|
| 338 |
+
|
| 339 |
+
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)) -> torch.Tensor:
|
| 340 |
+
for idx in range(len(self.convs1)):
|
| 341 |
+
xt = self.activations1[idx](x)
|
| 342 |
+
xt, _ = self.convs1[idx](xt)
|
| 343 |
+
xt = self.activations2[idx](xt)
|
| 344 |
+
xt, _ = self.convs2[idx](xt)
|
| 345 |
+
x = xt + x
|
| 346 |
+
return x, cache
|
| 347 |
+
|
| 348 |
+
def remove_weight_norm(self):
|
| 349 |
+
for idx in range(len(self.convs1)):
|
| 350 |
+
try:
|
| 351 |
+
remove_weight_norm(self.convs1[idx])
|
| 352 |
+
remove_weight_norm(self.convs2[idx])
|
| 353 |
+
except:
|
| 354 |
+
remove_parametrizations(self.convs1[idx], 'weight')
|
| 355 |
+
remove_parametrizations(self.convs2[idx], 'weight')
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class SineGen(torch.nn.Module):
|
| 359 |
+
""" Definition of sine generator
|
| 360 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 361 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 362 |
+
voiced_threshold = 0,
|
| 363 |
+
flag_for_pulse=False)
|
| 364 |
+
samp_rate: sampling rate in Hz
|
| 365 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 366 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 367 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 368 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 369 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 370 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 371 |
+
segment is always sin(np.pi) or cos(0)
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
| 375 |
+
sine_amp=0.1, noise_std=0.003,
|
| 376 |
+
voiced_threshold=0,
|
| 377 |
+
flag_for_pulse=False):
|
| 378 |
+
super(SineGen, self).__init__()
|
| 379 |
+
self.sine_amp = sine_amp
|
| 380 |
+
self.noise_std = noise_std
|
| 381 |
+
self.harmonic_num = harmonic_num
|
| 382 |
+
self.dim = self.harmonic_num + 1
|
| 383 |
+
self.sampling_rate = samp_rate
|
| 384 |
+
self.voiced_threshold = voiced_threshold
|
| 385 |
+
self.flag_for_pulse = flag_for_pulse
|
| 386 |
+
self.upsample_scale = upsample_scale
|
| 387 |
+
self.rand_ini = torch.rand(1, 9)
|
| 388 |
+
self.rand_ini[:, 0] = 0
|
| 389 |
+
self.sine_waves = torch.rand(1, 300 * 24000, 9)
|
| 390 |
+
|
| 391 |
+
def _f02uv(self, f0):
|
| 392 |
+
# generate uv signal
|
| 393 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 394 |
+
return uv
|
| 395 |
+
|
| 396 |
+
def _f02sine(self, f0_values):
|
| 397 |
+
""" f0_values: (batchsize, length, dim)
|
| 398 |
+
where dim indicates fundamental tone and overtones
|
| 399 |
+
"""
|
| 400 |
+
# convert to F0 in rad. The interger part n can be ignored
|
| 401 |
+
# because 2 * np.pi * n doesn't affect phase
|
| 402 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
| 403 |
+
|
| 404 |
+
# initial phase noise (no noise for fundamental component)
|
| 405 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
|
| 406 |
+
|
| 407 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
| 408 |
+
if not self.flag_for_pulse:
|
| 409 |
+
# # for normal case
|
| 410 |
+
|
| 411 |
+
# # To prevent torch.cumsum numerical overflow,
|
| 412 |
+
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
| 413 |
+
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
| 414 |
+
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
| 415 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
| 416 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
| 417 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
| 418 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
| 419 |
+
|
| 420 |
+
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
| 421 |
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
| 422 |
+
scale_factor=1/self.upsample_scale,
|
| 423 |
+
mode="linear").transpose(1, 2)
|
| 424 |
+
|
| 425 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
| 426 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
| 427 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
| 428 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
| 429 |
+
|
| 430 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
| 431 |
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
| 432 |
+
scale_factor=self.upsample_scale, mode="nearest").transpose(1, 2)
|
| 433 |
+
sines = torch.sin(phase)
|
| 434 |
+
|
| 435 |
+
else:
|
| 436 |
+
# If necessary, make sure that the first time step of every
|
| 437 |
+
# voiced segments is sin(pi) or cos(0)
|
| 438 |
+
# This is used for pulse-train generation
|
| 439 |
+
|
| 440 |
+
# identify the last time step in unvoiced segments
|
| 441 |
+
uv = self._f02uv(f0_values)
|
| 442 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
| 443 |
+
uv_1[:, -1, :] = 1
|
| 444 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
| 445 |
+
|
| 446 |
+
# get the instantanouse phase
|
| 447 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
| 448 |
+
# different batch needs to be processed differently
|
| 449 |
+
for idx in range(f0_values.shape[0]):
|
| 450 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
| 451 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
| 452 |
+
# stores the accumulation of i.phase within
|
| 453 |
+
# each voiced segments
|
| 454 |
+
tmp_cumsum[idx, :, :] = 0
|
| 455 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
| 456 |
+
|
| 457 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
| 458 |
+
# within the previous voiced segment.
|
| 459 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
| 460 |
+
|
| 461 |
+
# get the sines
|
| 462 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
| 463 |
+
return sines
|
| 464 |
+
|
| 465 |
+
def forward(self, f0):
|
| 466 |
+
""" sine_tensor, uv = forward(f0)
|
| 467 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
| 468 |
+
f0 for unvoiced steps should be 0
|
| 469 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
| 470 |
+
output uv: tensor(batchsize=1, length, 1)
|
| 471 |
+
"""
|
| 472 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
| 473 |
+
device=f0.device)
|
| 474 |
+
# fundamental component
|
| 475 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
| 476 |
+
|
| 477 |
+
# generate sine waveforms
|
| 478 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
| 479 |
+
|
| 480 |
+
# generate uv signal
|
| 481 |
+
# uv = torch.ones(f0.shape)
|
| 482 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
| 483 |
+
uv = self._f02uv(f0)
|
| 484 |
+
|
| 485 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 486 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 487 |
+
# . for voiced regions is self.noise_std
|
| 488 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 489 |
+
noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
|
| 490 |
+
|
| 491 |
+
# first: set the unvoiced part to 0 by uv
|
| 492 |
+
# then: additive noise
|
| 493 |
+
sine_waves = sine_waves * uv + noise
|
| 494 |
+
return sine_waves, uv, noise
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 498 |
+
""" SourceModule for hn-nsf
|
| 499 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 500 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 501 |
+
sampling_rate: sampling_rate in Hz
|
| 502 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 503 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 504 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 505 |
+
note that amplitude of noise in unvoiced is decided
|
| 506 |
+
by sine_amp
|
| 507 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 508 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 509 |
+
F0_sampled (batchsize, length, 1)
|
| 510 |
+
Sine_source (batchsize, length, 1)
|
| 511 |
+
noise_source (batchsize, length 1)
|
| 512 |
+
uv (batchsize, length, 1)
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
| 516 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 517 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 518 |
+
|
| 519 |
+
self.sine_amp = sine_amp
|
| 520 |
+
self.noise_std = add_noise_std
|
| 521 |
+
|
| 522 |
+
# to produce sine waveforms
|
| 523 |
+
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
| 524 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 525 |
+
|
| 526 |
+
# to merge source harmonics into a single excitation
|
| 527 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 528 |
+
self.l_tanh = torch.nn.Tanh()
|
| 529 |
+
self.uv = torch.rand(1, 300 * 24000, 1)
|
| 530 |
+
|
| 531 |
+
def forward(self, x):
|
| 532 |
+
"""
|
| 533 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 534 |
+
F0_sampled (batchsize, length, 1)
|
| 535 |
+
Sine_source (batchsize, length, 1)
|
| 536 |
+
noise_source (batchsize, length 1)
|
| 537 |
+
"""
|
| 538 |
+
# source for harmonic branch
|
| 539 |
+
with torch.no_grad():
|
| 540 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
| 541 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 542 |
+
|
| 543 |
+
# source for noise branch, in the same shape as uv
|
| 544 |
+
noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
|
| 545 |
+
return sine_merge, noise, uv
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class CausalHiFTGenerator(nn.Module):
|
| 549 |
+
"""
|
| 550 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
| 551 |
+
https://arxiv.org/abs/2309.09493
|
| 552 |
+
"""
|
| 553 |
+
def __init__(
|
| 554 |
+
self,
|
| 555 |
+
in_channels: int = 80,
|
| 556 |
+
base_channels: int = 512,
|
| 557 |
+
nb_harmonics: int = 8,
|
| 558 |
+
sampling_rate: int = 22050,
|
| 559 |
+
nsf_alpha: float = 0.1,
|
| 560 |
+
nsf_sigma: float = 0.003,
|
| 561 |
+
nsf_voiced_threshold: float = 10,
|
| 562 |
+
upsample_rates: List[int] = [8, 8],
|
| 563 |
+
upsample_kernel_sizes: List[int] = [16, 16],
|
| 564 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
| 565 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
| 566 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 567 |
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
| 568 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
| 569 |
+
lrelu_slope: float = 0.1,
|
| 570 |
+
audio_limit: float = 0.99,
|
| 571 |
+
f0_predictor: torch.nn.Module = None,
|
| 572 |
+
):
|
| 573 |
+
super(CausalHiFTGenerator, self).__init__()
|
| 574 |
+
|
| 575 |
+
self.out_channels = 1
|
| 576 |
+
self.nb_harmonics = nb_harmonics
|
| 577 |
+
self.sampling_rate = sampling_rate
|
| 578 |
+
self.istft_params = istft_params
|
| 579 |
+
self.lrelu_slope = lrelu_slope
|
| 580 |
+
self.audio_limit = audio_limit
|
| 581 |
+
|
| 582 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 583 |
+
self.num_upsamples = len(upsample_rates)
|
| 584 |
+
self.m_source = SourceModuleHnNSF(
|
| 585 |
+
sampling_rate=sampling_rate,
|
| 586 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
| 587 |
+
harmonic_num=nb_harmonics,
|
| 588 |
+
sine_amp=nsf_alpha,
|
| 589 |
+
add_noise_std=nsf_sigma,
|
| 590 |
+
voiced_threshod=nsf_voiced_threshold)
|
| 591 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"], mode='nearest')
|
| 592 |
+
|
| 593 |
+
self.conv_pre = weight_norm(
|
| 594 |
+
LookRightConv1d(in_channels, base_channels, 5, 1)
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Up
|
| 598 |
+
self.ups = nn.ModuleList()
|
| 599 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 600 |
+
self.ups.append(
|
| 601 |
+
weight_norm(
|
| 602 |
+
LookLeftConvTranspose1d(
|
| 603 |
+
base_channels // (2**i),
|
| 604 |
+
base_channels // (2**(i + 1)),
|
| 605 |
+
k,
|
| 606 |
+
u
|
| 607 |
+
)
|
| 608 |
+
)
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Down
|
| 612 |
+
self.source_downs = nn.ModuleList()
|
| 613 |
+
self.source_resblocks = nn.ModuleList()
|
| 614 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 615 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 616 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
| 617 |
+
if u == 1:
|
| 618 |
+
self.source_downs.append(
|
| 619 |
+
LookLeftConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
| 620 |
+
)
|
| 621 |
+
else:
|
| 622 |
+
self.source_downs.append(
|
| 623 |
+
LookLeftConv1dWithStride(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
self.source_resblocks.append(
|
| 627 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
self.resblocks = nn.ModuleList()
|
| 631 |
+
for i in range(len(self.ups)):
|
| 632 |
+
ch = base_channels // (2**(i + 1))
|
| 633 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 634 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
| 635 |
+
|
| 636 |
+
self.conv_post = weight_norm(LookLeftConv1d(ch, istft_params["n_fft"] + 2, 7, 1))
|
| 637 |
+
self.ups.apply(init_weights)
|
| 638 |
+
self.conv_post.apply(init_weights)
|
| 639 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
| 640 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
| 641 |
+
self.f0_predictor = f0_predictor
|
| 642 |
+
# f0回退3帧,hift回退5帧
|
| 643 |
+
self.context_size = 8
|
| 644 |
+
|
| 645 |
+
def remove_weight_norm(self):
|
| 646 |
+
print('Removing weight norm...')
|
| 647 |
+
for l in self.ups:
|
| 648 |
+
try:
|
| 649 |
+
remove_weight_norm(l)
|
| 650 |
+
except:
|
| 651 |
+
remove_parametrizations(l, 'weight')
|
| 652 |
+
for l in self.resblocks:
|
| 653 |
+
l.remove_weight_norm()
|
| 654 |
+
try:
|
| 655 |
+
remove_weight_norm(self.conv_pre)
|
| 656 |
+
remove_weight_norm(self.conv_post)
|
| 657 |
+
except:
|
| 658 |
+
remove_parametrizations(self.conv_pre, 'weight')
|
| 659 |
+
remove_parametrizations(self.conv_post, 'weight')
|
| 660 |
+
self.f0_predictor.remove_weight_norm()
|
| 661 |
+
for l in self.source_resblocks:
|
| 662 |
+
l.remove_weight_norm()
|
| 663 |
+
|
| 664 |
+
def _stft(self, x):
|
| 665 |
+
spec = torch.stft(
|
| 666 |
+
x,
|
| 667 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
| 668 |
+
return_complex=True)
|
| 669 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 670 |
+
return spec[..., 0], spec[..., 1]
|
| 671 |
+
|
| 672 |
+
def _istft(self, magnitude, phase):
|
| 673 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
| 674 |
+
real = magnitude * torch.cos(phase)
|
| 675 |
+
img = magnitude * torch.sin(phase)
|
| 676 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
| 677 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
| 678 |
+
return inverse_transform
|
| 679 |
+
|
| 680 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(0, 0, 0), finalize: bool = True) -> torch.Tensor:
|
| 681 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 682 |
+
# NOTE(lyuxiang.lx) 回退4帧
|
| 683 |
+
if finalize is False:
|
| 684 |
+
s_stft_real, s_stft_imag = s_stft_real[:, :, :-int(480 * 4 / self.istft_params["hop_len"])], s_stft_imag[:, :, :-int(480 * 4 / self.istft_params["hop_len"])]
|
| 685 |
+
x = self.conv_pre(x[:, :, :-4], x[:, :, -4:])
|
| 686 |
+
else:
|
| 687 |
+
x = self.conv_pre(x)
|
| 688 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 689 |
+
for i in range(self.num_upsamples):
|
| 690 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 691 |
+
x, _ = self.ups[i](x)
|
| 692 |
+
|
| 693 |
+
if i == self.num_upsamples - 1:
|
| 694 |
+
x = self.reflection_pad(x)
|
| 695 |
+
|
| 696 |
+
# fusion
|
| 697 |
+
si, _ = self.source_downs[i](s_stft)
|
| 698 |
+
si, _ = self.source_resblocks[i](si)
|
| 699 |
+
x = x + si
|
| 700 |
+
|
| 701 |
+
xs = None
|
| 702 |
+
for j in range(self.num_kernels):
|
| 703 |
+
this_xs, _ = self.resblocks[i * self.num_kernels + j](x)
|
| 704 |
+
if xs is None:
|
| 705 |
+
xs = this_xs
|
| 706 |
+
else:
|
| 707 |
+
xs += this_xs
|
| 708 |
+
x = xs / self.num_kernels
|
| 709 |
+
|
| 710 |
+
x = F.leaky_relu(x)
|
| 711 |
+
x, _ = self.conv_post(x)
|
| 712 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
| 713 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
| 714 |
+
|
| 715 |
+
x = self._istft(magnitude, phase)
|
| 716 |
+
# NOTE(lyuxiang.lx) 回退1帧
|
| 717 |
+
if finalize is False:
|
| 718 |
+
x = x[:, :-480]
|
| 719 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
| 720 |
+
return x
|
| 721 |
+
|
| 722 |
+
@torch.inference_mode()
|
| 723 |
+
def inference(self, speech_feat: torch.Tensor, f0_cpu: bool = False, finalize: bool = True) -> torch.Tensor:
|
| 724 |
+
# mel->f0->source
|
| 725 |
+
if f0_cpu is True:
|
| 726 |
+
self.f0_predictor.to('cpu')
|
| 727 |
+
f0, _ = self.f0_predictor(speech_feat.cpu(), finalize=finalize)
|
| 728 |
+
f0 = f0.to(speech_feat.device)
|
| 729 |
+
else:
|
| 730 |
+
self.f0_predictor.to(speech_feat.device)
|
| 731 |
+
f0, _ = self.f0_predictor(speech_feat, finalize=finalize)
|
| 732 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 733 |
+
s, _, _ = self.m_source(s)
|
| 734 |
+
s = s.transpose(1, 2)
|
| 735 |
+
if finalize is False:
|
| 736 |
+
generated_speech = self.decode(speech_feat[:, :, :-3], s, finalize=finalize)
|
| 737 |
+
else:
|
| 738 |
+
generated_speech = self.decode(speech_feat, s, finalize=finalize)
|
| 739 |
+
return generated_speech, []
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
class CausalHifiGan(nn.Module):
|
| 743 |
+
"""HIFIGAN-style vocoders (generator [stack of time-level-upsampling blocks] + discriminator).
|
| 744 |
+
NSF-HIFIGAN, HiFTNet Optional.
|
| 745 |
+
"""
|
| 746 |
+
|
| 747 |
+
def __init__(
|
| 748 |
+
self,
|
| 749 |
+
CausalHiFTGenerator_conf: dict = {},
|
| 750 |
+
CausalConvRNNF0Predictor_conf: dict = {},
|
| 751 |
+
sample_rate: float = 24000,
|
| 752 |
+
**kwargs
|
| 753 |
+
):
|
| 754 |
+
super().__init__()
|
| 755 |
+
self.generator = CausalHiFTGenerator(**CausalHiFTGenerator_conf)
|
| 756 |
+
self.generator.f0_predictor = CausalConvRNNF0Predictor(**CausalConvRNNF0Predictor_conf)
|
| 757 |
+
self.generator.remove_weight_norm()
|
| 758 |
+
self.sample_rate = sample_rate
|
| 759 |
+
|
| 760 |
+
def inference_prepare(
|
| 761 |
+
self,
|
| 762 |
+
data_in,
|
| 763 |
+
data_lengths=None,
|
| 764 |
+
key: list = None,
|
| 765 |
+
**kwargs,
|
| 766 |
+
):
|
| 767 |
+
if kwargs.get("batch_size", 1) > 1:
|
| 768 |
+
raise NotImplementedError("batch decoding is not implemented")
|
| 769 |
+
|
| 770 |
+
feat_list = []
|
| 771 |
+
feat_len_list = []
|
| 772 |
+
for i, feat in enumerate(data_in):
|
| 773 |
+
if isinstance(feat, str) and os.path.exists(feat):
|
| 774 |
+
feat = np.load(feat)
|
| 775 |
+
if isinstance(feat, np.ndarray):
|
| 776 |
+
feat = torch.from_numpy(feat)
|
| 777 |
+
|
| 778 |
+
feat_list.append(feat)
|
| 779 |
+
feat_len_list.append(feat.shape[0])
|
| 780 |
+
|
| 781 |
+
batch = {
|
| 782 |
+
"x": pad_sequence(feat_list, batch_first=True),
|
| 783 |
+
"x_lengths": torch.tensor(feat_len_list, dtype=torch.int64),
|
| 784 |
+
}
|
| 785 |
+
batch = to_device(batch, kwargs["device"])
|
| 786 |
+
|
| 787 |
+
return batch
|
| 788 |
+
|
| 789 |
+
def inference(
|
| 790 |
+
self,
|
| 791 |
+
data_in,
|
| 792 |
+
data_lengths=None,
|
| 793 |
+
key: list = None,
|
| 794 |
+
f0_cpu: bool = True,
|
| 795 |
+
finalize: bool = True,
|
| 796 |
+
**kwargs,
|
| 797 |
+
) -> torch.Tensor:
|
| 798 |
+
"""Run inference.
|
| 799 |
+
|
| 800 |
+
Args:
|
| 801 |
+
x (torch.Tensor): input representation, B x T x C
|
| 802 |
+
|
| 803 |
+
Returns:
|
| 804 |
+
Dict[str, Tensor]:
|
| 805 |
+
* recon_speech (Tensor): Reconstructed waveform tensor (B, T_wav).
|
| 806 |
+
|
| 807 |
+
"""
|
| 808 |
+
uttid = key[0]
|
| 809 |
+
batch = self.inference_prepare(data_in, data_lengths, key, **kwargs)
|
| 810 |
+
voc_dtype = dtype_map[kwargs.get("voc_dtype", "fp32")]
|
| 811 |
+
x = batch["x"].to(voc_dtype)
|
| 812 |
+
recon_speech = self.generator.inference(x.transpose(1, 2), f0_cpu=f0_cpu, finalize=finalize)[0].squeeze(1)
|
| 813 |
+
recon_speech = recon_speech.float()
|
| 814 |
+
logging.info(f"{uttid}: wav lengths {recon_speech.shape[1]}")
|
| 815 |
+
|
| 816 |
+
output_dir = kwargs.get("output_dir", None)
|
| 817 |
+
output_sr = kwargs.get("output_sr", None)
|
| 818 |
+
if output_dir is not None:
|
| 819 |
+
wav_out_dir = os.path.join(output_dir, "wav")
|
| 820 |
+
os.makedirs(wav_out_dir, exist_ok=True)
|
| 821 |
+
wav_sr = self.sample_rate
|
| 822 |
+
if output_sr is not None and output_sr != self.sample_rate:
|
| 823 |
+
recon_speech = torchaudio.functional.resample(
|
| 824 |
+
recon_speech,
|
| 825 |
+
orig_freq=self.sample_rate,
|
| 826 |
+
new_freq=output_sr
|
| 827 |
+
)
|
| 828 |
+
wav_sr = output_sr
|
| 829 |
+
torchaudio.save(
|
| 830 |
+
os.path.join(wav_out_dir, f"{key[0]}.wav"), recon_speech.cpu(),
|
| 831 |
+
sample_rate=wav_sr, encoding='PCM_S', bits_per_sample=16
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
return recon_speech
|
funcineforge/models/flow_matching_model.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import Dict
|
| 6 |
+
import logging
|
| 7 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from funcineforge.models.utils.nets_utils import make_pad_mask
|
| 10 |
+
from funcineforge.utils.device_funcs import to_device
|
| 11 |
+
import numpy as np
|
| 12 |
+
from funcineforge.utils.load_utils import extract_campp_xvec
|
| 13 |
+
import time
|
| 14 |
+
from funcineforge.models.utils import dtype_map
|
| 15 |
+
from funcineforge.utils.hinter import hint_once
|
| 16 |
+
from funcineforge.models.utils.masks import add_optional_chunk_mask
|
| 17 |
+
from .modules.dit_flow_matching.dit_model import DiT
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Audio2Mel(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
n_fft=1024,
|
| 24 |
+
hop_length=256,
|
| 25 |
+
win_length=1024,
|
| 26 |
+
sampling_rate=22050,
|
| 27 |
+
n_mel_channels=80,
|
| 28 |
+
mel_fmin=0.0,
|
| 29 |
+
mel_fmax=None,
|
| 30 |
+
center=False,
|
| 31 |
+
device='cuda',
|
| 32 |
+
feat_type="power_log",
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
##############################################
|
| 36 |
+
# FFT Parameters
|
| 37 |
+
##############################################
|
| 38 |
+
window = torch.hann_window(win_length, device=device).float()
|
| 39 |
+
mel_basis = librosa_mel_fn(
|
| 40 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
|
| 41 |
+
)
|
| 42 |
+
mel_basis = torch.from_numpy(mel_basis).float().to(device)
|
| 43 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 44 |
+
self.register_buffer("window", window)
|
| 45 |
+
self.n_fft = n_fft
|
| 46 |
+
self.hop_length = hop_length
|
| 47 |
+
self.win_length = win_length
|
| 48 |
+
self.sampling_rate = sampling_rate
|
| 49 |
+
self.n_mel_channels = n_mel_channels
|
| 50 |
+
self.mel_fmax = mel_fmax
|
| 51 |
+
self.center = center
|
| 52 |
+
self.feat_type = feat_type
|
| 53 |
+
|
| 54 |
+
def forward(self, audioin):
|
| 55 |
+
p = (self.n_fft - self.hop_length) // 2
|
| 56 |
+
audio = F.pad(audioin, (p, p), "reflect").squeeze(1)
|
| 57 |
+
fft = torch.stft(
|
| 58 |
+
audio,
|
| 59 |
+
n_fft=self.n_fft,
|
| 60 |
+
hop_length=self.hop_length,
|
| 61 |
+
win_length=self.win_length,
|
| 62 |
+
window=self.window,
|
| 63 |
+
center=self.center,
|
| 64 |
+
pad_mode="reflect",
|
| 65 |
+
normalized=False,
|
| 66 |
+
onesided=True,
|
| 67 |
+
return_complex=True,
|
| 68 |
+
)
|
| 69 |
+
if self.feat_type == "mag_log10":
|
| 70 |
+
power_spec = torch.sqrt(torch.pow(fft.imag, 2) + torch.pow(fft.real, 2))
|
| 71 |
+
mel_output = torch.matmul(self.mel_basis, power_spec)
|
| 72 |
+
return torch.log10(torch.clamp(mel_output, min=1e-5))
|
| 73 |
+
power_spec = torch.pow(fft.imag, 2) + torch.pow(fft.real, 2)
|
| 74 |
+
mel_spec = torch.matmul(self.mel_basis, torch.sqrt(power_spec + 1e-9))
|
| 75 |
+
return self.spectral_normalize(mel_spec)
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def spectral_normalize(cls, spec, C=1, clip_val=1e-5):
|
| 79 |
+
output = cls.dynamic_range_compression(spec, C, clip_val)
|
| 80 |
+
return output
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def spectral_de_normalize_torch(cls, spec, C=1, clip_val=1e-5):
|
| 84 |
+
output = cls.dynamic_range_decompression(spec, C, clip_val)
|
| 85 |
+
return output
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 89 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def dynamic_range_decompression(x, C=1):
|
| 93 |
+
return torch.exp(x) / C
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class LookaheadBlock(nn.Module):
|
| 97 |
+
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.channels = channels
|
| 100 |
+
self.pre_lookahead_len = pre_lookahead_len
|
| 101 |
+
self.conv1 = nn.Conv1d(
|
| 102 |
+
in_channels, channels,
|
| 103 |
+
kernel_size=pre_lookahead_len+1,
|
| 104 |
+
stride=1, padding=0,
|
| 105 |
+
)
|
| 106 |
+
self.conv2 = nn.Conv1d(
|
| 107 |
+
channels, in_channels,
|
| 108 |
+
kernel_size=3, stride=1, padding=0,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, inputs, ilens, context: torch.Tensor = torch.zeros(0, 0, 0)):
|
| 112 |
+
"""
|
| 113 |
+
inputs: (batch_size, seq_len, channels)
|
| 114 |
+
"""
|
| 115 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
| 116 |
+
context = context.transpose(1, 2).contiguous()
|
| 117 |
+
# look ahead
|
| 118 |
+
if context.size(2) == 0:
|
| 119 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0)
|
| 120 |
+
else:
|
| 121 |
+
assert context.size(2) == self.pre_lookahead_len
|
| 122 |
+
outputs = torch.concat([outputs, context], dim=2)
|
| 123 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
| 124 |
+
# outputs
|
| 125 |
+
outputs = F.pad(outputs, (2, 0), mode='constant', value=0)
|
| 126 |
+
outputs = self.conv2(outputs)
|
| 127 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
| 128 |
+
|
| 129 |
+
mask = (~make_pad_mask(ilens).unsqueeze(-1).to(inputs.device))
|
| 130 |
+
# residual connection
|
| 131 |
+
outputs = (outputs + inputs) * mask
|
| 132 |
+
|
| 133 |
+
return outputs, ilens
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class CosyVoiceFlowMatching(nn.Module):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
codebook_size: int,
|
| 140 |
+
model_size: int,
|
| 141 |
+
xvec_size: int = 198,
|
| 142 |
+
dit_conf: Dict = {},
|
| 143 |
+
mel_feat_conf: Dict = {},
|
| 144 |
+
prompt_conf: Dict = None,
|
| 145 |
+
**kwargs):
|
| 146 |
+
super().__init__()
|
| 147 |
+
|
| 148 |
+
# feat related
|
| 149 |
+
self.feat_token_ratio = kwargs.get("feat_token_ratio", None)
|
| 150 |
+
try:
|
| 151 |
+
self.mel_extractor = Audio2Mel(**mel_feat_conf)
|
| 152 |
+
self.sample_rate = self.mel_extractor.sampling_rate
|
| 153 |
+
except:
|
| 154 |
+
self.mel_extractor = None
|
| 155 |
+
self.sample_rate = 24000
|
| 156 |
+
self.mel_norm_type = kwargs.get("mel_norm_type", None)
|
| 157 |
+
self.num_mels = num_mels = mel_feat_conf["n_mel_channels"]
|
| 158 |
+
self.token_rate = kwargs.get("token_rate", 25)
|
| 159 |
+
self.model_dtype = kwargs.get("model_dtype", "fp32")
|
| 160 |
+
self.codebook_size = codebook_size
|
| 161 |
+
|
| 162 |
+
# condition related
|
| 163 |
+
self.prompt_conf = prompt_conf
|
| 164 |
+
if self.prompt_conf is not None:
|
| 165 |
+
self.prompt_masker = self.build_prompt_masker()
|
| 166 |
+
|
| 167 |
+
# codec related
|
| 168 |
+
self.codec_embedder = nn.Embedding(codebook_size, num_mels)
|
| 169 |
+
lookahead_length = kwargs.get("lookahead_length", 4)
|
| 170 |
+
self.lookahead_conv1d = LookaheadBlock(num_mels, model_size, lookahead_length)
|
| 171 |
+
|
| 172 |
+
# spk embed related
|
| 173 |
+
if xvec_size is not None:
|
| 174 |
+
self.xvec_proj = torch.nn.Linear(xvec_size, num_mels)
|
| 175 |
+
|
| 176 |
+
# dit model related
|
| 177 |
+
self.dit_conf = dit_conf
|
| 178 |
+
self.dit_model = DiT(**dit_conf)
|
| 179 |
+
|
| 180 |
+
self.training_cfg_rate = kwargs.get("training_cfg_rate", 0)
|
| 181 |
+
self.only_mask_loss = kwargs.get("only_mask_loss", True)
|
| 182 |
+
|
| 183 |
+
# NOTE fm需要右看的下文
|
| 184 |
+
self.context_size = self.lookahead_conv1d.pre_lookahead_len
|
| 185 |
+
|
| 186 |
+
def build_prompt_masker(self):
|
| 187 |
+
prompt_type = self.prompt_conf.get("prompt_type", "free")
|
| 188 |
+
if prompt_type == "prefix":
|
| 189 |
+
from funcineforge.models.utils.mask_along_axis import MaskTailVariableMaxWidth
|
| 190 |
+
masker = MaskTailVariableMaxWidth(
|
| 191 |
+
mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"],
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
raise NotImplementedError
|
| 195 |
+
|
| 196 |
+
return masker
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def norm_spk_emb(xvec):
|
| 200 |
+
xvec_mask = (~xvec.norm(dim=-1).isnan()) * (~xvec.norm(dim=-1).isinf())
|
| 201 |
+
xvec = xvec * xvec_mask.unsqueeze(-1)
|
| 202 |
+
xvec = xvec.mean(dim=1)
|
| 203 |
+
xvec = F.normalize(xvec, dim=1)
|
| 204 |
+
|
| 205 |
+
return xvec
|
| 206 |
+
|
| 207 |
+
def select_target_prompt(self, y: torch.Tensor, y_lengths: torch.Tensor):
|
| 208 |
+
# cond_mask: 1, 1, 1, ..., 0, 0, 0
|
| 209 |
+
cond_mask = self.prompt_masker(y, y_lengths, return_mask=True)
|
| 210 |
+
|
| 211 |
+
return cond_mask
|
| 212 |
+
|
| 213 |
+
@torch.no_grad()
|
| 214 |
+
def normalize_mel_feat(self, feat, feat_lengths):
|
| 215 |
+
# feat in B,T,D
|
| 216 |
+
if self.mel_norm_type == "mean_std":
|
| 217 |
+
max_length = feat.shape[1]
|
| 218 |
+
mask = (~make_pad_mask(feat_lengths, maxlen=max_length))
|
| 219 |
+
mask = mask.unsqueeze(-1).to(feat)
|
| 220 |
+
mean = ((feat * mask).sum(dim=(1, 2), keepdim=True) /
|
| 221 |
+
(mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1]))
|
| 222 |
+
var = (((feat - mean)**2 * mask).sum(dim=(1, 2), keepdim=True) /
|
| 223 |
+
(mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1] - 1)) # -1 for unbiased estimation
|
| 224 |
+
std = torch.sqrt(var)
|
| 225 |
+
feat = (feat - mean) / std
|
| 226 |
+
feat = feat * mask
|
| 227 |
+
return feat
|
| 228 |
+
if self.mel_norm_type == "min_max":
|
| 229 |
+
bb, tt, dd = feat.shape
|
| 230 |
+
mask = (~make_pad_mask(feat_lengths, maxlen=tt))
|
| 231 |
+
mask = mask.unsqueeze(-1).to(feat)
|
| 232 |
+
feat_min = (feat * mask).reshape([bb, tt * dd]).min(dim=1, keepdim=True).values.unsqueeze(-1)
|
| 233 |
+
feat_max = (feat * mask).reshape([bb, tt * dd]).max(dim=1, keepdim=True).values.unsqueeze(-1)
|
| 234 |
+
feat = (feat - feat_min) / (feat_max - feat_min)
|
| 235 |
+
# noise ~ N(0, I), P(x >= 3sigma) = 0.001, 3 is enough.
|
| 236 |
+
feat = (feat * 3) * mask # feat in [-3, 3]
|
| 237 |
+
return feat
|
| 238 |
+
else:
|
| 239 |
+
raise NotImplementedError
|
| 240 |
+
|
| 241 |
+
@torch.no_grad()
|
| 242 |
+
def extract_feat(self, y: torch.Tensor, y_lengths: torch.Tensor):
|
| 243 |
+
mel_extractor = self.mel_extractor.float()
|
| 244 |
+
feat = mel_extractor(y)
|
| 245 |
+
feat = feat.transpose(1, 2)
|
| 246 |
+
feat_lengths = (y_lengths / self.mel_extractor.hop_length).to(y_lengths)
|
| 247 |
+
if self.mel_norm_type is not None:
|
| 248 |
+
feat = self.normalize_mel_feat(feat, feat_lengths)
|
| 249 |
+
return feat, feat_lengths
|
| 250 |
+
|
| 251 |
+
def load_data(self, contents: dict, **kwargs):
|
| 252 |
+
fm_use_prompt = kwargs.get("fm_use_prompt", True)
|
| 253 |
+
|
| 254 |
+
# codec
|
| 255 |
+
codec = contents["codec"]
|
| 256 |
+
if isinstance(codec, np.ndarray):
|
| 257 |
+
codec = torch.from_numpy(codec)
|
| 258 |
+
# codec = torch.from_numpy(codec)[None, :]
|
| 259 |
+
codec_lengths = torch.tensor([codec.shape[1]], dtype=torch.int64)
|
| 260 |
+
|
| 261 |
+
# prompt codec (optional)
|
| 262 |
+
prompt_codec = kwargs.get("prompt_codec", None)
|
| 263 |
+
prompt_codec_lengths = None
|
| 264 |
+
if prompt_codec is not None and fm_use_prompt:
|
| 265 |
+
if isinstance(prompt_codec, str) and os.path.exists(prompt_codec):
|
| 266 |
+
prompt_codec = np.load(prompt_codec)
|
| 267 |
+
if isinstance(prompt_codec, np.ndarray):
|
| 268 |
+
prompt_codec = torch.from_numpy(prompt_codec)[None, :]
|
| 269 |
+
prompt_codec_lengths = torch.tensor([prompt_codec.shape[1]], dtype=torch.int64)
|
| 270 |
+
else:
|
| 271 |
+
prompt_codec = None
|
| 272 |
+
spk_emb = kwargs.get("spk_emb", None)
|
| 273 |
+
spk_emb_lengths = None
|
| 274 |
+
if spk_emb is not None:
|
| 275 |
+
if isinstance(spk_emb, str) and os.path.exists(spk_emb):
|
| 276 |
+
spk_emb = np.load(spk_emb)
|
| 277 |
+
if isinstance(spk_emb, np.ndarray):
|
| 278 |
+
spk_emb = torch.from_numpy(spk_emb)[None, :]
|
| 279 |
+
spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
|
| 280 |
+
|
| 281 |
+
# prompt wav as condition
|
| 282 |
+
prompt_wav = contents["vocal"]
|
| 283 |
+
prompt_wav_lengths = None
|
| 284 |
+
if prompt_wav is not None and fm_use_prompt and os.path.exists(prompt_wav):
|
| 285 |
+
if prompt_wav.endswith(".npy"):
|
| 286 |
+
spk_emb = np.load(prompt_wav)
|
| 287 |
+
spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
|
| 288 |
+
else:
|
| 289 |
+
spk_emb = extract_campp_xvec(prompt_wav, **kwargs)
|
| 290 |
+
spk_emb = torch.from_numpy(spk_emb)
|
| 291 |
+
spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64)
|
| 292 |
+
# prompt_wav = load_audio_text_image_video(prompt_wav, fs=self.sample_rate)
|
| 293 |
+
# prompt_wav = prompt_wav[None, :]
|
| 294 |
+
# prompt_wav_lengths = torch.tensor([prompt_wav.shape[1]], dtype=torch.int64)
|
| 295 |
+
else:
|
| 296 |
+
logging.info("[error] prompt_wav is None or not path or path not exists! Please provide the correct speaker embedding.")
|
| 297 |
+
|
| 298 |
+
output = {
|
| 299 |
+
"codec": codec,
|
| 300 |
+
"codec_lengths": codec_lengths,
|
| 301 |
+
"prompt_codec": prompt_codec,
|
| 302 |
+
"prompt_codec_lengths": prompt_codec_lengths,
|
| 303 |
+
"prompt_wav": None,
|
| 304 |
+
"prompt_wav_lengths": None,
|
| 305 |
+
"xvec": spk_emb,
|
| 306 |
+
"xvec_lengths": spk_emb_lengths,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
return output
|
| 310 |
+
|
| 311 |
+
@torch.no_grad()
|
| 312 |
+
def inference(
|
| 313 |
+
self,
|
| 314 |
+
data_in,
|
| 315 |
+
data_lengths=None,
|
| 316 |
+
key: list = None,
|
| 317 |
+
chunk_size: int = -1,
|
| 318 |
+
finalize: bool = True,
|
| 319 |
+
**kwargs,
|
| 320 |
+
):
|
| 321 |
+
uttid = key[0]
|
| 322 |
+
if kwargs.get("batch_size", 1) > 1:
|
| 323 |
+
raise NotImplementedError("batch decoding is not implemented")
|
| 324 |
+
batch = self.load_data(data_in[0], **kwargs)
|
| 325 |
+
batch = to_device(batch, kwargs["device"])
|
| 326 |
+
batch.update({'finalize': finalize, 'chunk_size': chunk_size})
|
| 327 |
+
feat = self._inference(**batch, **kwargs)
|
| 328 |
+
feat = feat.float()
|
| 329 |
+
logging.info(f"{uttid}: feat lengths {feat.shape[1]}")
|
| 330 |
+
|
| 331 |
+
return feat
|
| 332 |
+
|
| 333 |
+
@torch.no_grad()
|
| 334 |
+
def _inference(
|
| 335 |
+
self,
|
| 336 |
+
codec, codec_lengths,
|
| 337 |
+
prompt_codec=None, prompt_codec_lengths=None,
|
| 338 |
+
prompt_wav=None, prompt_wav_lengths=None,
|
| 339 |
+
xvec=None, xvec_lengths=None, chunk_size=-1, finalize=False,
|
| 340 |
+
**kwargs
|
| 341 |
+
):
|
| 342 |
+
fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")]
|
| 343 |
+
rand_xvec = None
|
| 344 |
+
if xvec is not None:
|
| 345 |
+
if xvec.dim() == 2:
|
| 346 |
+
xvec = xvec.unsqueeze(1)
|
| 347 |
+
xvec_lens = torch.ones_like(xvec_lengths)
|
| 348 |
+
rand_xvec = self.norm_spk_emb(xvec)
|
| 349 |
+
self.xvec_proj.to(fm_dtype)
|
| 350 |
+
rand_xvec = self.xvec_proj(rand_xvec.to(fm_dtype))
|
| 351 |
+
rand_xvec = rand_xvec.unsqueeze(1)
|
| 352 |
+
|
| 353 |
+
if (codec >= self.codebook_size).any():
|
| 354 |
+
new_codec = codec[codec < self.codebook_size].unsqueeze(0)
|
| 355 |
+
logging.info(f"remove out-of-range token for FM: from {codec.shape[1]} to {new_codec.shape[1]}.")
|
| 356 |
+
codec_lengths = codec_lengths - (codec.shape[1] - new_codec.shape[1])
|
| 357 |
+
codec = new_codec
|
| 358 |
+
if prompt_codec is not None:
|
| 359 |
+
codec, codec_lengths = self.concat_prompt(prompt_codec, prompt_codec_lengths, codec, codec_lengths)
|
| 360 |
+
mask = (codec != -1).float().unsqueeze(-1)
|
| 361 |
+
codec_emb = self.codec_embedder(torch.clamp(codec, min=0)) * mask
|
| 362 |
+
|
| 363 |
+
self.lookahead_conv1d.to(fm_dtype)
|
| 364 |
+
if finalize is True:
|
| 365 |
+
context = torch.zeros(1, 0, self.codec_embedder.embedding_dim).to(fm_dtype)
|
| 366 |
+
else:
|
| 367 |
+
codec_emb, context = codec_emb[:, :-self.context_size].to(fm_dtype), codec_emb[:, -self.context_size:].to(fm_dtype)
|
| 368 |
+
codec_lengths = codec_lengths - self.context_size
|
| 369 |
+
mu, _ = self.lookahead_conv1d(codec_emb, codec_lengths, context)
|
| 370 |
+
mu = mu.repeat_interleave(self.feat_token_ratio, dim=1)
|
| 371 |
+
# print(mu.size())
|
| 372 |
+
conditions = torch.zeros([mu.size(0), mu.shape[1], self.num_mels]).to(mu)
|
| 373 |
+
# get conditions
|
| 374 |
+
if prompt_wav is not None:
|
| 375 |
+
if prompt_wav.ndim == 2:
|
| 376 |
+
prompt_wav, prompt_wav_lengths = self.extract_feat(prompt_wav, prompt_wav_lengths)
|
| 377 |
+
# NOTE 在fmax12k fm中,尝试mel interploate成token 2倍shape,而不是强制截断
|
| 378 |
+
prompt_wav = prompt_wav.to(fm_dtype)
|
| 379 |
+
for i, _len in enumerate(prompt_wav_lengths):
|
| 380 |
+
conditions[i, :_len] = prompt_wav[i]
|
| 381 |
+
|
| 382 |
+
feat_lengths = codec_lengths * self.feat_token_ratio
|
| 383 |
+
# NOTE add_optional_chunk_mask支持生成-1/1/15/30不同chunk_size的mask
|
| 384 |
+
mask = add_optional_chunk_mask(mu, torch.ones([1, 1, mu.shape[1]]).to(mu).bool(), False, False, 0, chunk_size, -1)
|
| 385 |
+
feat = self.solve_ode(mu, rand_xvec, conditions.to(fm_dtype), mask, **kwargs)
|
| 386 |
+
|
| 387 |
+
if prompt_codec is not None and prompt_wav is not None:
|
| 388 |
+
feat, feat_lens = self.remove_prompt(None, prompt_wav_lengths, feat, feat_lengths)
|
| 389 |
+
|
| 390 |
+
return feat
|
| 391 |
+
|
| 392 |
+
@staticmethod
|
| 393 |
+
def concat_prompt(prompt, prompt_lengths, text, text_lengths):
|
| 394 |
+
xs_list, x_len_list = [], []
|
| 395 |
+
for idx, (_prompt_len, _text_len) in enumerate(zip(prompt_lengths, text_lengths)):
|
| 396 |
+
xs_list.append(torch.concat([prompt[idx, :_prompt_len], text[idx, :_text_len]], dim=0))
|
| 397 |
+
x_len_list.append(_prompt_len + _text_len)
|
| 398 |
+
|
| 399 |
+
xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0)
|
| 400 |
+
x_lens = torch.tensor(x_len_list, dtype=torch.int64).to(xs.device)
|
| 401 |
+
|
| 402 |
+
return xs, x_lens
|
| 403 |
+
|
| 404 |
+
@staticmethod
|
| 405 |
+
def remove_prompt(prompt, prompt_lengths, padded, padded_lengths):
|
| 406 |
+
xs_list = []
|
| 407 |
+
for idx, (_prompt_len, _x_len) in enumerate(zip(prompt_lengths, padded_lengths)):
|
| 408 |
+
xs_list.append(padded[idx, _prompt_len: _x_len])
|
| 409 |
+
|
| 410 |
+
xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0)
|
| 411 |
+
|
| 412 |
+
return xs, padded_lengths - prompt_lengths
|
| 413 |
+
|
| 414 |
+
def get_rand_noise(self, mu: torch.Tensor, **kwargs):
|
| 415 |
+
use_fixed_noise_infer = kwargs.get("use_fixed_noise_infer", True)
|
| 416 |
+
max_len = kwargs.get("max_len", 50*300)
|
| 417 |
+
if use_fixed_noise_infer:
|
| 418 |
+
if not hasattr(self, "rand_noise") or self.rand_noise is None or self.rand_noise.shape[2] < mu.shape[2]:
|
| 419 |
+
self.rand_noise = torch.randn([1, max_len, mu.shape[2]]).to(mu)
|
| 420 |
+
logging.info("init random noise for Flow")
|
| 421 |
+
# return self.rand_noise[:, :mu.shape[1], :]
|
| 422 |
+
return torch.concat([self.rand_noise[:, :mu.shape[1], :] for _ in range(mu.size(0))], dim = 0)
|
| 423 |
+
else:
|
| 424 |
+
return torch.randn_like(mu)
|
| 425 |
+
|
| 426 |
+
def solve_ode(self, mu, rand_xvec, conditions, mask, **kwargs):
|
| 427 |
+
fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")]
|
| 428 |
+
temperature = kwargs.get("temperature", 1.0)
|
| 429 |
+
n_timesteps = kwargs.get("n_timesteps", 10)
|
| 430 |
+
infer_t_scheduler = kwargs.get("infer_t_scheduler", "cosine")
|
| 431 |
+
z = self.get_rand_noise(mu) * temperature
|
| 432 |
+
# print("z", z.size(), "mu", mu.size())
|
| 433 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1).to(mu)
|
| 434 |
+
# print("t_span", t_span)
|
| 435 |
+
if infer_t_scheduler == 'cosine':
|
| 436 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 437 |
+
fm_time = time.time()
|
| 438 |
+
self.dit_model.to(fm_dtype)
|
| 439 |
+
feat = self.solve_euler(
|
| 440 |
+
z.to(fm_dtype), t_span=t_span.to(fm_dtype), mu=mu.to(fm_dtype), mask=mask,
|
| 441 |
+
spks=rand_xvec.to(fm_dtype), cond=conditions.to(fm_dtype), **kwargs
|
| 442 |
+
)
|
| 443 |
+
escape_time = (time.time() - fm_time) * 1000.0
|
| 444 |
+
logging.info(f"fm dec {n_timesteps} step time: {escape_time:.2f}, avg {escape_time/n_timesteps:.2f} ms")
|
| 445 |
+
return feat
|
| 446 |
+
|
| 447 |
+
def solve_euler(self, x, t_span, mu, mask, spks=None, cond=None, **kwargs):
|
| 448 |
+
"""
|
| 449 |
+
Fixed euler solver for ODEs.
|
| 450 |
+
Args:
|
| 451 |
+
x (torch.Tensor): random noise
|
| 452 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 453 |
+
shape: (n_timesteps + 1,)
|
| 454 |
+
mu (torch.Tensor): output of encoder
|
| 455 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 456 |
+
mask (torch.Tensor): output_mask
|
| 457 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 458 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 459 |
+
shape: (batch_size, spk_emb_dim)
|
| 460 |
+
cond: Not used but kept for future purposes
|
| 461 |
+
"""
|
| 462 |
+
inference_cfg_rate = kwargs.get("inference_cfg_rate", 0.7)
|
| 463 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 464 |
+
# print("solve_euler cond", cond.size())
|
| 465 |
+
steps = 1
|
| 466 |
+
z, bz = x, x.shape[0]
|
| 467 |
+
while steps <= len(t_span) - 1:
|
| 468 |
+
if inference_cfg_rate > 0:
|
| 469 |
+
x_in = torch.concat([x, x], dim=0)
|
| 470 |
+
spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
|
| 471 |
+
mask_in = torch.concat([mask, mask], dim=0)
|
| 472 |
+
mu_in = torch.concat([mu, torch.zeros_like(mu)], dim=0)
|
| 473 |
+
t_in = torch.concat([t.unsqueeze(0) for _ in range(mu_in.size(0))], dim=0)
|
| 474 |
+
if isinstance(cond, torch.Tensor):
|
| 475 |
+
cond_in = torch.concat([cond, torch.zeros_like(cond)], dim=0)
|
| 476 |
+
else:
|
| 477 |
+
cond_in = dict(
|
| 478 |
+
prompt=[
|
| 479 |
+
torch.concat([cond["prompt"][0], torch.zeros_like(cond["prompt"][0])], dim=0),
|
| 480 |
+
torch.concat([cond["prompt"][1], cond["prompt"][1]], dim=0),
|
| 481 |
+
]
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
x_in, mask_in, mu_in, spks_in, t_in, cond_in = x, mask, mu, spks, t, cond
|
| 485 |
+
|
| 486 |
+
# if spks is not None:
|
| 487 |
+
# cond_in = cond_in + spks
|
| 488 |
+
|
| 489 |
+
infer_causal_mask_type = kwargs.get("infer_causal_mask_type", 0)
|
| 490 |
+
chunk_mask_value = self.dit_model.causal_mask_type[infer_causal_mask_type]["prob_min"]
|
| 491 |
+
hint_once(
|
| 492 |
+
f"flow mask type: {infer_causal_mask_type}, mask_rank value: {chunk_mask_value}.",
|
| 493 |
+
"chunk_mask_value"
|
| 494 |
+
)
|
| 495 |
+
# print("dit_model cond", x_in.size(), cond_in.size(), mu_in.size(), spks_in.size(), t_in.size())
|
| 496 |
+
# print(t_in)
|
| 497 |
+
dphi_dt = self.dit_model(
|
| 498 |
+
x_in, cond_in, mu_in, spks_in, t_in,
|
| 499 |
+
mask=mask_in,
|
| 500 |
+
mask_rand=torch.ones_like(t_in).reshape(-1, 1, 1) * chunk_mask_value
|
| 501 |
+
)
|
| 502 |
+
if inference_cfg_rate > 0:
|
| 503 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [bz, bz], dim=0)
|
| 504 |
+
dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
|
| 505 |
+
inference_cfg_rate * cfg_dphi_dt)
|
| 506 |
+
|
| 507 |
+
x = x + dt * dphi_dt
|
| 508 |
+
t = t + dt
|
| 509 |
+
# sol.append(x)
|
| 510 |
+
if steps < len(t_span) - 1:
|
| 511 |
+
dt = t_span[steps + 1] - t
|
| 512 |
+
steps += 1
|
| 513 |
+
|
| 514 |
+
return x
|
funcineforge/models/inference_model.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import torchaudio
|
| 7 |
+
import time
|
| 8 |
+
import shutil
|
| 9 |
+
from funcineforge.utils.set_all_random_seed import set_all_random_seed
|
| 10 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip, AudioFileClip
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FunCineForgeInferModel(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
lm_model,
|
| 17 |
+
fm_model,
|
| 18 |
+
voc_model,
|
| 19 |
+
**kwargs
|
| 20 |
+
):
|
| 21 |
+
from funcineforge.auto.auto_model import AutoModel
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.tokenizer = lm_model.kwargs["tokenizer"]
|
| 24 |
+
self.frontend = fm_model.kwargs["frontend"]
|
| 25 |
+
self.lm_model = lm_model.model
|
| 26 |
+
self.fm_model = fm_model.model
|
| 27 |
+
self.voc_model = voc_model.model
|
| 28 |
+
mel_extractor = self.fm_model.mel_extractor
|
| 29 |
+
if mel_extractor:
|
| 30 |
+
self.mel_frame_rate = mel_extractor.sampling_rate // mel_extractor.hop_length
|
| 31 |
+
self.sample_rate = mel_extractor.sampling_rate
|
| 32 |
+
else:
|
| 33 |
+
self.mel_frame_rate = self.fm_model.sample_rate // 480
|
| 34 |
+
self.sample_rate = self.fm_model.sample_rate
|
| 35 |
+
|
| 36 |
+
@torch.no_grad()
|
| 37 |
+
def inference(
|
| 38 |
+
self,
|
| 39 |
+
data_in,
|
| 40 |
+
data_lengths=None,
|
| 41 |
+
key: list = None,
|
| 42 |
+
**kwargs,
|
| 43 |
+
):
|
| 44 |
+
uttid = key[0]
|
| 45 |
+
logging.info(f"generating {uttid}")
|
| 46 |
+
# text -> codec in [1, T]
|
| 47 |
+
kwargs["tokenizer"] = self.tokenizer
|
| 48 |
+
set_all_random_seed(kwargs.get("random_seed", 0))
|
| 49 |
+
lm_time = time.time()
|
| 50 |
+
codec, hit_eos, states = self.lm_model.inference(data_in, data_lengths, key, **kwargs)
|
| 51 |
+
logging.info(f"[llm time]: {((time.time()-lm_time)*1000):.2f} ms, [hit_eos]: {hit_eos}, [gen len]: {codec.shape[1]}, [speech tokens]: {codec[0].cpu().tolist()}")
|
| 52 |
+
wav, batch_data_time = None, 1.0
|
| 53 |
+
if codec.shape[1] > 0:
|
| 54 |
+
fm_time = time.time()
|
| 55 |
+
data_in[0]["codec"] = codec
|
| 56 |
+
set_all_random_seed(kwargs.get("random_seed", 0))
|
| 57 |
+
feat = self.fm_model.inference(data_in, data_lengths, key, **kwargs)
|
| 58 |
+
# feat -> wav
|
| 59 |
+
set_all_random_seed(kwargs.get("random_seed", 0))
|
| 60 |
+
wav = self.voc_model.inference([feat[0]], data_lengths, key, **kwargs)
|
| 61 |
+
# output save
|
| 62 |
+
output_dir = kwargs.get("output_dir", None)
|
| 63 |
+
if output_dir is not None:
|
| 64 |
+
feat_out_dir = os.path.join(output_dir, "feat")
|
| 65 |
+
os.makedirs(feat_out_dir, exist_ok=True)
|
| 66 |
+
np.save(os.path.join(feat_out_dir, f"{key[0]}.npy"), feat[0].cpu().numpy())
|
| 67 |
+
|
| 68 |
+
wav_out_dir = os.path.join(output_dir, "wav")
|
| 69 |
+
os.makedirs(wav_out_dir, exist_ok=True)
|
| 70 |
+
output_wav_path = os.path.join(wav_out_dir, f"{key[0]}.wav")
|
| 71 |
+
torchaudio.save(
|
| 72 |
+
output_wav_path, wav.cpu(),
|
| 73 |
+
sample_rate=self.sample_rate, encoding='PCM_S', bits_per_sample=16
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
silent_video_path = data_in[0]["video"]
|
| 77 |
+
if os.path.exists(silent_video_path):
|
| 78 |
+
video_out_dir = os.path.join(output_dir, "mp4")
|
| 79 |
+
video_gt_dir = os.path.join(output_dir, "gt")
|
| 80 |
+
os.makedirs(video_out_dir, exist_ok=True)
|
| 81 |
+
os.makedirs(video_gt_dir, exist_ok=True)
|
| 82 |
+
output_video_path = os.path.join(video_out_dir, f"{key[0]}.mp4")
|
| 83 |
+
copy_video_path = os.path.join(video_gt_dir, f"{key[0]}.mp4")
|
| 84 |
+
shutil.copy2(silent_video_path, copy_video_path)
|
| 85 |
+
self.merge_video_audio(
|
| 86 |
+
silent_video_path=silent_video_path,
|
| 87 |
+
wav_path=output_wav_path,
|
| 88 |
+
output_path=output_video_path,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
logging.info(f"fm_voc time: {((time.time()-fm_time)*1000):.2f} ms")
|
| 92 |
+
|
| 93 |
+
batch_data_time = wav.shape[1] / self.voc_model.sample_rate
|
| 94 |
+
|
| 95 |
+
return [[wav]], {"batch_data_time": batch_data_time}
|
| 96 |
+
|
| 97 |
+
def merge_video_audio(self, silent_video_path, wav_path, output_path):
|
| 98 |
+
|
| 99 |
+
video_clip = VideoFileClip(silent_video_path)
|
| 100 |
+
video_duration = video_clip.duration
|
| 101 |
+
audio_clip = AudioFileClip(wav_path)
|
| 102 |
+
audio_duration = audio_clip.duration
|
| 103 |
+
|
| 104 |
+
if audio_duration >= video_duration:
|
| 105 |
+
audio_clip = audio_clip.subclipped(0, video_duration)
|
| 106 |
+
|
| 107 |
+
video_clip = video_clip.with_audio(audio_clip)
|
| 108 |
+
video_clip.write_videofile(
|
| 109 |
+
output_path,
|
| 110 |
+
codec='libx264',
|
| 111 |
+
audio_codec='aac',
|
| 112 |
+
fps=video_clip.fps,
|
| 113 |
+
logger=None
|
| 114 |
+
)
|
| 115 |
+
video_clip.close()
|
| 116 |
+
audio_clip.close()
|
funcineforge/models/language_model.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from funcineforge.models.utils.llm_decoding import LLMDecoder
|
| 6 |
+
from funcineforge.utils.device_funcs import to_device
|
| 7 |
+
import numpy as np
|
| 8 |
+
from funcineforge.models.utils import dtype_map
|
| 9 |
+
from funcineforge.models import FunCineForgeSpecAug
|
| 10 |
+
from transformers import AutoModelForCausalLM
|
| 11 |
+
import pickle
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FunCineForgeLM(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
llm: str = None,
|
| 19 |
+
llm_conf: dict = None,
|
| 20 |
+
input_size: int = 80,
|
| 21 |
+
length_normalized_loss: bool = False,
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
# llm
|
| 27 |
+
self.llm_conf = llm_conf
|
| 28 |
+
self.llm = None
|
| 29 |
+
|
| 30 |
+
init_param_path = llm_conf.get("init_param_path", "")
|
| 31 |
+
llm_load_kwargs = llm_conf.get("load_kwargs", {})
|
| 32 |
+
self.sample_rate = kwargs.get("sample_rate", 24000)
|
| 33 |
+
self.token_rate = kwargs.get("token_rate", 25)
|
| 34 |
+
|
| 35 |
+
if kwargs.get("infer_lora_merged", False):
|
| 36 |
+
llm_conf["use_qlora"] = False
|
| 37 |
+
llm_conf["use_lora"] = False
|
| 38 |
+
kwargs["infer_use_lora"] = False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 42 |
+
init_param_path,
|
| 43 |
+
load_in_8bit=None,
|
| 44 |
+
device_map=None,
|
| 45 |
+
use_cache=None,
|
| 46 |
+
**llm_load_kwargs,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
freeze = llm_conf.get("freeze", True)
|
| 50 |
+
if freeze:
|
| 51 |
+
for name, param in model.named_parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
model.eval()
|
| 54 |
+
|
| 55 |
+
logging.info(f"use_lora: {llm_conf.get('use_lora', False)}, use_qlora: {llm_conf.get('use_qlora', False)}, infer_use_lora: {kwargs.get('infer_use_lora',False)}, infer_lora_merged: {kwargs.get('infer_lora_merged',False)}")
|
| 56 |
+
|
| 57 |
+
if llm_conf.get("activation_checkpoint", False):
|
| 58 |
+
model.gradient_checkpointing_enable()
|
| 59 |
+
|
| 60 |
+
self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
|
| 61 |
+
self.llm = model.to(dtype_map[self.llm_dtype])
|
| 62 |
+
llm_dim = model.get_input_embeddings().weight.shape[-1]
|
| 63 |
+
|
| 64 |
+
if (not llm_conf.get("use_lora", False)) and (not kwargs.get("infer_use_lora",False)):
|
| 65 |
+
del self.llm.lm_head
|
| 66 |
+
self.codec_unit = kwargs.get("codec_unit", 6761)
|
| 67 |
+
self.timespk_unit = kwargs.get("timespk_unit", 1550)
|
| 68 |
+
self.codec_embed = nn.Embedding(self.codec_unit, llm_dim, 0)
|
| 69 |
+
self.timespk_embed = nn.Embedding(self.timespk_unit, llm_dim, 0)
|
| 70 |
+
self.codec_head = nn.Linear(llm_dim, self.codec_unit, bias=False)
|
| 71 |
+
self.face_size = kwargs.get("face_size", 512)
|
| 72 |
+
self.face_linear = nn.Linear(self.face_size, llm_dim)
|
| 73 |
+
|
| 74 |
+
self.length_normalized_loss = length_normalized_loss
|
| 75 |
+
self.ignore_id = kwargs.get("ignore_id", -100)
|
| 76 |
+
|
| 77 |
+
specaug = kwargs.get("specaug", None)
|
| 78 |
+
specaug_conf = kwargs.get("specaug_conf", {})
|
| 79 |
+
if specaug is not None:
|
| 80 |
+
specaug = FunCineForgeSpecAug(**specaug_conf)
|
| 81 |
+
self.specaug = specaug
|
| 82 |
+
rank = int(os.environ.get("RANK", 0))
|
| 83 |
+
logging.info(f"rank: {rank}, model is builded.")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def insert_face_embeddings(
|
| 87 |
+
self, inputs_embeds, face_emb, attention_mask, labels_ids,
|
| 88 |
+
codec_len, insert_pos, device
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
将face_emb插入到inputs_embeds中的指定位置, 同步更新attention_mask和labels_ids
|
| 92 |
+
Args:
|
| 93 |
+
inputs_embeds: (batch_size, token_num, dims) 输入embedding
|
| 94 |
+
face_emb: (batch_size, max_face_len, dims) 面部embedding
|
| 95 |
+
attention_mask: (batch_size, token_num) 注意力mask
|
| 96 |
+
labels_ids: (batch_size, token_num) 标签ID
|
| 97 |
+
codec_len: (batch_size,) 每个样本的实际face_emb长度
|
| 98 |
+
insert_pos: int 插入位置, SOS token之后
|
| 99 |
+
device
|
| 100 |
+
Returns:
|
| 101 |
+
padded_inputs_embeds: 插入face_emb并padding后的inputs_embeds
|
| 102 |
+
padded_attention_mask: 更新后的attention_mask
|
| 103 |
+
padded_labels: 更新后的labels_ids
|
| 104 |
+
"""
|
| 105 |
+
batch_size, token_num, dims = inputs_embeds.shape
|
| 106 |
+
max_face_len = face_emb.size(1)
|
| 107 |
+
|
| 108 |
+
# 预计算新序列的最大长度
|
| 109 |
+
new_max_length = token_num + max_face_len
|
| 110 |
+
|
| 111 |
+
# 预分配输出张量
|
| 112 |
+
padded_inputs_embeds = torch.zeros(batch_size, new_max_length, dims, device=device)
|
| 113 |
+
padded_attention_mask = torch.zeros(batch_size, new_max_length, device=device, dtype=attention_mask.dtype)
|
| 114 |
+
padded_labels = torch.full((batch_size, new_max_length), self.ignore_id, device=device, dtype=labels_ids.dtype)
|
| 115 |
+
|
| 116 |
+
for i in range(batch_size):
|
| 117 |
+
current_face_len = codec_len[i].item()
|
| 118 |
+
|
| 119 |
+
# 直接填充,避免中间拼接
|
| 120 |
+
padded_inputs_embeds[i, :insert_pos] = inputs_embeds[i, :insert_pos]
|
| 121 |
+
padded_inputs_embeds[i, insert_pos:insert_pos+current_face_len] = face_emb[i, :current_face_len]
|
| 122 |
+
padded_inputs_embeds[i, insert_pos+current_face_len:token_num+current_face_len] = inputs_embeds[i, insert_pos:]
|
| 123 |
+
|
| 124 |
+
# 同样处理mask和labels
|
| 125 |
+
padded_attention_mask[i, :insert_pos] = attention_mask[i, :insert_pos]
|
| 126 |
+
padded_attention_mask[i, insert_pos:insert_pos+current_face_len] = 1
|
| 127 |
+
padded_attention_mask[i, insert_pos+current_face_len:token_num+current_face_len] = attention_mask[i, insert_pos:]
|
| 128 |
+
|
| 129 |
+
padded_labels[i, :insert_pos] = labels_ids[i, :insert_pos]
|
| 130 |
+
padded_labels[i, insert_pos:insert_pos+current_face_len] = self.ignore_id
|
| 131 |
+
padded_labels[i, insert_pos+current_face_len:token_num+current_face_len] = labels_ids[i, insert_pos:]
|
| 132 |
+
|
| 133 |
+
return padded_inputs_embeds, padded_attention_mask, padded_labels
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_data(self, contents: dict, **kwargs):
|
| 137 |
+
lm_use_prompt = kwargs.get("lm_use_prompt", True)
|
| 138 |
+
tokenizer = kwargs.get("tokenizer")
|
| 139 |
+
# text + clue
|
| 140 |
+
text = contents["text"]
|
| 141 |
+
clue = "<|startofclue|>" + contents["clue"] + "<|endofclue|>"
|
| 142 |
+
if lm_use_prompt:
|
| 143 |
+
text = clue + text
|
| 144 |
+
text_ids = tokenizer.encode(text)
|
| 145 |
+
text_len = len(text_ids)
|
| 146 |
+
# timespk_ids
|
| 147 |
+
timespk_ids = contents["timespk_ids"].tolist()
|
| 148 |
+
type_id = contents["type_id"]
|
| 149 |
+
# sequence
|
| 150 |
+
sequence = [
|
| 151 |
+
kwargs['dataset_conf']["sos"],
|
| 152 |
+
*text_ids,
|
| 153 |
+
type_id,
|
| 154 |
+
*timespk_ids,
|
| 155 |
+
kwargs['dataset_conf']["turn_of_speech"]
|
| 156 |
+
]
|
| 157 |
+
input_ids = torch.tensor(sequence, dtype=torch.int64)
|
| 158 |
+
|
| 159 |
+
# flag tensors
|
| 160 |
+
text_flag = torch.zeros(len(sequence), dtype=torch.float32)
|
| 161 |
+
timespk_flag = torch.zeros(len(sequence), dtype=torch.float32)
|
| 162 |
+
codec_flag = torch.zeros(len(sequence), dtype=torch.float32)
|
| 163 |
+
text_flag[1: text_len+1] = 1
|
| 164 |
+
timespk_flag[text_len+1: -1] = 1
|
| 165 |
+
codec_flag = 1 - text_flag - timespk_flag
|
| 166 |
+
|
| 167 |
+
# face embs
|
| 168 |
+
speech_len = contents["speech_len"]
|
| 169 |
+
face_embs = torch.zeros((speech_len, self.face_size), dtype=torch.float32)
|
| 170 |
+
face_path = contents.get("face")
|
| 171 |
+
with open(face_path, 'rb') as f:
|
| 172 |
+
stat_obj = pickle.load(f)
|
| 173 |
+
embeddings = stat_obj['embeddings']
|
| 174 |
+
faceI = stat_obj['faceI']
|
| 175 |
+
for emb, frameI in zip(embeddings, faceI):
|
| 176 |
+
fi = int(frameI)
|
| 177 |
+
if 0 <= fi < speech_len:
|
| 178 |
+
end = min(fi + 5, speech_len)
|
| 179 |
+
face_embs[fi:end] = torch.from_numpy(emb).expand(end - fi, -1)
|
| 180 |
+
|
| 181 |
+
# batch dimension
|
| 182 |
+
input_ids = input_ids[None, :]
|
| 183 |
+
text_flag = text_flag[None, :]
|
| 184 |
+
timespk_flag = timespk_flag[None, :]
|
| 185 |
+
codec_flag = codec_flag[None, :]
|
| 186 |
+
face_embs = face_embs[None, :, :]
|
| 187 |
+
output = {
|
| 188 |
+
"input_ids": input_ids,
|
| 189 |
+
"face_embs": face_embs,
|
| 190 |
+
"text_flag": text_flag > 0,
|
| 191 |
+
"timespk_flag": timespk_flag > 0,
|
| 192 |
+
"codec_flag": codec_flag > 0,
|
| 193 |
+
"prompt_codec": None, # you can add prompt codec here if needed
|
| 194 |
+
}
|
| 195 |
+
return output
|
| 196 |
+
|
| 197 |
+
def inference_prepare(self, data_in, **kwargs):
|
| 198 |
+
if kwargs.get("batch_size", 1) > 1:
|
| 199 |
+
raise NotImplementedError("batch decoding is not implemented")
|
| 200 |
+
output = self.load_data(data_in[0], **kwargs)
|
| 201 |
+
batch = to_device(output, kwargs["device"])
|
| 202 |
+
input_ids = batch["input_ids"]
|
| 203 |
+
input_ids = input_ids * (input_ids > 0)
|
| 204 |
+
text_flag = batch["text_flag"]
|
| 205 |
+
timespk_flag = batch["timespk_flag"]
|
| 206 |
+
codec_flag = batch["codec_flag"]
|
| 207 |
+
face_embs = batch["face_embs"]
|
| 208 |
+
|
| 209 |
+
if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
|
| 210 |
+
text_embeds = self.llm.base_model.model.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
|
| 211 |
+
else:
|
| 212 |
+
text_embeds = self.llm.model.get_input_embeddings()(input_ids * text_flag) * text_flag.unsqueeze(-1)
|
| 213 |
+
timespk_embeds = self.timespk_embed(input_ids * timespk_flag) * timespk_flag.unsqueeze(-1)
|
| 214 |
+
codec_embs = self.codec_embed(input_ids * codec_flag) * codec_flag.unsqueeze(-1)
|
| 215 |
+
face_embs = self.face_linear(face_embs)
|
| 216 |
+
|
| 217 |
+
inputs_embeds = text_embeds + timespk_embeds + codec_embs
|
| 218 |
+
|
| 219 |
+
inputs_embeds = torch.cat([
|
| 220 |
+
inputs_embeds[:, 0:1, :], # sos token
|
| 221 |
+
face_embs, # face embeddings
|
| 222 |
+
inputs_embeds[:, 1:, :] # inputs_embeds after sos
|
| 223 |
+
], dim=1)
|
| 224 |
+
|
| 225 |
+
prompt_codec = batch.get("prompt_codec", None)
|
| 226 |
+
if prompt_codec is not None:
|
| 227 |
+
codec_emb = self.codec_embed(prompt_codec)
|
| 228 |
+
inputs_embeds = torch.cat((inputs_embeds, codec_emb), dim=1)
|
| 229 |
+
|
| 230 |
+
return inputs_embeds
|
| 231 |
+
|
| 232 |
+
@torch.no_grad()
|
| 233 |
+
def inference(
|
| 234 |
+
self,
|
| 235 |
+
data_in,
|
| 236 |
+
data_lengths=None,
|
| 237 |
+
key: list = None,
|
| 238 |
+
**kwargs,
|
| 239 |
+
):
|
| 240 |
+
uttid = key[0]
|
| 241 |
+
inputs_emb = self.inference_prepare(data_in, **kwargs)
|
| 242 |
+
|
| 243 |
+
logging.info(f"{uttid}: min length: {kwargs['min_length']}, max length: {kwargs['max_length']}")
|
| 244 |
+
|
| 245 |
+
dtype = dtype_map[kwargs.get("llm_dtype", "fp32")]
|
| 246 |
+
if not hasattr(self, "llm_generator"):
|
| 247 |
+
llm_generator_conf = kwargs.get("dataset_conf", {})
|
| 248 |
+
self.llm_generator = LLMDecoder(
|
| 249 |
+
token_embeder=self.codec_embed,
|
| 250 |
+
**llm_generator_conf
|
| 251 |
+
).to(dtype)
|
| 252 |
+
|
| 253 |
+
if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
|
| 254 |
+
self.llm.base_model.model.lm_head = self.codec_head.to(dtype)
|
| 255 |
+
else:
|
| 256 |
+
self.llm.lm_head = self.codec_head.to(dtype)
|
| 257 |
+
|
| 258 |
+
gen_codec, hit_eos, states = self.llm_generator(
|
| 259 |
+
inputs_emb.to(dtype),
|
| 260 |
+
self.llm,
|
| 261 |
+
states=kwargs.get("states", {}),
|
| 262 |
+
**kwargs
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
output_dir = kwargs.get("output_dir", None)
|
| 266 |
+
if output_dir is not None:
|
| 267 |
+
output_dir = os.path.join(output_dir, "codec")
|
| 268 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 269 |
+
np.save(
|
| 270 |
+
os.path.join(output_dir, f"{key[0]}.npy"),
|
| 271 |
+
gen_codec[0].cpu().numpy()
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
return gen_codec, hit_eos, states
|
funcineforge/models/modules/__init__.py
ADDED
|
File without changes
|
funcineforge/models/modules/dit_flow_matching/__init__.py
ADDED
|
File without changes
|
funcineforge/models/modules/dit_flow_matching/dit_model.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from einops import repeat
|
| 16 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 17 |
+
from funcineforge.models.utils.masks import causal_block_mask
|
| 18 |
+
|
| 19 |
+
from .dit_modules import (
|
| 20 |
+
TimestepEmbedding,
|
| 21 |
+
ConvNeXtV2Block,
|
| 22 |
+
CausalConvPositionEmbedding,
|
| 23 |
+
DiTBlock,
|
| 24 |
+
AdaLayerNormZero_Final,
|
| 25 |
+
precompute_freqs_cis,
|
| 26 |
+
get_pos_embed_indices,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Text embedding
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TextEmbedding(nn.Module):
|
| 34 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
| 37 |
+
|
| 38 |
+
if conv_layers > 0:
|
| 39 |
+
self.extra_modeling = True
|
| 40 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
| 41 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
| 42 |
+
self.text_blocks = nn.Sequential(
|
| 43 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
self.extra_modeling = False
|
| 47 |
+
|
| 48 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
| 49 |
+
batch, text_len = text.shape[0], text.shape[1]
|
| 50 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
| 51 |
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
| 52 |
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
| 53 |
+
|
| 54 |
+
if drop_text: # cfg for text
|
| 55 |
+
text = torch.zeros_like(text)
|
| 56 |
+
|
| 57 |
+
text = self.text_embed(text) # b n -> b n d
|
| 58 |
+
|
| 59 |
+
# possible extra modeling
|
| 60 |
+
if self.extra_modeling:
|
| 61 |
+
# sinus pos emb
|
| 62 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
| 63 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
| 64 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 65 |
+
text = text + text_pos_embed
|
| 66 |
+
|
| 67 |
+
# convnextv2 blocks
|
| 68 |
+
text = self.text_blocks(text)
|
| 69 |
+
|
| 70 |
+
return text
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# noised input audio and context mixing embedding
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class InputEmbedding(nn.Module):
|
| 77 |
+
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
|
| 78 |
+
super().__init__()
|
| 79 |
+
spk_dim = 0 if spk_dim is None else spk_dim
|
| 80 |
+
self.spk_dim = spk_dim
|
| 81 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
|
| 82 |
+
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
x: float["b n d"],
|
| 87 |
+
cond: float["b n d"],
|
| 88 |
+
text_embed: float["b n d"],
|
| 89 |
+
spks: float["b d"],
|
| 90 |
+
):
|
| 91 |
+
to_cat = [x, cond, text_embed]
|
| 92 |
+
if self.spk_dim > 0:
|
| 93 |
+
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
|
| 94 |
+
to_cat.append(spks)
|
| 95 |
+
|
| 96 |
+
x = self.proj(torch.cat(to_cat, dim=-1))
|
| 97 |
+
x = self.conv_pos_embed(x) + x
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Transformer backbone using DiT blocks
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class DiT(nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
*,
|
| 108 |
+
dim,
|
| 109 |
+
depth=8,
|
| 110 |
+
heads=8,
|
| 111 |
+
dim_head=64,
|
| 112 |
+
dropout=0.1,
|
| 113 |
+
ff_mult=4,
|
| 114 |
+
mel_dim=80,
|
| 115 |
+
mu_dim=None,
|
| 116 |
+
long_skip_connection=False,
|
| 117 |
+
spk_dim=None,
|
| 118 |
+
**kwargs
|
| 119 |
+
):
|
| 120 |
+
super().__init__()
|
| 121 |
+
|
| 122 |
+
self.time_embed = TimestepEmbedding(dim)
|
| 123 |
+
if mu_dim is None:
|
| 124 |
+
mu_dim = mel_dim
|
| 125 |
+
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
|
| 126 |
+
|
| 127 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 128 |
+
|
| 129 |
+
self.dim = dim
|
| 130 |
+
self.depth = depth
|
| 131 |
+
|
| 132 |
+
self.transformer_blocks = nn.ModuleList(
|
| 133 |
+
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
| 134 |
+
)
|
| 135 |
+
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
| 136 |
+
|
| 137 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
| 138 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 139 |
+
self.causal_mask_type = kwargs.get("causal_mask_type", None)
|
| 140 |
+
|
| 141 |
+
def build_mix_causal_mask(self, attn_mask, rand=None, ratio=None):
|
| 142 |
+
b, _, _, t = attn_mask.shape
|
| 143 |
+
if rand is None:
|
| 144 |
+
rand = torch.rand((b, 1, 1, 1), device=attn_mask.device, dtype=torch.float32)
|
| 145 |
+
mixed_mask = attn_mask.clone()
|
| 146 |
+
for item in self.causal_mask_type:
|
| 147 |
+
prob_min, prob_max = item["prob_min"], item["prob_max"]
|
| 148 |
+
_ratio = 1
|
| 149 |
+
if "ratio" in item:
|
| 150 |
+
_ratio = item["ratio"]
|
| 151 |
+
if ratio is not None:
|
| 152 |
+
_ratio = ratio
|
| 153 |
+
block_size = item["block_size"] * _ratio
|
| 154 |
+
if block_size <= 0:
|
| 155 |
+
causal_mask = attn_mask
|
| 156 |
+
else:
|
| 157 |
+
causal_mask = causal_block_mask(
|
| 158 |
+
t, block_size, attn_mask.device, torch.float32
|
| 159 |
+
).unsqueeze(0).unsqueeze(1) # 1,1,T,T
|
| 160 |
+
flag = (prob_min <= rand) & (rand < prob_max)
|
| 161 |
+
mixed_mask = mixed_mask * (~flag) + (causal_mask * attn_mask) * flag
|
| 162 |
+
|
| 163 |
+
return mixed_mask
|
| 164 |
+
|
| 165 |
+
def forward(
|
| 166 |
+
self,
|
| 167 |
+
x: float["b n d"], # nosied input audio
|
| 168 |
+
cond: float["b n d"], # masked cond audio
|
| 169 |
+
mu: int["b nt d"], # mu
|
| 170 |
+
spks: float["b 1 d"], # spk xvec
|
| 171 |
+
time: float["b"] | float[""], # time step
|
| 172 |
+
return_hidden: bool = False,
|
| 173 |
+
mask: bool["b 1 n"] | None = None,
|
| 174 |
+
mask_rand: float["b 1 1"] = None, # for mask flag type
|
| 175 |
+
**kwargs,
|
| 176 |
+
):
|
| 177 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
| 178 |
+
if time.ndim == 0:
|
| 179 |
+
time = time.repeat(batch)
|
| 180 |
+
|
| 181 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
| 182 |
+
t = self.time_embed(time)
|
| 183 |
+
x = self.input_embed(x, cond, mu, spks.squeeze(1))
|
| 184 |
+
|
| 185 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
| 186 |
+
|
| 187 |
+
if self.long_skip_connection is not None:
|
| 188 |
+
residual = x
|
| 189 |
+
|
| 190 |
+
mask = mask.unsqueeze(1) # B,1,1,T
|
| 191 |
+
if self.causal_mask_type is not None:
|
| 192 |
+
mask = self.build_mix_causal_mask(mask, rand=mask_rand.unsqueeze(-1))
|
| 193 |
+
|
| 194 |
+
for block in self.transformer_blocks:
|
| 195 |
+
# mask-out padded values for amp training
|
| 196 |
+
x = x * mask[:, 0, -1, :].unsqueeze(-1)
|
| 197 |
+
x = block(x, t, mask=mask.bool(), rope=rope)
|
| 198 |
+
|
| 199 |
+
if self.long_skip_connection is not None:
|
| 200 |
+
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
| 201 |
+
|
| 202 |
+
x = self.norm_out(x, t)
|
| 203 |
+
output = self.proj_out(x)
|
| 204 |
+
|
| 205 |
+
if return_hidden:
|
| 206 |
+
return output, None
|
| 207 |
+
|
| 208 |
+
return output
|
funcineforge/models/modules/dit_flow_matching/dit_modules.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
from typing import Optional
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchaudio
|
| 18 |
+
|
| 19 |
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# raw wav to mel spec
|
| 23 |
+
class MelSpec(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
filter_length=1024,
|
| 27 |
+
hop_length=256,
|
| 28 |
+
win_length=1024,
|
| 29 |
+
n_mel_channels=100,
|
| 30 |
+
target_sample_rate=24_000,
|
| 31 |
+
normalize=False,
|
| 32 |
+
power=1,
|
| 33 |
+
norm=None,
|
| 34 |
+
center=True,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.n_mel_channels = n_mel_channels
|
| 38 |
+
|
| 39 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 40 |
+
sample_rate=target_sample_rate,
|
| 41 |
+
n_fft=filter_length,
|
| 42 |
+
win_length=win_length,
|
| 43 |
+
hop_length=hop_length,
|
| 44 |
+
n_mels=n_mel_channels,
|
| 45 |
+
power=power,
|
| 46 |
+
center=center,
|
| 47 |
+
normalized=normalize,
|
| 48 |
+
norm=norm,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
| 52 |
+
|
| 53 |
+
def forward(self, inp):
|
| 54 |
+
if len(inp.shape) == 3:
|
| 55 |
+
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
|
| 56 |
+
|
| 57 |
+
assert len(inp.shape) == 2
|
| 58 |
+
|
| 59 |
+
if self.dummy.device != inp.device:
|
| 60 |
+
self.to(inp.device)
|
| 61 |
+
|
| 62 |
+
mel = self.mel_stft(inp)
|
| 63 |
+
mel = mel.clamp(min=1e-5).log()
|
| 64 |
+
return mel
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# sinusoidal position embedding
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class SinusPositionEmbedding(nn.Module):
|
| 71 |
+
def __init__(self, dim):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.dim = dim
|
| 74 |
+
|
| 75 |
+
def forward(self, x, scale=1000):
|
| 76 |
+
device = x.device
|
| 77 |
+
half_dim = self.dim // 2
|
| 78 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 79 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 80 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 81 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 82 |
+
return emb
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# convolutional position embedding
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ConvPositionEmbedding(nn.Module):
|
| 89 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
| 90 |
+
super().__init__()
|
| 91 |
+
assert kernel_size % 2 != 0
|
| 92 |
+
self.conv1d = nn.Sequential(
|
| 93 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 94 |
+
nn.Mish(),
|
| 95 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 96 |
+
nn.Mish(),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
| 100 |
+
if mask is not None:
|
| 101 |
+
mask = mask[..., None]
|
| 102 |
+
x = x.masked_fill(~mask, 0.0)
|
| 103 |
+
|
| 104 |
+
x = x.permute(0, 2, 1)
|
| 105 |
+
x = self.conv1d(x)
|
| 106 |
+
out = x.permute(0, 2, 1)
|
| 107 |
+
|
| 108 |
+
if mask is not None:
|
| 109 |
+
out = out.masked_fill(~mask, 0.0)
|
| 110 |
+
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class CausalConvPositionEmbedding(nn.Module):
|
| 115 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
| 116 |
+
super().__init__()
|
| 117 |
+
assert kernel_size % 2 != 0
|
| 118 |
+
self.kernel_size = kernel_size
|
| 119 |
+
self.conv1 = nn.Sequential(
|
| 120 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
| 121 |
+
nn.Mish(),
|
| 122 |
+
)
|
| 123 |
+
self.conv2 = nn.Sequential(
|
| 124 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
| 125 |
+
nn.Mish(),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
| 129 |
+
if mask is not None:
|
| 130 |
+
mask = mask[..., None]
|
| 131 |
+
x = x.masked_fill(~mask, 0.0)
|
| 132 |
+
|
| 133 |
+
x = x.permute(0, 2, 1)
|
| 134 |
+
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
| 135 |
+
x = self.conv1(x)
|
| 136 |
+
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
| 137 |
+
x = self.conv2(x)
|
| 138 |
+
out = x.permute(0, 2, 1)
|
| 139 |
+
|
| 140 |
+
if mask is not None:
|
| 141 |
+
out = out.masked_fill(~mask, 0.0)
|
| 142 |
+
|
| 143 |
+
return out
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# rotary positional embedding related
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
| 150 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 151 |
+
# has some connection to NTK literature
|
| 152 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 153 |
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
| 154 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 155 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 156 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 157 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 158 |
+
freqs_cos = torch.cos(freqs) # real part
|
| 159 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
| 160 |
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
| 164 |
+
# length = length if isinstance(length, int) else length.max()
|
| 165 |
+
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
| 166 |
+
pos = (
|
| 167 |
+
start.unsqueeze(1)
|
| 168 |
+
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
| 169 |
+
)
|
| 170 |
+
# avoid extra long error.
|
| 171 |
+
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
| 172 |
+
return pos
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Global Response Normalization layer (Instance Normalization ?)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class GRN(nn.Module):
|
| 179 |
+
def __init__(self, dim):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
| 182 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
| 183 |
+
|
| 184 |
+
def forward(self, x):
|
| 185 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 186 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
| 187 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 188 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
| 192 |
+
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ConvNeXtV2Block(nn.Module):
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
dim: int,
|
| 199 |
+
intermediate_dim: int,
|
| 200 |
+
dilation: int = 1,
|
| 201 |
+
):
|
| 202 |
+
super().__init__()
|
| 203 |
+
padding = (dilation * (7 - 1)) // 2
|
| 204 |
+
self.dwconv = nn.Conv1d(
|
| 205 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
| 206 |
+
) # depthwise conv
|
| 207 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 208 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 209 |
+
self.act = nn.GELU()
|
| 210 |
+
self.grn = GRN(intermediate_dim)
|
| 211 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 212 |
+
|
| 213 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 214 |
+
residual = x
|
| 215 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
| 216 |
+
x = self.dwconv(x)
|
| 217 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
| 218 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 219 |
+
x = self.norm(x)
|
| 220 |
+
x = self.pwconv1(x)
|
| 221 |
+
x = self.act(x)
|
| 222 |
+
x = self.grn(x)
|
| 223 |
+
x = self.pwconv2(x)
|
| 224 |
+
return residual + x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# AdaLayerNormZero
|
| 228 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class AdaLayerNormZero(nn.Module):
|
| 232 |
+
def __init__(self, dim):
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
self.silu = nn.SiLU()
|
| 236 |
+
self.linear = nn.Linear(dim, dim * 6)
|
| 237 |
+
|
| 238 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 239 |
+
|
| 240 |
+
def forward(self, x, emb=None):
|
| 241 |
+
emb = self.linear(self.silu(emb))
|
| 242 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
| 243 |
+
|
| 244 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 245 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 246 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# AdaLayerNormZero for final layer
|
| 250 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class AdaLayerNormZero_Final(nn.Module):
|
| 254 |
+
def __init__(self, dim):
|
| 255 |
+
super().__init__()
|
| 256 |
+
|
| 257 |
+
self.silu = nn.SiLU()
|
| 258 |
+
self.linear = nn.Linear(dim, dim * 2)
|
| 259 |
+
|
| 260 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 261 |
+
|
| 262 |
+
def forward(self, x, emb):
|
| 263 |
+
emb = self.linear(self.silu(emb))
|
| 264 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 265 |
+
|
| 266 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 267 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# FeedForward
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class FeedForward(nn.Module):
|
| 275 |
+
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
| 276 |
+
super().__init__()
|
| 277 |
+
inner_dim = int(dim * mult)
|
| 278 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 279 |
+
|
| 280 |
+
activation = nn.GELU(approximate=approximate)
|
| 281 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
| 282 |
+
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
return self.ff(x)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# Attention with possible joint part
|
| 289 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class Attention(nn.Module):
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
processor: JointAttnProcessor | AttnProcessor,
|
| 296 |
+
dim: int,
|
| 297 |
+
heads: int = 8,
|
| 298 |
+
dim_head: int = 64,
|
| 299 |
+
dropout: float = 0.0,
|
| 300 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
| 301 |
+
context_pre_only=None,
|
| 302 |
+
):
|
| 303 |
+
super().__init__()
|
| 304 |
+
|
| 305 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 306 |
+
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 307 |
+
|
| 308 |
+
self.processor = processor
|
| 309 |
+
|
| 310 |
+
self.dim = dim
|
| 311 |
+
self.heads = heads
|
| 312 |
+
self.inner_dim = dim_head * heads
|
| 313 |
+
self.dropout = dropout
|
| 314 |
+
|
| 315 |
+
self.context_dim = context_dim
|
| 316 |
+
self.context_pre_only = context_pre_only
|
| 317 |
+
|
| 318 |
+
self.to_q = nn.Linear(dim, self.inner_dim)
|
| 319 |
+
self.to_k = nn.Linear(dim, self.inner_dim)
|
| 320 |
+
self.to_v = nn.Linear(dim, self.inner_dim)
|
| 321 |
+
|
| 322 |
+
if self.context_dim is not None:
|
| 323 |
+
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
| 324 |
+
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
| 325 |
+
if self.context_pre_only is not None:
|
| 326 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
| 327 |
+
|
| 328 |
+
self.to_out = nn.ModuleList([])
|
| 329 |
+
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
| 330 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 331 |
+
|
| 332 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
| 333 |
+
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
| 334 |
+
|
| 335 |
+
def forward(
|
| 336 |
+
self,
|
| 337 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 338 |
+
c: float["b n d"] = None, # context c # noqa: F722
|
| 339 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 340 |
+
rope=None, # rotary position embedding for x
|
| 341 |
+
c_rope=None, # rotary position embedding for c
|
| 342 |
+
) -> torch.Tensor:
|
| 343 |
+
if c is not None:
|
| 344 |
+
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
| 345 |
+
else:
|
| 346 |
+
return self.processor(self, x, mask=mask, rope=rope)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Attention processor
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class AttnProcessor:
|
| 353 |
+
def __init__(self):
|
| 354 |
+
pass
|
| 355 |
+
|
| 356 |
+
def __call__(
|
| 357 |
+
self,
|
| 358 |
+
attn: Attention,
|
| 359 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 360 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 361 |
+
rope=None, # rotary position embedding
|
| 362 |
+
) -> torch.FloatTensor:
|
| 363 |
+
batch_size = x.shape[0]
|
| 364 |
+
|
| 365 |
+
# `sample` projections.
|
| 366 |
+
query = attn.to_q(x)
|
| 367 |
+
key = attn.to_k(x)
|
| 368 |
+
value = attn.to_v(x)
|
| 369 |
+
|
| 370 |
+
# apply rotary position embedding
|
| 371 |
+
if rope is not None:
|
| 372 |
+
freqs, xpos_scale = rope
|
| 373 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 374 |
+
|
| 375 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 376 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 377 |
+
|
| 378 |
+
# attention
|
| 379 |
+
inner_dim = key.shape[-1]
|
| 380 |
+
head_dim = inner_dim // attn.heads
|
| 381 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 382 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 383 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 384 |
+
|
| 385 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 386 |
+
if mask is not None:
|
| 387 |
+
attn_mask = mask
|
| 388 |
+
if attn_mask.dim() == 2:
|
| 389 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 390 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
| 391 |
+
else:
|
| 392 |
+
attn_mask = None
|
| 393 |
+
|
| 394 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
| 395 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 396 |
+
x = x.to(query.dtype)
|
| 397 |
+
|
| 398 |
+
# linear proj
|
| 399 |
+
x = attn.to_out[0](x)
|
| 400 |
+
# dropout
|
| 401 |
+
x = attn.to_out[1](x)
|
| 402 |
+
|
| 403 |
+
if mask is not None:
|
| 404 |
+
if mask.dim() == 2:
|
| 405 |
+
mask = mask.unsqueeze(-1)
|
| 406 |
+
else:
|
| 407 |
+
mask = mask[:, 0, -1].unsqueeze(-1)
|
| 408 |
+
x = x.masked_fill(~mask, 0.0)
|
| 409 |
+
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# Joint Attention processor for MM-DiT
|
| 414 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class JointAttnProcessor:
|
| 418 |
+
def __init__(self):
|
| 419 |
+
pass
|
| 420 |
+
|
| 421 |
+
def __call__(
|
| 422 |
+
self,
|
| 423 |
+
attn: Attention,
|
| 424 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 425 |
+
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
| 426 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 427 |
+
rope=None, # rotary position embedding for x
|
| 428 |
+
c_rope=None, # rotary position embedding for c
|
| 429 |
+
) -> torch.FloatTensor:
|
| 430 |
+
residual = x
|
| 431 |
+
|
| 432 |
+
batch_size = c.shape[0]
|
| 433 |
+
|
| 434 |
+
# `sample` projections.
|
| 435 |
+
query = attn.to_q(x)
|
| 436 |
+
key = attn.to_k(x)
|
| 437 |
+
value = attn.to_v(x)
|
| 438 |
+
|
| 439 |
+
# `context` projections.
|
| 440 |
+
c_query = attn.to_q_c(c)
|
| 441 |
+
c_key = attn.to_k_c(c)
|
| 442 |
+
c_value = attn.to_v_c(c)
|
| 443 |
+
|
| 444 |
+
# apply rope for context and noised input independently
|
| 445 |
+
if rope is not None:
|
| 446 |
+
freqs, xpos_scale = rope
|
| 447 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 448 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 449 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 450 |
+
if c_rope is not None:
|
| 451 |
+
freqs, xpos_scale = c_rope
|
| 452 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 453 |
+
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
| 454 |
+
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
| 455 |
+
|
| 456 |
+
# attention
|
| 457 |
+
query = torch.cat([query, c_query], dim=1)
|
| 458 |
+
key = torch.cat([key, c_key], dim=1)
|
| 459 |
+
value = torch.cat([value, c_value], dim=1)
|
| 460 |
+
|
| 461 |
+
inner_dim = key.shape[-1]
|
| 462 |
+
head_dim = inner_dim // attn.heads
|
| 463 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 464 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 465 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 466 |
+
|
| 467 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 468 |
+
if mask is not None:
|
| 469 |
+
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
| 470 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 471 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
| 472 |
+
else:
|
| 473 |
+
attn_mask = None
|
| 474 |
+
|
| 475 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
| 476 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 477 |
+
x = x.to(query.dtype)
|
| 478 |
+
|
| 479 |
+
# Split the attention outputs.
|
| 480 |
+
x, c = (
|
| 481 |
+
x[:, : residual.shape[1]],
|
| 482 |
+
x[:, residual.shape[1] :],
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# linear proj
|
| 486 |
+
x = attn.to_out[0](x)
|
| 487 |
+
# dropout
|
| 488 |
+
x = attn.to_out[1](x)
|
| 489 |
+
if not attn.context_pre_only:
|
| 490 |
+
c = attn.to_out_c(c)
|
| 491 |
+
|
| 492 |
+
if mask is not None:
|
| 493 |
+
mask = mask.unsqueeze(-1)
|
| 494 |
+
x = x.masked_fill(~mask, 0.0)
|
| 495 |
+
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
| 496 |
+
|
| 497 |
+
return x, c
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# DiT Block
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class DiTBlock(nn.Module):
|
| 504 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
| 505 |
+
super().__init__()
|
| 506 |
+
|
| 507 |
+
self.attn_norm = AdaLayerNormZero(dim)
|
| 508 |
+
self.attn = Attention(
|
| 509 |
+
processor=AttnProcessor(),
|
| 510 |
+
dim=dim,
|
| 511 |
+
heads=heads,
|
| 512 |
+
dim_head=dim_head,
|
| 513 |
+
dropout=dropout,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 517 |
+
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 518 |
+
|
| 519 |
+
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
| 520 |
+
# pre-norm & modulation for attention input
|
| 521 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
| 522 |
+
|
| 523 |
+
# attention
|
| 524 |
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
| 525 |
+
|
| 526 |
+
# process attention output for input x
|
| 527 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 528 |
+
|
| 529 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 530 |
+
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 531 |
+
ff_output = self.ff(ff_norm)
|
| 532 |
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
| 533 |
+
|
| 534 |
+
return x
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
# MMDiT Block https://arxiv.org/abs/2403.03206
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class MMDiTBlock(nn.Module):
|
| 541 |
+
r"""
|
| 542 |
+
modified from diffusers/src/diffusers/models/attention.py
|
| 543 |
+
|
| 544 |
+
notes.
|
| 545 |
+
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
| 546 |
+
_x: noised input related. (right part)
|
| 547 |
+
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
| 551 |
+
super().__init__()
|
| 552 |
+
|
| 553 |
+
self.context_pre_only = context_pre_only
|
| 554 |
+
|
| 555 |
+
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
| 556 |
+
self.attn_norm_x = AdaLayerNormZero(dim)
|
| 557 |
+
self.attn = Attention(
|
| 558 |
+
processor=JointAttnProcessor(),
|
| 559 |
+
dim=dim,
|
| 560 |
+
heads=heads,
|
| 561 |
+
dim_head=dim_head,
|
| 562 |
+
dropout=dropout,
|
| 563 |
+
context_dim=dim,
|
| 564 |
+
context_pre_only=context_pre_only,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if not context_pre_only:
|
| 568 |
+
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 569 |
+
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 570 |
+
else:
|
| 571 |
+
self.ff_norm_c = None
|
| 572 |
+
self.ff_c = None
|
| 573 |
+
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 574 |
+
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 575 |
+
|
| 576 |
+
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
| 577 |
+
# pre-norm & modulation for attention input
|
| 578 |
+
if self.context_pre_only:
|
| 579 |
+
norm_c = self.attn_norm_c(c, t)
|
| 580 |
+
else:
|
| 581 |
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
| 582 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
| 583 |
+
|
| 584 |
+
# attention
|
| 585 |
+
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
| 586 |
+
|
| 587 |
+
# process attention output for context c
|
| 588 |
+
if self.context_pre_only:
|
| 589 |
+
c = None
|
| 590 |
+
else: # if not last layer
|
| 591 |
+
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
| 592 |
+
|
| 593 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 594 |
+
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 595 |
+
c_ff_output = self.ff_c(norm_c)
|
| 596 |
+
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
| 597 |
+
|
| 598 |
+
# process attention output for input x
|
| 599 |
+
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
| 600 |
+
|
| 601 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 602 |
+
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
| 603 |
+
x_ff_output = self.ff_x(norm_x)
|
| 604 |
+
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
| 605 |
+
|
| 606 |
+
return c, x
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
# time step conditioning embedding
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class TimestepEmbedding(nn.Module):
|
| 613 |
+
def __init__(self, dim, freq_embed_dim=256):
|
| 614 |
+
super().__init__()
|
| 615 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
| 616 |
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 617 |
+
|
| 618 |
+
def forward(self, timestep: float["b"]): # noqa: F821
|
| 619 |
+
time_hidden = self.time_embed(timestep)
|
| 620 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
| 621 |
+
time = self.time_mlp(time_hidden) # b d
|
| 622 |
+
return time
|
funcineforge/models/modules/hifigan/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
def get_padding(kernel_size, dilation=1):
|
| 3 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 7 |
+
classname = m.__class__.__name__
|
| 8 |
+
if classname.find("Conv") != -1:
|
| 9 |
+
m.weight.data.normal_(mean, std)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from funcineforge.models.modules.hifigan.generator import HifiGenerator, NsfHifiGenerator, HiFTGenerator
|
| 13 |
+
from funcineforge.models.modules.hifigan.discriminator import MultipleDiscriminator
|
| 14 |
+
from funcineforge.models.modules.hifigan.nsf_utils import ConvRNNF0Predictor
|
funcineforge/models/modules/hifigan/activations.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, sin, pow
|
| 6 |
+
from torch.nn import Parameter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Snake(nn.Module):
|
| 10 |
+
'''
|
| 11 |
+
Implementation of a sine-based periodic activation function
|
| 12 |
+
Shape:
|
| 13 |
+
- Input: (B, C, T)
|
| 14 |
+
- Output: (B, C, T), same shape as the input
|
| 15 |
+
Parameters:
|
| 16 |
+
- alpha - trainable parameter
|
| 17 |
+
References:
|
| 18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 19 |
+
https://arxiv.org/abs/2006.08195
|
| 20 |
+
Examples:
|
| 21 |
+
>>> a1 = snake(256)
|
| 22 |
+
>>> x = torch.randn(256)
|
| 23 |
+
>>> x = a1(x)
|
| 24 |
+
'''
|
| 25 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 26 |
+
'''
|
| 27 |
+
Initialization.
|
| 28 |
+
INPUT:
|
| 29 |
+
- in_features: shape of the input
|
| 30 |
+
- alpha: trainable parameter
|
| 31 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 32 |
+
alpha will be trained along with the rest of your model.
|
| 33 |
+
'''
|
| 34 |
+
super(Snake, self).__init__()
|
| 35 |
+
self.in_features = in_features
|
| 36 |
+
|
| 37 |
+
# initialize alpha
|
| 38 |
+
self.alpha_logscale = alpha_logscale
|
| 39 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 40 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 41 |
+
else: # linear scale alphas initialized to ones
|
| 42 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 43 |
+
|
| 44 |
+
self.alpha.requires_grad = alpha_trainable
|
| 45 |
+
|
| 46 |
+
self.no_div_by_zero = 0.000000001
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
'''
|
| 50 |
+
Forward pass of the function.
|
| 51 |
+
Applies the function to the input elementwise.
|
| 52 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 53 |
+
'''
|
| 54 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 55 |
+
if self.alpha_logscale:
|
| 56 |
+
alpha = torch.exp(alpha)
|
| 57 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 58 |
+
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SnakeBeta(nn.Module):
|
| 63 |
+
'''
|
| 64 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 65 |
+
Shape:
|
| 66 |
+
- Input: (B, C, T)
|
| 67 |
+
- Output: (B, C, T), same shape as the input
|
| 68 |
+
Parameters:
|
| 69 |
+
- alpha - trainable parameter that controls frequency
|
| 70 |
+
- beta - trainable parameter that controls magnitude
|
| 71 |
+
References:
|
| 72 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 73 |
+
https://arxiv.org/abs/2006.08195
|
| 74 |
+
Examples:
|
| 75 |
+
>>> a1 = snakebeta(256)
|
| 76 |
+
>>> x = torch.randn(256)
|
| 77 |
+
>>> x = a1(x)
|
| 78 |
+
'''
|
| 79 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 80 |
+
'''
|
| 81 |
+
Initialization.
|
| 82 |
+
INPUT:
|
| 83 |
+
- in_features: shape of the input
|
| 84 |
+
- alpha - trainable parameter that controls frequency
|
| 85 |
+
- beta - trainable parameter that controls magnitude
|
| 86 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 87 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 88 |
+
alpha will be trained along with the rest of your model.
|
| 89 |
+
'''
|
| 90 |
+
super(SnakeBeta, self).__init__()
|
| 91 |
+
self.in_features = in_features
|
| 92 |
+
|
| 93 |
+
# initialize alpha
|
| 94 |
+
self.alpha_logscale = alpha_logscale
|
| 95 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 96 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 97 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 98 |
+
else: # linear scale alphas initialized to ones
|
| 99 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 100 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 101 |
+
|
| 102 |
+
self.alpha.requires_grad = alpha_trainable
|
| 103 |
+
self.beta.requires_grad = alpha_trainable
|
| 104 |
+
|
| 105 |
+
self.no_div_by_zero = 0.000000001
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
'''
|
| 109 |
+
Forward pass of the function.
|
| 110 |
+
Applies the function to the input elementwise.
|
| 111 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 112 |
+
'''
|
| 113 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 114 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 115 |
+
if self.alpha_logscale:
|
| 116 |
+
alpha = torch.exp(alpha)
|
| 117 |
+
beta = torch.exp(beta)
|
| 118 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 119 |
+
|
| 120 |
+
return x
|
funcineforge/models/modules/hifigan/discriminator.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""hifigan based dicriminator implementation.
|
| 2 |
+
|
| 3 |
+
This code is modified from https://github.com/jik876/hifi-gan and https://github.com/kan-bayashi/ParallelWaveGAN.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import typing as tp
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.nn import Conv2d, AvgPool1d, Conv1d
|
| 13 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
| 14 |
+
|
| 15 |
+
from funcineforge.models.modules.hifigan import get_padding
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DiscriminatorP(torch.nn.Module):
|
| 19 |
+
def __init__(self, period, kernel_size=5, stride=3,
|
| 20 |
+
use_spectral_norm=False, lrelu_slope=0.1):
|
| 21 |
+
super(DiscriminatorP, self).__init__()
|
| 22 |
+
self.period = period
|
| 23 |
+
self.lrelu_slope = lrelu_slope
|
| 24 |
+
|
| 25 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 26 |
+
self.convs = nn.ModuleList([
|
| 27 |
+
norm_f(
|
| 28 |
+
Conv2d(
|
| 29 |
+
1,
|
| 30 |
+
32, (kernel_size, 1), (stride, 1),
|
| 31 |
+
padding=(get_padding(5, 1), 0))),
|
| 32 |
+
norm_f(
|
| 33 |
+
Conv2d(
|
| 34 |
+
32,
|
| 35 |
+
128, (kernel_size, 1), (stride, 1),
|
| 36 |
+
padding=(get_padding(5, 1), 0))),
|
| 37 |
+
norm_f(
|
| 38 |
+
Conv2d(
|
| 39 |
+
128,
|
| 40 |
+
512, (kernel_size, 1), (stride, 1),
|
| 41 |
+
padding=(get_padding(5, 1), 0))),
|
| 42 |
+
norm_f(
|
| 43 |
+
Conv2d(
|
| 44 |
+
512,
|
| 45 |
+
1024, (kernel_size, 1), (stride, 1),
|
| 46 |
+
padding=(get_padding(5, 1), 0))),
|
| 47 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 48 |
+
])
|
| 49 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
fmap = []
|
| 53 |
+
|
| 54 |
+
# 1d to 2d
|
| 55 |
+
b, c, t = x.shape
|
| 56 |
+
if t % self.period != 0: # pad first
|
| 57 |
+
n_pad = self.period - (t % self.period)
|
| 58 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 59 |
+
t = t + n_pad
|
| 60 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 61 |
+
|
| 62 |
+
for l in self.convs:
|
| 63 |
+
x = l(x)
|
| 64 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 65 |
+
fmap.append(x)
|
| 66 |
+
x = self.conv_post(x)
|
| 67 |
+
fmap.append(x)
|
| 68 |
+
x = torch.flatten(x, 1, -1)
|
| 69 |
+
|
| 70 |
+
return x, fmap
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 74 |
+
def __init__(self,
|
| 75 |
+
in_channels: int = 1,
|
| 76 |
+
periods: tp.List[int] = [2, 3, 5, 7, 11]):
|
| 77 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 78 |
+
self.discriminators = nn.ModuleList([
|
| 79 |
+
DiscriminatorP(p) for p in periods
|
| 80 |
+
])
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
|
| 83 |
+
"""Calculate forward propagation.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
x (Tensor): Input noise signal (B, 1, T).
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
List: List of list of each discriminator outputs, which consists of each
|
| 90 |
+
layer output tensors.
|
| 91 |
+
|
| 92 |
+
"""
|
| 93 |
+
outs = []
|
| 94 |
+
for f in self.discriminators:
|
| 95 |
+
# outs += [f(x)]
|
| 96 |
+
if return_intermediates:
|
| 97 |
+
outs.append(f(x))
|
| 98 |
+
else:
|
| 99 |
+
outs.append(f(x)[0])
|
| 100 |
+
|
| 101 |
+
return outs
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class DiscriminatorS(torch.nn.Module):
|
| 105 |
+
def __init__(self, use_spectral_norm=False, lrelu_slope=0.1):
|
| 106 |
+
super(DiscriminatorS, self).__init__()
|
| 107 |
+
self.lrelu_slope = lrelu_slope
|
| 108 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 109 |
+
self.convs = nn.ModuleList([
|
| 110 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
| 111 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 112 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 113 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 114 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 115 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 116 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 117 |
+
])
|
| 118 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
fmap = []
|
| 122 |
+
for l in self.convs:
|
| 123 |
+
x = l(x)
|
| 124 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 125 |
+
fmap.append(x)
|
| 126 |
+
x = self.conv_post(x)
|
| 127 |
+
fmap.append(x)
|
| 128 |
+
x = torch.flatten(x, 1, -1)
|
| 129 |
+
|
| 130 |
+
return x, fmap
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
| 134 |
+
def __init__(self, in_channels: int = 1, nb_scales: int = 3):
|
| 135 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 136 |
+
self.discriminators = nn.ModuleList([
|
| 137 |
+
DiscriminatorS(use_spectral_norm=True),
|
| 138 |
+
DiscriminatorS(),
|
| 139 |
+
DiscriminatorS(),
|
| 140 |
+
])
|
| 141 |
+
self.meanpools = nn.ModuleList(
|
| 142 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
|
| 143 |
+
|
| 144 |
+
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
|
| 145 |
+
"""Calculate forward propagation.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
x (Tensor): Input noise signal (B, 1, T).
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
List: List of list of each discriminator outputs, which consists of each
|
| 152 |
+
layer output tensors.
|
| 153 |
+
|
| 154 |
+
"""
|
| 155 |
+
outs = []
|
| 156 |
+
for i, f in enumerate(self.discriminators):
|
| 157 |
+
if i != 0:
|
| 158 |
+
x = self.meanpools[i - 1](x)
|
| 159 |
+
if return_intermediates:
|
| 160 |
+
outs.append(f(x))
|
| 161 |
+
else:
|
| 162 |
+
outs.append(f(x)[0])
|
| 163 |
+
|
| 164 |
+
return outs
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DiscriminatorR(nn.Module):
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
stft_params: tp.List[int],
|
| 171 |
+
lrelu_slope: float = 0.1,
|
| 172 |
+
use_spectral_norm: bool = False,
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
|
| 176 |
+
self.stft_params = stft_params
|
| 177 |
+
self.lrelu_slope = lrelu_slope
|
| 178 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 179 |
+
|
| 180 |
+
self.convs = nn.ModuleList([
|
| 181 |
+
norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
| 182 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 183 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 184 |
+
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 185 |
+
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
| 186 |
+
])
|
| 187 |
+
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
| 188 |
+
|
| 189 |
+
def spectrogram(self, x):
|
| 190 |
+
n_fft, hop_length, win_length = self.stft_params
|
| 191 |
+
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
|
| 192 |
+
x = x.squeeze(1)
|
| 193 |
+
spec = torch.stft(x, n_fft, hop_length=hop_length, win_length=win_length,
|
| 194 |
+
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
| 195 |
+
|
| 196 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 197 |
+
mag = torch.norm(spec, p=2, dim =-1) #[B, F, TT]
|
| 198 |
+
|
| 199 |
+
return mag
|
| 200 |
+
|
| 201 |
+
def forward(self, x):
|
| 202 |
+
fmap = []
|
| 203 |
+
|
| 204 |
+
x = self.spectrogram(x).unsqueeze(1)
|
| 205 |
+
for l in self.convs:
|
| 206 |
+
x = l(x)
|
| 207 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 208 |
+
fmap.append(x)
|
| 209 |
+
x = self.conv_post(x)
|
| 210 |
+
fmap.append(x)
|
| 211 |
+
|
| 212 |
+
x = torch.flatten(x, 1, -1)
|
| 213 |
+
|
| 214 |
+
return x, fmap
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class MultiResolutionDiscriminator(nn.Module):
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
in_channels: int,
|
| 221 |
+
fft_sizes: tp.List[int] = [1024, 2048, 512],
|
| 222 |
+
hop_sizes: tp.List[int] = [120, 240, 50],
|
| 223 |
+
win_lengths: tp.List[int] = [600, 1200, 240],
|
| 224 |
+
lrelu_slope: float = 0.1,
|
| 225 |
+
):
|
| 226 |
+
super().__init__()
|
| 227 |
+
|
| 228 |
+
self.discriminators = nn.ModuleList()
|
| 229 |
+
|
| 230 |
+
for fft, hop, win in zip(fft_sizes, hop_sizes, win_lengths):
|
| 231 |
+
self.discriminators.append(DiscriminatorR([fft, hop, win], lrelu_slope))
|
| 232 |
+
|
| 233 |
+
def forward(self, x: torch.Tensor, return_intermediates: bool = True):
|
| 234 |
+
"""Calculate forward propagation.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
x (Tensor): Input noise signal (B, 1, T).
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
List: List of list of each discriminator outputs, which consists of each
|
| 241 |
+
layer output tensors.
|
| 242 |
+
|
| 243 |
+
"""
|
| 244 |
+
outs = []
|
| 245 |
+
for f in self.discriminators:
|
| 246 |
+
if return_intermediates:
|
| 247 |
+
outs.append(f(x))
|
| 248 |
+
else:
|
| 249 |
+
outs.append(f(x)[0])
|
| 250 |
+
|
| 251 |
+
return outs
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class MultipleDiscriminator(nn.Module):
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
input_size: int = 1,
|
| 258 |
+
disc_conf_list: tp.List[tp.Dict[str, tp.Any]] = None,
|
| 259 |
+
):
|
| 260 |
+
super().__init__()
|
| 261 |
+
|
| 262 |
+
self.support_disc_choices = dict(
|
| 263 |
+
mpd=MultiPeriodDiscriminator,
|
| 264 |
+
msd=MultiScaleDiscriminator,
|
| 265 |
+
mrd=MultiResolutionDiscriminator,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.discriminators = nn.ModuleList()
|
| 269 |
+
self.discriminator_type_lst = []
|
| 270 |
+
for args in disc_conf_list:
|
| 271 |
+
assert "name" in args, "disc_conf must have `name` attr to specific disc type."
|
| 272 |
+
disc_type = args.pop("name")
|
| 273 |
+
assert disc_type in self.support_disc_choices, \
|
| 274 |
+
"Unsupported discriminator type, only support {}".format(
|
| 275 |
+
",".join(self.support_disc_choices.keys())
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
disc_class = self.support_disc_choices[disc_type]
|
| 279 |
+
one_disc = disc_class(in_channels=input_size, **args)
|
| 280 |
+
self.discriminators.append(one_disc)
|
| 281 |
+
# add back to the args for dump config.yaml
|
| 282 |
+
args["name"] = disc_type
|
| 283 |
+
self.discriminator_type_lst.append(disc_type)
|
| 284 |
+
|
| 285 |
+
def get_discriminator_type_lst(self) -> tp.List[str]:
|
| 286 |
+
return self.discriminator_type_lst
|
| 287 |
+
|
| 288 |
+
def forward(self, x, return_intermediates=True):
|
| 289 |
+
retval = []
|
| 290 |
+
for disc in self.discriminators:
|
| 291 |
+
out = disc(x, return_intermediates=return_intermediates)
|
| 292 |
+
if isinstance(out, tuple):
|
| 293 |
+
retval.append(out)
|
| 294 |
+
elif isinstance(out, list):
|
| 295 |
+
retval.extend(out)
|
| 296 |
+
else:
|
| 297 |
+
raise TypeError("The return value of discriminator must be tuple or list[tuple]")
|
| 298 |
+
|
| 299 |
+
return retval
|
funcineforge/models/modules/hifigan/generator.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""hifigan based generator implementation.
|
| 2 |
+
|
| 3 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
| 4 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
| 5 |
+
https://github.com/NVIDIA/BigVGAN
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy.signal import get_window
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
| 18 |
+
from torch.nn.utils import weight_norm
|
| 19 |
+
from torch.nn.utils import remove_weight_norm
|
| 20 |
+
|
| 21 |
+
from funcineforge.models.modules.hifigan import get_padding, init_weights
|
| 22 |
+
from funcineforge.models.modules.hifigan.activations import Snake, SnakeBeta
|
| 23 |
+
from funcineforge.models.modules.hifigan.nsf_utils import SourceModule, SourceModuleHnNSF
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ResBlock(torch.nn.Module):
|
| 27 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
channels: int = 512,
|
| 31 |
+
kernel_size: int = 3,
|
| 32 |
+
dilations: tp.List[int] = [1, 3, 5],
|
| 33 |
+
use_additional_convs: bool = True,
|
| 34 |
+
nonlinear_activation: str = "LeakyReLU",
|
| 35 |
+
nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
|
| 36 |
+
):
|
| 37 |
+
super(ResBlock, self).__init__()
|
| 38 |
+
self.use_additional_convs = use_additional_convs
|
| 39 |
+
|
| 40 |
+
self.convs1 = nn.ModuleList()
|
| 41 |
+
if use_additional_convs:
|
| 42 |
+
self.convs2 = nn.ModuleList()
|
| 43 |
+
|
| 44 |
+
for dilation in dilations:
|
| 45 |
+
self.convs1.append(
|
| 46 |
+
weight_norm(
|
| 47 |
+
Conv1d(
|
| 48 |
+
channels,
|
| 49 |
+
channels,
|
| 50 |
+
kernel_size,
|
| 51 |
+
1,
|
| 52 |
+
dilation=dilation,
|
| 53 |
+
padding=get_padding(kernel_size, dilation)
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if use_additional_convs:
|
| 59 |
+
self.convs2.append(
|
| 60 |
+
weight_norm(
|
| 61 |
+
Conv1d(
|
| 62 |
+
channels,
|
| 63 |
+
channels,
|
| 64 |
+
kernel_size,
|
| 65 |
+
1,
|
| 66 |
+
dilation=1,
|
| 67 |
+
padding=get_padding(kernel_size, 1)
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
self.convs1.apply(init_weights)
|
| 73 |
+
if use_additional_convs:
|
| 74 |
+
self.convs2.apply(init_weights)
|
| 75 |
+
|
| 76 |
+
if nonlinear_activation == "LeakyReLU":
|
| 77 |
+
self.activations1 = nn.ModuleList([
|
| 78 |
+
nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
|
| 79 |
+
for _ in range(len(self.convs1))
|
| 80 |
+
])
|
| 81 |
+
if use_additional_convs:
|
| 82 |
+
self.activations2 = nn.ModuleList([
|
| 83 |
+
nn.LeakyReLU(nonlinear_activation_params["negative_slope"])
|
| 84 |
+
for _ in range(len(self.convs2))
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
elif nonlinear_activation == "Snake":
|
| 88 |
+
self.activations1 = nn.ModuleList([
|
| 89 |
+
Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
| 90 |
+
for _ in range(len(self.convs1))
|
| 91 |
+
])
|
| 92 |
+
if use_additional_convs:
|
| 93 |
+
self.activations2 = nn.ModuleList([
|
| 94 |
+
Snake(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
| 95 |
+
for _ in range(len(self.convs2))
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
elif nonlinear_activation == "SnakeBeta":
|
| 99 |
+
self.activations1 = nn.ModuleList([
|
| 100 |
+
SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
| 101 |
+
for _ in range(len(self.convs1))
|
| 102 |
+
])
|
| 103 |
+
if use_additional_convs:
|
| 104 |
+
self.activations2 = nn.ModuleList([
|
| 105 |
+
SnakeBeta(channels, alpha_logscale=nonlinear_activation_params.get("alpha_logscale", False))
|
| 106 |
+
for _ in range(len(self.convs2))
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
for idx in range(len(self.convs1)):
|
| 114 |
+
xt = self.activations1[idx](x)
|
| 115 |
+
xt = self.convs1[idx](xt)
|
| 116 |
+
if self.use_additional_convs:
|
| 117 |
+
xt = self.activations2[idx](xt)
|
| 118 |
+
xt = self.convs2[idx](xt)
|
| 119 |
+
x = xt + x
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
def remove_weight_norm(self):
|
| 123 |
+
for idx in range(len(self.convs1)):
|
| 124 |
+
remove_weight_norm(self.convs1[idx])
|
| 125 |
+
if self.use_additional_convs:
|
| 126 |
+
remove_weight_norm(self.convs2[idx])
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class HifiGenerator(nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
in_channels: int = 80,
|
| 133 |
+
base_channels: int = 512,
|
| 134 |
+
global_channels: int = -1,
|
| 135 |
+
upsample_rates: tp.List[int] = [8, 8, 2, 2],
|
| 136 |
+
upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
|
| 137 |
+
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
| 138 |
+
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 139 |
+
resblock_nonlinear_activation: str = "LeakyReLU",
|
| 140 |
+
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
|
| 141 |
+
use_additional_convs: bool = True,
|
| 142 |
+
cond_in_each_up_layer: bool = False,
|
| 143 |
+
lrelu_slope: float = 0.1,
|
| 144 |
+
act_pre_each_up_layer: bool = True
|
| 145 |
+
):
|
| 146 |
+
super(HifiGenerator, self).__init__()
|
| 147 |
+
|
| 148 |
+
self.out_channels = 1
|
| 149 |
+
self.global_channels = global_channels
|
| 150 |
+
self.use_additional_convs = use_additional_convs
|
| 151 |
+
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
|
| 152 |
+
self.lrelu_slope = lrelu_slope
|
| 153 |
+
self.act_pre_each_up_layer = act_pre_each_up_layer
|
| 154 |
+
|
| 155 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 156 |
+
self.num_upsamples = len(upsample_rates)
|
| 157 |
+
|
| 158 |
+
self.conv_pre = weight_norm(
|
| 159 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.ups = nn.ModuleList()
|
| 163 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 164 |
+
self.ups.append(
|
| 165 |
+
weight_norm(
|
| 166 |
+
ConvTranspose1d(
|
| 167 |
+
base_channels // (2**i),
|
| 168 |
+
base_channels // (2**(i + 1)),
|
| 169 |
+
k,
|
| 170 |
+
u,
|
| 171 |
+
padding=(k - u) // 2,
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
self.resblocks = nn.ModuleList()
|
| 177 |
+
for i in range(len(self.ups)):
|
| 178 |
+
ch = base_channels // (2**(i + 1))
|
| 179 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 180 |
+
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
|
| 181 |
+
resblock_nonlinear_activation,
|
| 182 |
+
resblock_nonlinear_activation_params))
|
| 183 |
+
|
| 184 |
+
if self.global_channels > 0:
|
| 185 |
+
self.conv_global_cond = weight_norm(
|
| 186 |
+
Conv1d(global_channels, base_channels, 1)
|
| 187 |
+
)
|
| 188 |
+
self.conv_global_cond.apply(init_weights)
|
| 189 |
+
|
| 190 |
+
if self.cond_in_each_up_layer:
|
| 191 |
+
self.conv_conds = nn.ModuleList()
|
| 192 |
+
for i in range(len(self.ups)):
|
| 193 |
+
self.conv_conds.append(weight_norm(
|
| 194 |
+
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
|
| 195 |
+
)
|
| 196 |
+
self.conv_conds.apply(init_weights)
|
| 197 |
+
|
| 198 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 199 |
+
self.ups.apply(init_weights)
|
| 200 |
+
self.conv_post.apply(init_weights)
|
| 201 |
+
|
| 202 |
+
def output_size(self):
|
| 203 |
+
return self.out_channels
|
| 204 |
+
|
| 205 |
+
def forward(self, x: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 206 |
+
# x in (B, in_channels, T), g in (B, global_channels, 1)
|
| 207 |
+
x = self.conv_pre(x)
|
| 208 |
+
if self.global_channels > 0 and g is not None:
|
| 209 |
+
x = x + self.conv_global_cond(g)
|
| 210 |
+
|
| 211 |
+
for i in range(self.num_upsamples):
|
| 212 |
+
if self.act_pre_each_up_layer:
|
| 213 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 214 |
+
x = self.ups[i](x)
|
| 215 |
+
|
| 216 |
+
if self.cond_in_each_up_layer and g is not None:
|
| 217 |
+
x = x + self.conv_conds[i](g)
|
| 218 |
+
|
| 219 |
+
xs = None
|
| 220 |
+
for j in range(self.num_kernels):
|
| 221 |
+
if xs is None:
|
| 222 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 223 |
+
else:
|
| 224 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 225 |
+
x = xs / self.num_kernels
|
| 226 |
+
|
| 227 |
+
x = F.leaky_relu(x)
|
| 228 |
+
x = self.conv_post(x)
|
| 229 |
+
x = torch.tanh(x)
|
| 230 |
+
|
| 231 |
+
return x
|
| 232 |
+
|
| 233 |
+
def remove_weight_norm(self):
|
| 234 |
+
print('Removing weight norm...')
|
| 235 |
+
for l in self.ups:
|
| 236 |
+
remove_weight_norm(l)
|
| 237 |
+
for l in self.resblocks:
|
| 238 |
+
l.remove_weight_norm()
|
| 239 |
+
remove_weight_norm(self.conv_pre)
|
| 240 |
+
remove_weight_norm(self.conv_post)
|
| 241 |
+
if self.global_channels > 0:
|
| 242 |
+
remove_weight_norm(self.conv_global_cond)
|
| 243 |
+
if self.cond_in_each_up_layer:
|
| 244 |
+
for l in self.conv_conds:
|
| 245 |
+
remove_weight_norm(l)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class NsfHifiGenerator(nn.Module):
|
| 249 |
+
"""
|
| 250 |
+
Neural Source Filter + HifiGan
|
| 251 |
+
"""
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
in_channels: int = 80,
|
| 255 |
+
base_channels: int = 512,
|
| 256 |
+
global_channels: int = -1,
|
| 257 |
+
nb_harmonics: int = 7,
|
| 258 |
+
sampling_rate: int = 22050,
|
| 259 |
+
nsf_alpha: float = 0.1,
|
| 260 |
+
nsf_sigma: float = 0.003,
|
| 261 |
+
nsf_voiced_threshold: float = 10,
|
| 262 |
+
upsample_rates: tp.List[int] = [8, 8, 2, 2],
|
| 263 |
+
upsample_kernel_sizes: tp.List[int] = [16, 16, 4, 4],
|
| 264 |
+
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
| 265 |
+
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 266 |
+
resblock_nonlinear_activation: str = "LeakyReLU",
|
| 267 |
+
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"negative_slope": 0.1},
|
| 268 |
+
use_additional_convs: bool = True,
|
| 269 |
+
cond_in_each_up_layer: bool = False,
|
| 270 |
+
lrelu_slope: float = 0.1,
|
| 271 |
+
act_pre_each_up_layer: bool = True
|
| 272 |
+
):
|
| 273 |
+
super(NsfHifiGenerator, self).__init__()
|
| 274 |
+
|
| 275 |
+
self.out_channels = 1
|
| 276 |
+
self.global_channels = global_channels
|
| 277 |
+
self.nb_harmonics = nb_harmonics
|
| 278 |
+
self.sampling_rate = sampling_rate
|
| 279 |
+
self.use_additional_convs = use_additional_convs
|
| 280 |
+
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
|
| 281 |
+
self.lrelu_slope = lrelu_slope
|
| 282 |
+
self.act_pre_each_up_layer = act_pre_each_up_layer
|
| 283 |
+
|
| 284 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 285 |
+
self.num_upsamples = len(upsample_rates)
|
| 286 |
+
|
| 287 |
+
self.source_module = SourceModule(nb_harmonics, np.cumprod(upsample_rates)[-1],
|
| 288 |
+
sampling_rate, nsf_alpha, nsf_sigma, nsf_voiced_threshold)
|
| 289 |
+
self.conv_pre = weight_norm(
|
| 290 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Up
|
| 294 |
+
self.ups = nn.ModuleList()
|
| 295 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 296 |
+
self.ups.append(
|
| 297 |
+
weight_norm(
|
| 298 |
+
ConvTranspose1d(
|
| 299 |
+
base_channels // (2**i),
|
| 300 |
+
base_channels // (2**(i + 1)),
|
| 301 |
+
k,
|
| 302 |
+
u,
|
| 303 |
+
padding=(k - u) // 2,
|
| 304 |
+
)
|
| 305 |
+
)
|
| 306 |
+
)
|
| 307 |
+
# Down
|
| 308 |
+
self.source_downs = nn.ModuleList()
|
| 309 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 310 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 311 |
+
for i, u in enumerate(downsample_cum_rates[::-1]):
|
| 312 |
+
if (u == 1):
|
| 313 |
+
self.source_downs.append(
|
| 314 |
+
weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), 1, 1))
|
| 315 |
+
)
|
| 316 |
+
else:
|
| 317 |
+
self.source_downs.append(
|
| 318 |
+
weight_norm(Conv1d(1, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2)))
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.resblocks = nn.ModuleList()
|
| 322 |
+
for i in range(len(self.ups)):
|
| 323 |
+
ch = base_channels // (2**(i + 1))
|
| 324 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 325 |
+
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
|
| 326 |
+
resblock_nonlinear_activation,
|
| 327 |
+
resblock_nonlinear_activation_params))
|
| 328 |
+
|
| 329 |
+
if self.global_channels > 0:
|
| 330 |
+
self.conv_global_cond = weight_norm(
|
| 331 |
+
Conv1d(global_channels, base_channels, 1)
|
| 332 |
+
)
|
| 333 |
+
self.conv_global_cond.apply(init_weights)
|
| 334 |
+
|
| 335 |
+
if self.cond_in_each_up_layer:
|
| 336 |
+
self.conv_conds = nn.ModuleList()
|
| 337 |
+
for i in range(len(self.ups)):
|
| 338 |
+
self.conv_conds.append(weight_norm(
|
| 339 |
+
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
|
| 340 |
+
)
|
| 341 |
+
self.conv_conds.apply(init_weights)
|
| 342 |
+
|
| 343 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 344 |
+
self.ups.apply(init_weights)
|
| 345 |
+
self.conv_post.apply(init_weights)
|
| 346 |
+
|
| 347 |
+
def output_size(self):
|
| 348 |
+
return self.out_channels
|
| 349 |
+
|
| 350 |
+
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
| 351 |
+
return self.source_module(f0.unsqueeze(1))
|
| 352 |
+
|
| 353 |
+
def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 354 |
+
# x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
|
| 355 |
+
|
| 356 |
+
s = self._f02source(f0)
|
| 357 |
+
|
| 358 |
+
x = self.conv_pre(x)
|
| 359 |
+
if self.global_channels > 0 and g is not None:
|
| 360 |
+
x = x + self.conv_global_cond(g)
|
| 361 |
+
|
| 362 |
+
for i in range(self.num_upsamples):
|
| 363 |
+
if self.act_pre_each_up_layer:
|
| 364 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 365 |
+
x = self.ups[i](x)
|
| 366 |
+
|
| 367 |
+
if self.cond_in_each_up_layer and g is not None:
|
| 368 |
+
x = x + self.conv_conds[i](g)
|
| 369 |
+
|
| 370 |
+
# fusion
|
| 371 |
+
x = x + self.source_downs[i](s)
|
| 372 |
+
|
| 373 |
+
xs = None
|
| 374 |
+
for j in range(self.num_kernels):
|
| 375 |
+
if xs is None:
|
| 376 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 377 |
+
else:
|
| 378 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 379 |
+
x = xs / self.num_kernels
|
| 380 |
+
|
| 381 |
+
x = F.leaky_relu(x)
|
| 382 |
+
x = self.conv_post(x)
|
| 383 |
+
x = torch.tanh(x)
|
| 384 |
+
|
| 385 |
+
return x
|
| 386 |
+
|
| 387 |
+
def remove_weight_norm(self):
|
| 388 |
+
print('Removing weight norm...')
|
| 389 |
+
for l in self.ups:
|
| 390 |
+
remove_weight_norm(l)
|
| 391 |
+
for l in self.resblocks:
|
| 392 |
+
l.remove_weight_norm()
|
| 393 |
+
remove_weight_norm(self.conv_pre)
|
| 394 |
+
remove_weight_norm(self.conv_post)
|
| 395 |
+
if self.global_channels > 0:
|
| 396 |
+
remove_weight_norm(self.conv_global_cond)
|
| 397 |
+
if self.cond_in_each_up_layer:
|
| 398 |
+
for l in self.conv_conds:
|
| 399 |
+
remove_weight_norm(l)
|
| 400 |
+
self.source_module.remove_weight_norm()
|
| 401 |
+
for l in self.source_downs:
|
| 402 |
+
remove_weight_norm(l)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class HiFTGenerator(nn.Module):
|
| 406 |
+
"""
|
| 407 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
| 408 |
+
https://arxiv.org/abs/2309.09493
|
| 409 |
+
"""
|
| 410 |
+
def __init__(
|
| 411 |
+
self,
|
| 412 |
+
in_channels: int = 80,
|
| 413 |
+
base_channels: int = 512,
|
| 414 |
+
global_channels: int = -1,
|
| 415 |
+
nb_harmonics: int = 8,
|
| 416 |
+
sampling_rate: int = 22050,
|
| 417 |
+
nsf_alpha: float = 0.1,
|
| 418 |
+
nsf_sigma: float = 0.003,
|
| 419 |
+
nsf_voiced_threshold: float = 10,
|
| 420 |
+
upsample_rates: tp.List[int] = [8, 8],
|
| 421 |
+
upsample_kernel_sizes: tp.List[int] = [16, 16],
|
| 422 |
+
istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
| 423 |
+
resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
| 424 |
+
resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 425 |
+
resblock_nonlinear_activation: str = "Snake",
|
| 426 |
+
resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
|
| 427 |
+
source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
| 428 |
+
source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
| 429 |
+
source_resblock_nonlinear_activation: str = "Snake",
|
| 430 |
+
source_resblock_nonlinear_activation_params: tp.Dict[str, tp.Any] = {"alpha_logscale": False},
|
| 431 |
+
use_additional_convs: bool = True,
|
| 432 |
+
cond_in_each_up_layer: bool = False,
|
| 433 |
+
lrelu_slope: float = 0.1,
|
| 434 |
+
act_pre_each_up_layer: bool = True,
|
| 435 |
+
audio_limit: float = 0.99,
|
| 436 |
+
):
|
| 437 |
+
super(HiFTGenerator, self).__init__()
|
| 438 |
+
|
| 439 |
+
self.out_channels = 1
|
| 440 |
+
self.global_channels = global_channels
|
| 441 |
+
self.nb_harmonics = nb_harmonics
|
| 442 |
+
self.sampling_rate = sampling_rate
|
| 443 |
+
self.istft_params = istft_params
|
| 444 |
+
self.use_additional_convs = use_additional_convs
|
| 445 |
+
self.cond_in_each_up_layer = cond_in_each_up_layer if global_channels > 0 else False
|
| 446 |
+
self.lrelu_slope = lrelu_slope
|
| 447 |
+
self.act_pre_each_up_layer = act_pre_each_up_layer
|
| 448 |
+
self.audio_limit = audio_limit
|
| 449 |
+
|
| 450 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 451 |
+
self.num_upsamples = len(upsample_rates)
|
| 452 |
+
self.m_source = SourceModuleHnNSF(
|
| 453 |
+
sampling_rate=sampling_rate,
|
| 454 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
| 455 |
+
harmonic_num=nb_harmonics,
|
| 456 |
+
sine_amp=nsf_alpha,
|
| 457 |
+
add_noise_std=nsf_sigma,
|
| 458 |
+
voiced_threshod=nsf_voiced_threshold)
|
| 459 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
| 460 |
+
|
| 461 |
+
self.conv_pre = weight_norm(
|
| 462 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# Up
|
| 466 |
+
self.ups = nn.ModuleList()
|
| 467 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 468 |
+
self.ups.append(
|
| 469 |
+
weight_norm(
|
| 470 |
+
ConvTranspose1d(
|
| 471 |
+
base_channels // (2**i),
|
| 472 |
+
base_channels // (2**(i + 1)),
|
| 473 |
+
k,
|
| 474 |
+
u,
|
| 475 |
+
padding=(k - u) // 2,
|
| 476 |
+
)
|
| 477 |
+
)
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Down
|
| 481 |
+
self.source_downs = nn.ModuleList()
|
| 482 |
+
self.source_resblocks = nn.ModuleList()
|
| 483 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 484 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 485 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
|
| 486 |
+
source_resblock_dilation_sizes)):
|
| 487 |
+
if u == 1:
|
| 488 |
+
self.source_downs.append(
|
| 489 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
self.source_downs.append(
|
| 493 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u*2, u, padding=(u//2))
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
self.source_resblocks.append(
|
| 497 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d,
|
| 498 |
+
use_additional_convs, source_resblock_nonlinear_activation,
|
| 499 |
+
source_resblock_nonlinear_activation_params)
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
self.resblocks = nn.ModuleList()
|
| 503 |
+
for i in range(len(self.ups)):
|
| 504 |
+
ch = base_channels // (2**(i + 1))
|
| 505 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 506 |
+
self.resblocks.append(ResBlock(ch, k, d, use_additional_convs,
|
| 507 |
+
resblock_nonlinear_activation,
|
| 508 |
+
resblock_nonlinear_activation_params))
|
| 509 |
+
|
| 510 |
+
if self.global_channels > 0:
|
| 511 |
+
self.conv_global_cond = weight_norm(
|
| 512 |
+
Conv1d(global_channels, base_channels, 1)
|
| 513 |
+
)
|
| 514 |
+
self.conv_global_cond.apply(init_weights)
|
| 515 |
+
|
| 516 |
+
if self.cond_in_each_up_layer:
|
| 517 |
+
self.conv_conds = nn.ModuleList()
|
| 518 |
+
for i in range(len(self.ups)):
|
| 519 |
+
self.conv_conds.append(weight_norm(
|
| 520 |
+
nn.Conv1d(global_channels, base_channels // (2**(i + 1)), 1))
|
| 521 |
+
)
|
| 522 |
+
self.conv_conds.apply(init_weights)
|
| 523 |
+
|
| 524 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
| 525 |
+
self.ups.apply(init_weights)
|
| 526 |
+
self.conv_post.apply(init_weights)
|
| 527 |
+
|
| 528 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
| 529 |
+
window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
| 530 |
+
self.register_buffer("stft_window", window)
|
| 531 |
+
|
| 532 |
+
def output_size(self):
|
| 533 |
+
return self.out_channels
|
| 534 |
+
|
| 535 |
+
def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
| 536 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 537 |
+
|
| 538 |
+
har_source, _, _ = self.m_source(f0)
|
| 539 |
+
return har_source.transpose(1, 2)
|
| 540 |
+
|
| 541 |
+
def forward(self, x: torch.Tensor, f0: torch.Tensor, g: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 542 |
+
# x in (B, in_channels, T), f0 in (B, T), g in (B, global_channels, 1)
|
| 543 |
+
|
| 544 |
+
s = self._f02source(f0)
|
| 545 |
+
|
| 546 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 547 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 548 |
+
|
| 549 |
+
x = self.conv_pre(x)
|
| 550 |
+
if self.global_channels > 0 and g is not None:
|
| 551 |
+
x = x + self.conv_global_cond(g)
|
| 552 |
+
|
| 553 |
+
for i in range(self.num_upsamples):
|
| 554 |
+
if self.act_pre_each_up_layer:
|
| 555 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 556 |
+
x = self.ups[i](x)
|
| 557 |
+
|
| 558 |
+
if self.cond_in_each_up_layer and g is not None:
|
| 559 |
+
x = x + self.conv_conds[i](g)
|
| 560 |
+
|
| 561 |
+
if i == self.num_upsamples - 1:
|
| 562 |
+
x = self.reflection_pad(x)
|
| 563 |
+
|
| 564 |
+
# fusion
|
| 565 |
+
si = self.source_downs[i](s_stft)
|
| 566 |
+
si = self.source_resblocks[i](si)
|
| 567 |
+
x = x + si
|
| 568 |
+
|
| 569 |
+
xs = None
|
| 570 |
+
for j in range(self.num_kernels):
|
| 571 |
+
if xs is None:
|
| 572 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 573 |
+
else:
|
| 574 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 575 |
+
x = xs / self.num_kernels
|
| 576 |
+
|
| 577 |
+
x = F.leaky_relu(x)
|
| 578 |
+
x = self.conv_post(x)
|
| 579 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
| 580 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
| 581 |
+
|
| 582 |
+
x = self._istft(magnitude, phase)
|
| 583 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
| 584 |
+
return x
|
| 585 |
+
|
| 586 |
+
def remove_weight_norm(self):
|
| 587 |
+
print('Removing weight norm...')
|
| 588 |
+
for l in self.ups:
|
| 589 |
+
remove_weight_norm(l)
|
| 590 |
+
for l in self.resblocks:
|
| 591 |
+
l.remove_weight_norm()
|
| 592 |
+
remove_weight_norm(self.conv_pre)
|
| 593 |
+
remove_weight_norm(self.conv_post)
|
| 594 |
+
if self.global_channels > 0:
|
| 595 |
+
remove_weight_norm(self.conv_global_cond)
|
| 596 |
+
if self.cond_in_each_up_layer:
|
| 597 |
+
for l in self.conv_conds:
|
| 598 |
+
remove_weight_norm(l)
|
| 599 |
+
self.source_module.remove_weight_norm()
|
| 600 |
+
for l in self.source_downs:
|
| 601 |
+
remove_weight_norm(l)
|
| 602 |
+
for l in self.source_resblocks:
|
| 603 |
+
l.remove_weight_norm()
|
| 604 |
+
|
| 605 |
+
def _stft(self, x):
|
| 606 |
+
spec = torch.stft(
|
| 607 |
+
x,
|
| 608 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window,
|
| 609 |
+
return_complex=True)
|
| 610 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 611 |
+
return spec[...,0], spec[...,1]
|
| 612 |
+
|
| 613 |
+
def _istft(self, magnitude, phase):
|
| 614 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
| 615 |
+
real = magnitude * torch.cos(phase)
|
| 616 |
+
img = magnitude * torch.sin(phase)
|
| 617 |
+
inverse_transform = torch.istft(
|
| 618 |
+
# torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1),
|
| 619 |
+
torch.complex(real, img),
|
| 620 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"],
|
| 621 |
+
self.istft_params["n_fft"], window=self.stft_window,
|
| 622 |
+
return_complex=False
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
funcineforge/models/modules/hifigan/mel_spectrum.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils.data
|
| 3 |
+
import numpy as np
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 8 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def dynamic_range_decompression(x, C=1):
|
| 12 |
+
return np.exp(x) / C
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 16 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 20 |
+
return torch.exp(x) / C
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def spectral_normalize_torch(magnitudes):
|
| 24 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 25 |
+
return output
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 29 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 30 |
+
return output
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
mel_basis = {}
|
| 34 |
+
hann_window = {}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 38 |
+
if torch.min(y) < -1.:
|
| 39 |
+
print('min value is ', torch.min(y))
|
| 40 |
+
if torch.max(y) > 1.:
|
| 41 |
+
print('max value is ', torch.max(y))
|
| 42 |
+
|
| 43 |
+
global mel_basis, hann_window
|
| 44 |
+
if fmax not in mel_basis:
|
| 45 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 46 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 47 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 48 |
+
|
| 49 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
| 50 |
+
y = y.squeeze(1)
|
| 51 |
+
|
| 52 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
| 53 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
| 54 |
+
|
| 55 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
| 56 |
+
|
| 57 |
+
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
| 58 |
+
spec = spectral_normalize_torch(spec)
|
| 59 |
+
|
| 60 |
+
return spec
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def power_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 64 |
+
if torch.min(y) < -1.:
|
| 65 |
+
print('min value is ', torch.min(y))
|
| 66 |
+
if torch.max(y) > 1.:
|
| 67 |
+
print('max value is ', torch.max(y))
|
| 68 |
+
|
| 69 |
+
global mel_basis, hann_window
|
| 70 |
+
if fmax not in mel_basis:
|
| 71 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 72 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 73 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 74 |
+
|
| 75 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
| 76 |
+
y = y.squeeze(1)
|
| 77 |
+
|
| 78 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
| 79 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
| 80 |
+
|
| 81 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
| 82 |
+
spec = spectral_normalize_torch(spec)
|
| 83 |
+
|
| 84 |
+
return spec
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def mel_from_power_spectrogram(spec, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 88 |
+
global mel_basis, hann_window
|
| 89 |
+
spec = spectral_de_normalize_torch(spec)
|
| 90 |
+
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(spec.device)], spec)
|
| 91 |
+
spec = spectral_normalize_torch(spec)
|
| 92 |
+
|
| 93 |
+
return spec
|
funcineforge/models/modules/hifigan/nsf_utils.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Neural Source Filter based modules implementation.
|
| 3 |
+
|
| 4 |
+
Neural source-filter waveform models for statistical parametric speech synthesis
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 15 |
+
from torch.distributions.uniform import Uniform
|
| 16 |
+
from torch.distributions.normal import Normal
|
| 17 |
+
|
| 18 |
+
class SineGen(torch.nn.Module):
|
| 19 |
+
""" Definition of sine generator
|
| 20 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 21 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 22 |
+
voiced_threshold = 0,
|
| 23 |
+
flag_for_pulse=False)
|
| 24 |
+
samp_rate: sampling rate in Hz
|
| 25 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 26 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 27 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 28 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 29 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 30 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 31 |
+
segment is always sin(np.pi) or cos(0)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
| 35 |
+
sine_amp=0.1, noise_std=0.003,
|
| 36 |
+
voiced_threshold=0):
|
| 37 |
+
super(SineGen, self).__init__()
|
| 38 |
+
self.sine_amp = sine_amp
|
| 39 |
+
self.noise_std = noise_std
|
| 40 |
+
self.harmonic_num = harmonic_num
|
| 41 |
+
self.sampling_rate = samp_rate
|
| 42 |
+
self.voiced_threshold = voiced_threshold
|
| 43 |
+
|
| 44 |
+
def _f02uv(self, f0):
|
| 45 |
+
# generate uv signal
|
| 46 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 47 |
+
return uv
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def forward(self, f0):
|
| 51 |
+
"""
|
| 52 |
+
:param f0: [B, 1, sample_len], Hz
|
| 53 |
+
:return: [B, 1, sample_len]
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
| 57 |
+
for i in range(self.harmonic_num + 1):
|
| 58 |
+
F_mat[:, i:i+1, :] = f0 * (i+1) / self.sampling_rate
|
| 59 |
+
|
| 60 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
| 61 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
| 62 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
| 63 |
+
phase_vec[:, 0, :] = 0
|
| 64 |
+
|
| 65 |
+
# generate sine waveforms
|
| 66 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
| 67 |
+
|
| 68 |
+
# generate uv signal
|
| 69 |
+
uv = self._f02uv(f0)
|
| 70 |
+
|
| 71 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 72 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 73 |
+
# . for voiced regions is self.noise_std
|
| 74 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 75 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 76 |
+
|
| 77 |
+
# first: set the unvoiced part to 0 by uv
|
| 78 |
+
# then: additive noise
|
| 79 |
+
sine_waves = sine_waves * uv + noise
|
| 80 |
+
return sine_waves, uv, noise
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 84 |
+
""" SourceModule for hn-nsf
|
| 85 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 86 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 87 |
+
sampling_rate: sampling_rate in Hz
|
| 88 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 89 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 90 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 91 |
+
note that amplitude of noise in unvoiced is decided
|
| 92 |
+
by sine_amp
|
| 93 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 94 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 95 |
+
F0_sampled (batchsize, length, 1)
|
| 96 |
+
Sine_source (batchsize, length, 1)
|
| 97 |
+
noise_source (batchsize, length 1)
|
| 98 |
+
uv (batchsize, length, 1)
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
| 102 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 103 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 104 |
+
|
| 105 |
+
self.sine_amp = sine_amp
|
| 106 |
+
self.noise_std = add_noise_std
|
| 107 |
+
|
| 108 |
+
# to produce sine waveforms
|
| 109 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
| 110 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 111 |
+
|
| 112 |
+
# to merge source harmonics into a single excitation
|
| 113 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 114 |
+
self.l_tanh = torch.nn.Tanh()
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
"""
|
| 118 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 119 |
+
F0_sampled (batchsize, length, 1)
|
| 120 |
+
Sine_source (batchsize, length, 1)
|
| 121 |
+
noise_source (batchsize, length 1)
|
| 122 |
+
"""
|
| 123 |
+
# source for harmonic branch
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1,2))
|
| 126 |
+
sine_wavs = sine_wavs.transpose(1,2)
|
| 127 |
+
uv = uv.transpose(1,2)
|
| 128 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 129 |
+
|
| 130 |
+
# source for noise branch, in the same shape as uv
|
| 131 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 132 |
+
return sine_merge, noise, uv
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class SourceModule(torch.nn.Module):
|
| 136 |
+
def __init__(self,
|
| 137 |
+
nb_harmonics: int,
|
| 138 |
+
upsample_ratio: int,
|
| 139 |
+
sampling_rate: int,
|
| 140 |
+
alpha: float = 0.1,
|
| 141 |
+
sigma: float = 0.003,
|
| 142 |
+
voiced_threshold: float = 10
|
| 143 |
+
):
|
| 144 |
+
super(SourceModule, self).__init__()
|
| 145 |
+
|
| 146 |
+
self.nb_harmonics = nb_harmonics
|
| 147 |
+
self.upsample_ratio = upsample_ratio
|
| 148 |
+
self.sampling_rate = sampling_rate
|
| 149 |
+
self.alpha = alpha
|
| 150 |
+
self.sigma = sigma
|
| 151 |
+
self.voiced_threshold = voiced_threshold
|
| 152 |
+
|
| 153 |
+
self.ffn = nn.Sequential(
|
| 154 |
+
weight_norm(nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
|
| 155 |
+
nn.Tanh())
|
| 156 |
+
|
| 157 |
+
def f02uv(self, f0):
|
| 158 |
+
# generate uv signal
|
| 159 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 160 |
+
return uv
|
| 161 |
+
|
| 162 |
+
def forward(self, f0):
|
| 163 |
+
"""
|
| 164 |
+
:param f0: [B, 1, frame_len], Hz
|
| 165 |
+
:return: [B, 1, sample_len]
|
| 166 |
+
"""
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
uv = self.f02uv(f0)
|
| 169 |
+
f0_samples = F.interpolate(f0, scale_factor=(self.upsample_ratio), mode='nearest')
|
| 170 |
+
uv_samples = F.interpolate(uv, scale_factor=(self.upsample_ratio), mode='nearest')
|
| 171 |
+
|
| 172 |
+
F_mat = torch.zeros((f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(f0_samples.device)
|
| 173 |
+
for i in range(self.nb_harmonics + 1):
|
| 174 |
+
F_mat[:, i:i+1, :] = f0_samples * (i+1) / self.sampling_rate
|
| 175 |
+
|
| 176 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
| 177 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
| 178 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.nb_harmonics + 1, 1)).to(F_mat.device)
|
| 179 |
+
phase_vec[:, 0, :] = 0
|
| 180 |
+
|
| 181 |
+
n_dist = Normal(loc=0., scale=self.sigma)
|
| 182 |
+
noise = n_dist.sample(sample_shape=(f0_samples.size(0), self.nb_harmonics + 1, f0_samples.size(-1))).to(F_mat.device)
|
| 183 |
+
|
| 184 |
+
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
|
| 185 |
+
e_unvoice = self.alpha / 3 / self.sigma * noise
|
| 186 |
+
|
| 187 |
+
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
|
| 188 |
+
|
| 189 |
+
return self.ffn(e)
|
| 190 |
+
|
| 191 |
+
def remove_weight_norm(self):
|
| 192 |
+
remove_weight_norm(self.ffn[0])
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class ConvRNNF0Predictor(nn.Module):
|
| 196 |
+
def __init__(self,
|
| 197 |
+
num_class: int = 1,
|
| 198 |
+
in_channels: int = 80,
|
| 199 |
+
cond_channels: int = 512,
|
| 200 |
+
use_cond_rnn: bool = True,
|
| 201 |
+
bidirectional_rnn: bool = False,
|
| 202 |
+
):
|
| 203 |
+
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
self.num_class = num_class
|
| 207 |
+
self.use_cond_rnn = use_cond_rnn
|
| 208 |
+
|
| 209 |
+
self.condnet = nn.Sequential(
|
| 210 |
+
weight_norm(
|
| 211 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
| 212 |
+
),
|
| 213 |
+
nn.ELU(),
|
| 214 |
+
weight_norm(
|
| 215 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 216 |
+
),
|
| 217 |
+
nn.ELU(),
|
| 218 |
+
weight_norm(
|
| 219 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 220 |
+
),
|
| 221 |
+
nn.ELU(),
|
| 222 |
+
weight_norm(
|
| 223 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 224 |
+
),
|
| 225 |
+
nn.ELU(),
|
| 226 |
+
weight_norm(
|
| 227 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 228 |
+
),
|
| 229 |
+
nn.ELU(),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if self.use_cond_rnn:
|
| 233 |
+
self.rnn = nn.GRU(
|
| 234 |
+
cond_channels,
|
| 235 |
+
cond_channels // 2 if bidirectional_rnn else cond_channels,
|
| 236 |
+
num_layers=1,
|
| 237 |
+
batch_first=True,
|
| 238 |
+
bidirectional=bidirectional_rnn,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
| 242 |
+
|
| 243 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
x = self.condnet(x)
|
| 245 |
+
if self.use_cond_rnn:
|
| 246 |
+
x, _ = self.rnn(x.transpose(1, 2))
|
| 247 |
+
else:
|
| 248 |
+
x = x.transpose(1, 2)
|
| 249 |
+
|
| 250 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
funcineforge/models/specaug/__init__.py
ADDED
|
File without changes
|
funcineforge/models/specaug/mask_along_axis.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Sequence
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def mask_along_axis(
|
| 8 |
+
spec: torch.Tensor,
|
| 9 |
+
spec_lengths: torch.Tensor,
|
| 10 |
+
mask_width_range: Sequence[int] = (0, 30),
|
| 11 |
+
dim: int = 1,
|
| 12 |
+
num_mask: int = 2,
|
| 13 |
+
replace_with_zero: bool = True,
|
| 14 |
+
fill_value: float = 0.0,
|
| 15 |
+
):
|
| 16 |
+
"""Apply mask along the specified direction.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
spec: (Batch, Length, Freq)
|
| 20 |
+
spec_lengths: (Length): Not using lengths in this implementation
|
| 21 |
+
mask_width_range: Select the width randomly between this range
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
org_size = spec.size()
|
| 25 |
+
if spec.dim() == 4:
|
| 26 |
+
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
|
| 27 |
+
spec = spec.view(-1, spec.size(2), spec.size(3))
|
| 28 |
+
|
| 29 |
+
B = spec.shape[0]
|
| 30 |
+
# D = Length or Freq
|
| 31 |
+
D = spec.shape[dim]
|
| 32 |
+
# mask_length: (B, num_mask, 1)
|
| 33 |
+
mask_length = torch.randint(
|
| 34 |
+
mask_width_range[0],
|
| 35 |
+
mask_width_range[1],
|
| 36 |
+
(B, num_mask),
|
| 37 |
+
device=spec.device,
|
| 38 |
+
).unsqueeze(2)
|
| 39 |
+
|
| 40 |
+
# mask_pos: (B, num_mask, 1)
|
| 41 |
+
mask_pos = torch.randint(
|
| 42 |
+
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
|
| 43 |
+
).unsqueeze(2)
|
| 44 |
+
|
| 45 |
+
# aran: (1, 1, D)
|
| 46 |
+
aran = torch.arange(D, device=spec.device)[None, None, :]
|
| 47 |
+
# mask: (Batch, num_mask, D)
|
| 48 |
+
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
|
| 49 |
+
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
|
| 50 |
+
mask = mask.any(dim=1)
|
| 51 |
+
if dim == 1:
|
| 52 |
+
# mask: (Batch, Length, 1)
|
| 53 |
+
mask = mask.unsqueeze(2)
|
| 54 |
+
elif dim == 2:
|
| 55 |
+
# mask: (Batch, 1, Freq)
|
| 56 |
+
mask = mask.unsqueeze(1)
|
| 57 |
+
|
| 58 |
+
if replace_with_zero:
|
| 59 |
+
value = fill_value
|
| 60 |
+
else:
|
| 61 |
+
value = spec.mean()
|
| 62 |
+
|
| 63 |
+
spec = spec.masked_fill(mask, value)
|
| 64 |
+
spec = spec.view(*org_size)
|
| 65 |
+
return spec, spec_lengths
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class MaskAlongAxis(torch.nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
mask_width_range: Union[int, Sequence[int]] = (0, 30),
|
| 72 |
+
num_mask: int = 2,
|
| 73 |
+
dim: Union[int, str] = "time",
|
| 74 |
+
replace_with_zero: bool = True,
|
| 75 |
+
fill_value: float = 0.0,
|
| 76 |
+
):
|
| 77 |
+
if isinstance(mask_width_range, int):
|
| 78 |
+
mask_width_range = (0, mask_width_range)
|
| 79 |
+
if len(mask_width_range) != 2:
|
| 80 |
+
raise TypeError(
|
| 81 |
+
f"mask_width_range must be a tuple of int and int values: " f"{mask_width_range}",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
assert mask_width_range[1] > mask_width_range[0]
|
| 85 |
+
if isinstance(dim, str):
|
| 86 |
+
if dim == "time":
|
| 87 |
+
dim = 1
|
| 88 |
+
elif dim == "freq":
|
| 89 |
+
dim = 2
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError("dim must be int, 'time' or 'freq'")
|
| 92 |
+
if dim == 1:
|
| 93 |
+
self.mask_axis = "time"
|
| 94 |
+
elif dim == 2:
|
| 95 |
+
self.mask_axis = "freq"
|
| 96 |
+
else:
|
| 97 |
+
self.mask_axis = "unknown"
|
| 98 |
+
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.mask_width_range = mask_width_range
|
| 101 |
+
self.num_mask = num_mask
|
| 102 |
+
self.dim = dim
|
| 103 |
+
self.replace_with_zero = replace_with_zero
|
| 104 |
+
self.fill_value = fill_value
|
| 105 |
+
|
| 106 |
+
def extra_repr(self):
|
| 107 |
+
return (
|
| 108 |
+
f"mask_width_range={self.mask_width_range}, "
|
| 109 |
+
f"num_mask={self.num_mask}, axis={self.mask_axis}"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
| 113 |
+
"""Forward function.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
spec: (Batch, Length, Freq)
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
return mask_along_axis(
|
| 120 |
+
spec,
|
| 121 |
+
spec_lengths,
|
| 122 |
+
mask_width_range=self.mask_width_range,
|
| 123 |
+
dim=self.dim,
|
| 124 |
+
num_mask=self.num_mask,
|
| 125 |
+
replace_with_zero=self.replace_with_zero,
|
| 126 |
+
fill_value=self.fill_value,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
|
| 131 |
+
"""Mask input spec along a specified axis with variable maximum width.
|
| 132 |
+
|
| 133 |
+
Formula:
|
| 134 |
+
max_width = max_width_ratio * seq_len
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
|
| 140 |
+
num_mask: int = 2,
|
| 141 |
+
dim: Union[int, str] = "time",
|
| 142 |
+
replace_with_zero: bool = True,
|
| 143 |
+
fill_value: float = 0.0,
|
| 144 |
+
):
|
| 145 |
+
if isinstance(mask_width_ratio_range, float):
|
| 146 |
+
mask_width_ratio_range = (0.0, mask_width_ratio_range)
|
| 147 |
+
if len(mask_width_ratio_range) != 2:
|
| 148 |
+
raise TypeError(
|
| 149 |
+
f"mask_width_ratio_range must be a tuple of float and float values: "
|
| 150 |
+
f"{mask_width_ratio_range}",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
|
| 154 |
+
if isinstance(dim, str):
|
| 155 |
+
if dim == "time":
|
| 156 |
+
dim = 1
|
| 157 |
+
elif dim == "freq":
|
| 158 |
+
dim = 2
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError("dim must be int, 'time' or 'freq'")
|
| 161 |
+
if dim == 1:
|
| 162 |
+
self.mask_axis = "time"
|
| 163 |
+
elif dim == 2:
|
| 164 |
+
self.mask_axis = "freq"
|
| 165 |
+
else:
|
| 166 |
+
self.mask_axis = "unknown"
|
| 167 |
+
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.mask_width_ratio_range = mask_width_ratio_range
|
| 170 |
+
self.num_mask = num_mask
|
| 171 |
+
self.dim = dim
|
| 172 |
+
self.replace_with_zero = replace_with_zero
|
| 173 |
+
self.fill_value = fill_value
|
| 174 |
+
|
| 175 |
+
def extra_repr(self):
|
| 176 |
+
return (
|
| 177 |
+
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
|
| 178 |
+
f"num_mask={self.num_mask}, axis={self.mask_axis}"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
| 182 |
+
"""Forward function.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
spec: (Batch, Length, Freq)
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
max_seq_len = spec.shape[self.dim]
|
| 189 |
+
min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
|
| 190 |
+
min_mask_width = max([0, min_mask_width])
|
| 191 |
+
max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
|
| 192 |
+
max_mask_width = min([max_seq_len, max_mask_width])
|
| 193 |
+
|
| 194 |
+
if max_mask_width > min_mask_width:
|
| 195 |
+
return mask_along_axis(
|
| 196 |
+
spec,
|
| 197 |
+
spec_lengths,
|
| 198 |
+
mask_width_range=(min_mask_width, max_mask_width),
|
| 199 |
+
dim=self.dim,
|
| 200 |
+
num_mask=self.num_mask,
|
| 201 |
+
replace_with_zero=self.replace_with_zero,
|
| 202 |
+
fill_value=self.fill_value,
|
| 203 |
+
)
|
| 204 |
+
return spec, spec_lengths
|
funcineforge/models/specaug/specaug.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SpecAugment module."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from typing import Sequence
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
from funcineforge.models.specaug.mask_along_axis import MaskAlongAxis
|
| 8 |
+
from funcineforge.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
|
| 9 |
+
from funcineforge.models.specaug.time_warp import TimeWarp
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SpecAug(nn.Module):
|
| 15 |
+
"""Implementation of SpecAug.
|
| 16 |
+
|
| 17 |
+
Reference:
|
| 18 |
+
Daniel S. Park et al.
|
| 19 |
+
"SpecAugment: A Simple Data
|
| 20 |
+
Augmentation Method for Automatic Speech Recognition"
|
| 21 |
+
|
| 22 |
+
.. warning::
|
| 23 |
+
When using cuda mode, time_warp doesn't have reproducibility
|
| 24 |
+
due to `torch.nn.functional.interpolate`.
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
apply_time_warp: bool = True,
|
| 31 |
+
time_warp_window: int = 5,
|
| 32 |
+
time_warp_mode: str = "bicubic",
|
| 33 |
+
apply_freq_mask: bool = True,
|
| 34 |
+
freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
|
| 35 |
+
num_freq_mask: int = 2,
|
| 36 |
+
apply_time_mask: bool = True,
|
| 37 |
+
time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
|
| 38 |
+
time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
|
| 39 |
+
num_time_mask: int = 2,
|
| 40 |
+
fill_value: float = 0.0,
|
| 41 |
+
):
|
| 42 |
+
if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
|
| 43 |
+
raise ValueError("Either one of time_warp, time_mask, or freq_mask should be applied")
|
| 44 |
+
if (
|
| 45 |
+
apply_time_mask
|
| 46 |
+
and (time_mask_width_range is not None)
|
| 47 |
+
and (time_mask_width_ratio_range is not None)
|
| 48 |
+
):
|
| 49 |
+
raise ValueError(
|
| 50 |
+
'Either one of "time_mask_width_range" or '
|
| 51 |
+
'"time_mask_width_ratio_range" can be used'
|
| 52 |
+
)
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.apply_time_warp = apply_time_warp
|
| 55 |
+
self.apply_freq_mask = apply_freq_mask
|
| 56 |
+
self.apply_time_mask = apply_time_mask
|
| 57 |
+
|
| 58 |
+
if apply_time_warp:
|
| 59 |
+
self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
|
| 60 |
+
else:
|
| 61 |
+
self.time_warp = None
|
| 62 |
+
|
| 63 |
+
if apply_freq_mask:
|
| 64 |
+
self.freq_mask = MaskAlongAxis(
|
| 65 |
+
dim="freq",
|
| 66 |
+
mask_width_range=freq_mask_width_range,
|
| 67 |
+
num_mask=num_freq_mask,
|
| 68 |
+
fill_value=fill_value,
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
self.freq_mask = None
|
| 72 |
+
|
| 73 |
+
if apply_time_mask:
|
| 74 |
+
if time_mask_width_range is not None:
|
| 75 |
+
self.time_mask = MaskAlongAxis(
|
| 76 |
+
dim="time",
|
| 77 |
+
mask_width_range=time_mask_width_range,
|
| 78 |
+
num_mask=num_time_mask,
|
| 79 |
+
fill_value=fill_value,
|
| 80 |
+
)
|
| 81 |
+
elif time_mask_width_ratio_range is not None:
|
| 82 |
+
self.time_mask = MaskAlongAxisVariableMaxWidth(
|
| 83 |
+
dim="time",
|
| 84 |
+
mask_width_ratio_range=time_mask_width_ratio_range,
|
| 85 |
+
num_mask=num_time_mask,
|
| 86 |
+
fill_value=fill_value,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
'Either one of "time_mask_width_range" or '
|
| 91 |
+
'"time_mask_width_ratio_range" should be used.'
|
| 92 |
+
)
|
| 93 |
+
else:
|
| 94 |
+
self.time_mask = None
|
| 95 |
+
|
| 96 |
+
def forward(self, x, x_lengths=None):
|
| 97 |
+
if self.time_warp is not None:
|
| 98 |
+
x, x_lengths = self.time_warp(x, x_lengths)
|
| 99 |
+
if self.freq_mask is not None:
|
| 100 |
+
x, x_lengths = self.freq_mask(x, x_lengths)
|
| 101 |
+
if self.time_mask is not None:
|
| 102 |
+
x, x_lengths = self.time_mask(x, x_lengths)
|
| 103 |
+
return x, x_lengths
|
funcineforge/models/specaug/time_warp.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Time warp module."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from funcineforge.models.utils.nets_utils import pad_list
|
| 6 |
+
|
| 7 |
+
DEFAULT_TIME_WARP_MODE = "bicubic"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
| 11 |
+
"""Time warping using torch.interpolate.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
x: (Batch, Time, Freq)
|
| 15 |
+
window: time warp parameter
|
| 16 |
+
mode: Interpolate mode
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
# bicubic supports 4D or more dimension tensor
|
| 20 |
+
org_size = x.size()
|
| 21 |
+
if x.dim() == 3:
|
| 22 |
+
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
|
| 23 |
+
x = x[:, None]
|
| 24 |
+
|
| 25 |
+
t = x.shape[2]
|
| 26 |
+
if t - window <= window:
|
| 27 |
+
return x.view(*org_size)
|
| 28 |
+
|
| 29 |
+
center = torch.randint(window, t - window, (1,))[0]
|
| 30 |
+
warped = torch.randint(center - window, center + window, (1,))[0] + 1
|
| 31 |
+
|
| 32 |
+
# left: (Batch, Channel, warped, Freq)
|
| 33 |
+
# right: (Batch, Channel, time - warped, Freq)
|
| 34 |
+
left = torch.nn.functional.interpolate(
|
| 35 |
+
x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
|
| 36 |
+
)
|
| 37 |
+
right = torch.nn.functional.interpolate(
|
| 38 |
+
x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
if x.requires_grad:
|
| 42 |
+
x = torch.cat([left, right], dim=-2)
|
| 43 |
+
else:
|
| 44 |
+
x[:, :, :warped] = left
|
| 45 |
+
x[:, :, warped:] = right
|
| 46 |
+
|
| 47 |
+
return x.view(*org_size)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TimeWarp(torch.nn.Module):
|
| 51 |
+
"""Time warping using torch.interpolate.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
window: time warp parameter
|
| 55 |
+
mode: Interpolate mode
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.window = window
|
| 61 |
+
self.mode = mode
|
| 62 |
+
|
| 63 |
+
def extra_repr(self):
|
| 64 |
+
return f"window={self.window}, mode={self.mode}"
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
|
| 67 |
+
"""Forward function.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
x: (Batch, Time, Freq)
|
| 71 |
+
x_lengths: (Batch,)
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
|
| 75 |
+
# Note that applying same warping for each sample
|
| 76 |
+
y = time_warp(x, window=self.window, mode=self.mode)
|
| 77 |
+
else:
|
| 78 |
+
# FIXME(kamo): I have no idea to batchify Timewarp
|
| 79 |
+
ys = []
|
| 80 |
+
for i in range(x.size(0)):
|
| 81 |
+
_y = time_warp(
|
| 82 |
+
x[i][None, : x_lengths[i]],
|
| 83 |
+
window=self.window,
|
| 84 |
+
mode=self.mode,
|
| 85 |
+
)[0]
|
| 86 |
+
ys.append(_y)
|
| 87 |
+
y = pad_list(ys, 0.0)
|
| 88 |
+
|
| 89 |
+
return y, x_lengths
|
funcineforge/models/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
funcineforge/models/utils/llm_decoding.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import nullcontext
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Union
|
| 5 |
+
from funcineforge.utils.hinter import hint_once
|
| 6 |
+
import numpy as np
|
| 7 |
+
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LLMDecoder(nn.Module):
|
| 11 |
+
def __init__(self, **kwargs):
|
| 12 |
+
super(LLMDecoder, self).__init__()
|
| 13 |
+
self.eos_token = kwargs["eos"]
|
| 14 |
+
if isinstance(self.eos_token, int):
|
| 15 |
+
self.eos_token = [self.eos_token]
|
| 16 |
+
self.token_embeder = kwargs["token_embeder"]
|
| 17 |
+
self.ras_conf = kwargs.get("ras_conf", {})
|
| 18 |
+
self.token_offset = kwargs.get("token_offset", 0)
|
| 19 |
+
|
| 20 |
+
def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25, beam_size=1):
|
| 21 |
+
prob, indices = [], []
|
| 22 |
+
cum_prob = 0.0
|
| 23 |
+
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
| 24 |
+
for i in range(len(sorted_idx)):
|
| 25 |
+
# sampling both top-p and numbers.
|
| 26 |
+
if cum_prob < top_p and len(prob) < top_k:
|
| 27 |
+
cum_prob += sorted_value[i]
|
| 28 |
+
prob.append(sorted_value[i])
|
| 29 |
+
indices.append(sorted_idx[i])
|
| 30 |
+
else:
|
| 31 |
+
break
|
| 32 |
+
prob = torch.tensor(prob).to(weighted_scores)
|
| 33 |
+
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
| 34 |
+
sampling_ids = prob.multinomial(beam_size, replacement=True)
|
| 35 |
+
top_ids = indices[sampling_ids]
|
| 36 |
+
return top_ids
|
| 37 |
+
|
| 38 |
+
def random_sampling(self, weighted_scores, beam_size=1):
|
| 39 |
+
top_ids = weighted_scores.softmax(dim=0).multinomial(beam_size, replacement=True)
|
| 40 |
+
return top_ids
|
| 41 |
+
|
| 42 |
+
# Repetition Aware Sampling in VALL-E 2
|
| 43 |
+
def ras_sampling(
|
| 44 |
+
self, weighted_scores, decoded_tokens, *,
|
| 45 |
+
top_p=0.8, top_k=25, win_size=10, tau_r=0.1
|
| 46 |
+
):
|
| 47 |
+
if self.ras_conf is not None:
|
| 48 |
+
top_p = self.ras_conf.get("top_p", top_p)
|
| 49 |
+
top_k = self.ras_conf.get("top_k", top_k)
|
| 50 |
+
win_size = self.ras_conf.get("win_size", win_size)
|
| 51 |
+
tau_r = self.ras_conf.get("tau_r", tau_r)
|
| 52 |
+
|
| 53 |
+
hint_once(f"using Repetition Aware Sampling: top_p: {top_p}, top_k: {top_k},win_size: {win_size}, tau_r: {tau_r}", "ras_sampling")
|
| 54 |
+
top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
| 55 |
+
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()
|
| 56 |
+
if rep_num >= win_size * tau_r:
|
| 57 |
+
top_ids = self.random_sampling(weighted_scores)
|
| 58 |
+
|
| 59 |
+
return top_ids
|
| 60 |
+
|
| 61 |
+
def sampling_ids(
|
| 62 |
+
self,
|
| 63 |
+
weighted_scores: torch.Tensor,
|
| 64 |
+
sampling: Union[bool, int, float] = True,
|
| 65 |
+
decoded_tokens: list = None,
|
| 66 |
+
):
|
| 67 |
+
if isinstance(sampling, bool):
|
| 68 |
+
if sampling:
|
| 69 |
+
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
| 70 |
+
else:
|
| 71 |
+
top_ids = weighted_scores.topk(1)[1]
|
| 72 |
+
elif isinstance(sampling, int):
|
| 73 |
+
prob, indices = weighted_scores.softmax(dim=0).topk(sampling)
|
| 74 |
+
sampling_ids = prob.multinomial(1, replacement=True)
|
| 75 |
+
top_ids = indices[sampling_ids]
|
| 76 |
+
elif isinstance(sampling, float):
|
| 77 |
+
prob, indices = [], []
|
| 78 |
+
cum_prob = 0.0
|
| 79 |
+
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
| 80 |
+
for i in range(len(sorted_idx)):
|
| 81 |
+
# sampling both top-p and numbers.
|
| 82 |
+
if cum_prob < sampling and len(prob) < 25:
|
| 83 |
+
cum_prob += sorted_value[i]
|
| 84 |
+
prob.append(sorted_value[i])
|
| 85 |
+
indices.append(sorted_idx[i])
|
| 86 |
+
else:
|
| 87 |
+
break
|
| 88 |
+
prob = torch.tensor(prob).to(weighted_scores)
|
| 89 |
+
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
| 90 |
+
sampling_ids = prob.multinomial(1, replacement=True)
|
| 91 |
+
top_ids = indices[sampling_ids]
|
| 92 |
+
elif isinstance(sampling, str) and sampling.lower() == "ras":
|
| 93 |
+
top_ids = self.ras_sampling(weighted_scores, decoded_tokens=decoded_tokens)
|
| 94 |
+
else:
|
| 95 |
+
raise NotImplementedError(f"Not implemented for {type(sampling)} sampling")
|
| 96 |
+
|
| 97 |
+
return top_ids
|
| 98 |
+
|
| 99 |
+
def __call__(self, input_embeddings, llm, states, quantize=False, **kwargs):
|
| 100 |
+
max_length = kwargs.get("max_length", 60 * 25)
|
| 101 |
+
min_length = kwargs.get("min_length", 2 * 25)
|
| 102 |
+
sampling = kwargs.get("sampling", True)
|
| 103 |
+
device = kwargs.get("device", "cuda")
|
| 104 |
+
llm_dtype = kwargs.get("llm_dtype", "fp32")
|
| 105 |
+
use_llm_cache = kwargs.get("use_llm_cache", True)
|
| 106 |
+
include_eos = kwargs.get("include_eos", False)
|
| 107 |
+
custom_eos_token = kwargs.get("custom_eos_token", self.eos_token)
|
| 108 |
+
avoid_token = kwargs.get("avoid_token", None)
|
| 109 |
+
|
| 110 |
+
llm_cache = states.get("llm_cache", None)
|
| 111 |
+
out_tokens, hit_eos = [], False
|
| 112 |
+
for i in range(max_length):
|
| 113 |
+
with torch.cuda.amp.autocast(
|
| 114 |
+
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
|
| 115 |
+
) if quantize is False else nullcontext():
|
| 116 |
+
# default attention_mask is causal, no longer need manually construct
|
| 117 |
+
# input_masks = torch.ones((1, input_embeddings.shape[1]), device=input_embeddings.device).to(torch.bool)
|
| 118 |
+
|
| 119 |
+
if (kwargs.get("use_qlora",False) or kwargs.get("infer_use_lora",False)) and (not kwargs.get("infer_lora_merged",False)):
|
| 120 |
+
outputs = llm.base_model.model(
|
| 121 |
+
inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings,
|
| 122 |
+
# attention_mask=input_masks,
|
| 123 |
+
output_hidden_states=True,
|
| 124 |
+
return_dict=True,
|
| 125 |
+
use_cache=use_llm_cache,
|
| 126 |
+
past_key_values=llm_cache,
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
outputs = llm(
|
| 130 |
+
inputs_embeds=input_embeddings.to(torch.bfloat16) if quantize is True else input_embeddings,
|
| 131 |
+
# attention_mask=input_masks,
|
| 132 |
+
output_hidden_states=True,
|
| 133 |
+
return_dict=True,
|
| 134 |
+
use_cache=use_llm_cache,
|
| 135 |
+
past_key_values=llm_cache,
|
| 136 |
+
)
|
| 137 |
+
lm_hidden_states = outputs.hidden_states[-1]
|
| 138 |
+
h = llm.lm_head(lm_hidden_states[:, -1])
|
| 139 |
+
# logp = h.log_softmax(dim=-1).squeeze(0)
|
| 140 |
+
logp = h.squeeze(0)
|
| 141 |
+
if use_llm_cache:
|
| 142 |
+
llm_cache = outputs.past_key_values
|
| 143 |
+
|
| 144 |
+
pred = torch.log_softmax(logp, dim=-1)
|
| 145 |
+
if min_length is not None and i < min_length:
|
| 146 |
+
for x in custom_eos_token:
|
| 147 |
+
if pred.dtype == torch.bfloat16:
|
| 148 |
+
pred[x] = float(np.finfo(np.float16).min)
|
| 149 |
+
else:
|
| 150 |
+
pred[x] = float(np.finfo(np.float32).min)
|
| 151 |
+
if avoid_token is not None and len(avoid_token) > 0:
|
| 152 |
+
for x in avoid_token:
|
| 153 |
+
if pred.dtype == torch.bfloat16:
|
| 154 |
+
pred[x] = float(np.finfo(np.float16).min)
|
| 155 |
+
else:
|
| 156 |
+
pred[x] = float(np.finfo(np.float32).min)
|
| 157 |
+
top_id = self.sampling_ids(pred, sampling, out_tokens)[0].item()
|
| 158 |
+
|
| 159 |
+
if top_id in custom_eos_token:
|
| 160 |
+
if include_eos:
|
| 161 |
+
out_tokens.append(top_id)
|
| 162 |
+
hit_eos = True
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
out_tokens.append(top_id)
|
| 166 |
+
if use_llm_cache:
|
| 167 |
+
input_embeddings = self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset)
|
| 168 |
+
else:
|
| 169 |
+
input_embeddings = torch.cat([
|
| 170 |
+
input_embeddings,
|
| 171 |
+
self.token_embeder(torch.tensor([[top_id]], dtype=torch.int64, device=device) + self.token_offset)
|
| 172 |
+
], dim=1)
|
| 173 |
+
|
| 174 |
+
out_tokens = torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
| 175 |
+
|
| 176 |
+
states = {"llm_cache": llm_cache}
|
| 177 |
+
|
| 178 |
+
return out_tokens, hit_eos, states
|
funcineforge/models/utils/mask_along_axis.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MaskTailVariableMaxWidth(torch.nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
|
| 10 |
+
replace_value: float = 0.0,
|
| 11 |
+
):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.mask_width_ratio_range = mask_width_ratio_range
|
| 14 |
+
self.replace_value = replace_value
|
| 15 |
+
|
| 16 |
+
def extra_repr(self):
|
| 17 |
+
return (
|
| 18 |
+
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
|
| 22 |
+
bb, tt, _ = spec.shape
|
| 23 |
+
|
| 24 |
+
mask_width_ratio = torch.rand((bb, 1), device=spec.device)
|
| 25 |
+
ratio_st, ratio_ed = self.mask_width_ratio_range
|
| 26 |
+
mask_width_ratio = mask_width_ratio * (ratio_ed - ratio_st) + ratio_st
|
| 27 |
+
mask_length = (mask_width_ratio * spec_lengths.unsqueeze(1)).to(spec_lengths)
|
| 28 |
+
|
| 29 |
+
# mask_pos: (B, 1)
|
| 30 |
+
mask_start_pos = spec_lengths.unsqueeze(-1) - mask_length
|
| 31 |
+
|
| 32 |
+
aran = torch.arange(tt, device=spec.device)[None, :]
|
| 33 |
+
# mask: (Batch, L)
|
| 34 |
+
mask = aran < mask_start_pos
|
| 35 |
+
# (Batch, L) -> (Batch, L, 1)
|
| 36 |
+
mask = mask.unsqueeze(2)
|
| 37 |
+
|
| 38 |
+
return mask
|
| 39 |
+
|
| 40 |
+
class PrefixMaskVariableMaxWidth(torch.nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
|
| 44 |
+
replace_value: float = 0.0,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.mask_width_ratio_range = mask_width_ratio_range
|
| 48 |
+
self.replace_value = replace_value
|
| 49 |
+
|
| 50 |
+
def extra_repr(self):
|
| 51 |
+
return (
|
| 52 |
+
f"mask_width_ratio_range={self.mask_width_ratio_range}, "
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None, return_mask: bool = False):
|
| 56 |
+
bb, tt, _ = spec.shape
|
| 57 |
+
|
| 58 |
+
mask_width_ratio_range = torch.tensor(self.mask_width_ratio_range, dtype=torch.float32, device=spec.device)
|
| 59 |
+
mask_width_range = (mask_width_ratio_range * tt).long()
|
| 60 |
+
mask_length = torch.randint(
|
| 61 |
+
mask_width_range[0],
|
| 62 |
+
mask_width_range[1],
|
| 63 |
+
(bb, 1),
|
| 64 |
+
device=spec.device,
|
| 65 |
+
).unsqueeze(2)
|
| 66 |
+
|
| 67 |
+
# mask_pos: (B, num_mask, 1)
|
| 68 |
+
mask_pos = tt - mask_length
|
| 69 |
+
|
| 70 |
+
aran = torch.arange(tt, device=spec.device)[None, None, :]
|
| 71 |
+
# mask: (Batch, num_mask, L)
|
| 72 |
+
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
|
| 73 |
+
# Multiply masks: (Batch, num_mask, L) -> (Batch, L, 1)
|
| 74 |
+
mask = mask.any(dim=1).unsqueeze(2)
|
| 75 |
+
|
| 76 |
+
return mask
|
funcineforge/models/utils/masks.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
| 4 |
+
masks: torch.Tensor,
|
| 5 |
+
use_dynamic_chunk: bool,
|
| 6 |
+
use_dynamic_left_chunk: bool,
|
| 7 |
+
decoding_chunk_size: int,
|
| 8 |
+
static_chunk_size: int,
|
| 9 |
+
num_decoding_left_chunks: int,
|
| 10 |
+
enable_full_context: bool = True):
|
| 11 |
+
""" Apply optional mask for encoder.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
| 15 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
| 16 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
| 17 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
| 18 |
+
training.
|
| 19 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
| 20 |
+
0: default for training, use random dynamic chunk.
|
| 21 |
+
<0: for decoding, use full chunk.
|
| 22 |
+
>0: for decoding, use fixed chunk size as set.
|
| 23 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
| 24 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
| 25 |
+
this parameter will be ignored
|
| 26 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 27 |
+
the chunk size is decoding_chunk_size.
|
| 28 |
+
>=0: use num_decoding_left_chunks
|
| 29 |
+
<0: use all left chunks
|
| 30 |
+
enable_full_context (bool):
|
| 31 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
| 32 |
+
False: chunk size ~ U[1, 25]
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
torch.Tensor: chunk mask of the input xs.
|
| 36 |
+
"""
|
| 37 |
+
# Whether to use chunk mask or not
|
| 38 |
+
if use_dynamic_chunk:
|
| 39 |
+
max_len = xs.size(1)
|
| 40 |
+
if decoding_chunk_size < 0:
|
| 41 |
+
chunk_size = max_len
|
| 42 |
+
num_left_chunks = -1
|
| 43 |
+
elif decoding_chunk_size > 0:
|
| 44 |
+
chunk_size = decoding_chunk_size
|
| 45 |
+
num_left_chunks = num_decoding_left_chunks
|
| 46 |
+
else:
|
| 47 |
+
# chunk size is either [1, 25] or full context(max_len).
|
| 48 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
| 49 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
| 50 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
| 51 |
+
num_left_chunks = -1
|
| 52 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
| 53 |
+
chunk_size = max_len
|
| 54 |
+
else:
|
| 55 |
+
chunk_size = chunk_size % 25 + 1
|
| 56 |
+
if use_dynamic_left_chunk:
|
| 57 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
| 58 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
| 59 |
+
(1, )).item()
|
| 60 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
| 61 |
+
num_left_chunks,
|
| 62 |
+
xs.device) # (L, L)
|
| 63 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 64 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 65 |
+
elif static_chunk_size > 0:
|
| 66 |
+
num_left_chunks = num_decoding_left_chunks
|
| 67 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
| 68 |
+
num_left_chunks,
|
| 69 |
+
xs.device) # (L, L)
|
| 70 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 71 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 72 |
+
else:
|
| 73 |
+
chunk_masks = masks
|
| 74 |
+
assert chunk_masks.dtype == torch.bool
|
| 75 |
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
| 76 |
+
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
| 77 |
+
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
| 78 |
+
return chunk_masks
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def subsequent_chunk_mask(
|
| 82 |
+
size: int,
|
| 83 |
+
chunk_size: int,
|
| 84 |
+
num_left_chunks: int = -1,
|
| 85 |
+
device: torch.device = torch.device("cpu"),
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
| 88 |
+
this is for streaming encoder
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
size (int): size of mask
|
| 92 |
+
chunk_size (int): size of chunk
|
| 93 |
+
num_left_chunks (int): number of left chunks
|
| 94 |
+
<0: use full chunk
|
| 95 |
+
>=0: use num_left_chunks
|
| 96 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
torch.Tensor: mask
|
| 100 |
+
|
| 101 |
+
Examples:
|
| 102 |
+
>>> subsequent_chunk_mask(4, 2)
|
| 103 |
+
[[1, 1, 0, 0],
|
| 104 |
+
[1, 1, 0, 0],
|
| 105 |
+
[1, 1, 1, 1],
|
| 106 |
+
[1, 1, 1, 1]]
|
| 107 |
+
"""
|
| 108 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
| 109 |
+
pos_idx = torch.arange(size, device=device)
|
| 110 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
| 111 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
| 112 |
+
return ret
|
| 113 |
+
|
| 114 |
+
def causal_block_mask(size, block_size=1, device="cpu", dtype=torch.bool):
|
| 115 |
+
"""Create mask for subsequent steps (size, size).
|
| 116 |
+
|
| 117 |
+
:param int size: size of mask
|
| 118 |
+
:param int block_size: block size of mask
|
| 119 |
+
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
| 120 |
+
:param torch.dtype dtype: result dtype
|
| 121 |
+
:rtype: torch.Tensor
|
| 122 |
+
>>> causal_block_mask(4, 2)
|
| 123 |
+
[[1, 1, 0, 0],
|
| 124 |
+
[1, 1, 0, 0],
|
| 125 |
+
[1, 1, 1, 1],
|
| 126 |
+
[1, 1, 1, 1]]
|
| 127 |
+
"""
|
| 128 |
+
# assert size % block_size == 0
|
| 129 |
+
pos_idx = torch.arange(size, device=device)
|
| 130 |
+
block_value = (torch.div(pos_idx, block_size, rounding_mode='trunc') + 1) * block_size
|
| 131 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
| 132 |
+
return ret.to(dtype)
|
funcineforge/models/utils/nets_utils.py
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""Network related utility tools."""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def to_device(m, x):
|
| 13 |
+
"""Send tensor into the device of the module.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
m (torch.nn.Module): Torch module.
|
| 17 |
+
x (Tensor): Torch tensor.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor: Torch tensor located in the same place as torch module.
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
if isinstance(m, torch.nn.Module):
|
| 24 |
+
device = next(m.parameters()).device
|
| 25 |
+
elif isinstance(m, torch.Tensor):
|
| 26 |
+
device = m.device
|
| 27 |
+
else:
|
| 28 |
+
raise TypeError("Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}")
|
| 29 |
+
return x.to(device)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def pad_list(xs, pad_value):
|
| 33 |
+
"""Perform padding for the list of tensors.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
| 37 |
+
pad_value (float): Value for padding.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
| 41 |
+
|
| 42 |
+
Examples:
|
| 43 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
| 44 |
+
>>> x
|
| 45 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
| 46 |
+
>>> pad_list(x, 0)
|
| 47 |
+
tensor([[1., 1., 1., 1.],
|
| 48 |
+
[1., 1., 0., 0.],
|
| 49 |
+
[1., 0., 0., 0.]])
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
n_batch = len(xs)
|
| 53 |
+
max_len = max(x.size(0) for x in xs)
|
| 54 |
+
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
| 55 |
+
|
| 56 |
+
for i in range(n_batch):
|
| 57 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
| 58 |
+
|
| 59 |
+
return pad
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def pad_list_all_dim(xs, pad_value):
|
| 63 |
+
"""Perform padding for the list of tensors.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
| 67 |
+
pad_value (float): Value for padding.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
| 71 |
+
|
| 72 |
+
Examples:
|
| 73 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
| 74 |
+
>>> x
|
| 75 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
| 76 |
+
>>> pad_list(x, 0)
|
| 77 |
+
tensor([[1., 1., 1., 1.],
|
| 78 |
+
[1., 1., 0., 0.],
|
| 79 |
+
[1., 0., 0., 0.]])
|
| 80 |
+
|
| 81 |
+
"""
|
| 82 |
+
n_batch = len(xs)
|
| 83 |
+
num_dim = len(xs[0].shape)
|
| 84 |
+
max_len_all_dim = []
|
| 85 |
+
for i in range(num_dim):
|
| 86 |
+
max_len_all_dim.append(max(x.size(i) for x in xs))
|
| 87 |
+
pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
|
| 88 |
+
|
| 89 |
+
for i in range(n_batch):
|
| 90 |
+
if num_dim == 1:
|
| 91 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
| 92 |
+
elif num_dim == 2:
|
| 93 |
+
pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
|
| 94 |
+
elif num_dim == 3:
|
| 95 |
+
pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
return pad
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
| 105 |
+
"""Make mask tensor containing indices of padded part.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 109 |
+
xs (Tensor, optional): The reference tensor.
|
| 110 |
+
If set, masks will be the same shape as this tensor.
|
| 111 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
| 112 |
+
See the example.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Tensor: Mask tensor containing indices of padded part.
|
| 116 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 117 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
| 118 |
+
|
| 119 |
+
Examples:
|
| 120 |
+
With only lengths.
|
| 121 |
+
|
| 122 |
+
>>> lengths = [5, 3, 2]
|
| 123 |
+
>>> make_pad_mask(lengths)
|
| 124 |
+
masks = [[0, 0, 0, 0 ,0],
|
| 125 |
+
[0, 0, 0, 1, 1],
|
| 126 |
+
[0, 0, 1, 1, 1]]
|
| 127 |
+
|
| 128 |
+
With the reference tensor.
|
| 129 |
+
|
| 130 |
+
>>> xs = torch.zeros((3, 2, 4))
|
| 131 |
+
>>> make_pad_mask(lengths, xs)
|
| 132 |
+
tensor([[[0, 0, 0, 0],
|
| 133 |
+
[0, 0, 0, 0]],
|
| 134 |
+
[[0, 0, 0, 1],
|
| 135 |
+
[0, 0, 0, 1]],
|
| 136 |
+
[[0, 0, 1, 1],
|
| 137 |
+
[0, 0, 1, 1]]], dtype=torch.uint8)
|
| 138 |
+
>>> xs = torch.zeros((3, 2, 6))
|
| 139 |
+
>>> make_pad_mask(lengths, xs)
|
| 140 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
| 141 |
+
[0, 0, 0, 0, 0, 1]],
|
| 142 |
+
[[0, 0, 0, 1, 1, 1],
|
| 143 |
+
[0, 0, 0, 1, 1, 1]],
|
| 144 |
+
[[0, 0, 1, 1, 1, 1],
|
| 145 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 146 |
+
|
| 147 |
+
With the reference tensor and dimension indicator.
|
| 148 |
+
|
| 149 |
+
>>> xs = torch.zeros((3, 6, 6))
|
| 150 |
+
>>> make_pad_mask(lengths, xs, 1)
|
| 151 |
+
tensor([[[0, 0, 0, 0, 0, 0],
|
| 152 |
+
[0, 0, 0, 0, 0, 0],
|
| 153 |
+
[0, 0, 0, 0, 0, 0],
|
| 154 |
+
[0, 0, 0, 0, 0, 0],
|
| 155 |
+
[0, 0, 0, 0, 0, 0],
|
| 156 |
+
[1, 1, 1, 1, 1, 1]],
|
| 157 |
+
[[0, 0, 0, 0, 0, 0],
|
| 158 |
+
[0, 0, 0, 0, 0, 0],
|
| 159 |
+
[0, 0, 0, 0, 0, 0],
|
| 160 |
+
[1, 1, 1, 1, 1, 1],
|
| 161 |
+
[1, 1, 1, 1, 1, 1],
|
| 162 |
+
[1, 1, 1, 1, 1, 1]],
|
| 163 |
+
[[0, 0, 0, 0, 0, 0],
|
| 164 |
+
[0, 0, 0, 0, 0, 0],
|
| 165 |
+
[1, 1, 1, 1, 1, 1],
|
| 166 |
+
[1, 1, 1, 1, 1, 1],
|
| 167 |
+
[1, 1, 1, 1, 1, 1],
|
| 168 |
+
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 169 |
+
>>> make_pad_mask(lengths, xs, 2)
|
| 170 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
| 171 |
+
[0, 0, 0, 0, 0, 1],
|
| 172 |
+
[0, 0, 0, 0, 0, 1],
|
| 173 |
+
[0, 0, 0, 0, 0, 1],
|
| 174 |
+
[0, 0, 0, 0, 0, 1],
|
| 175 |
+
[0, 0, 0, 0, 0, 1]],
|
| 176 |
+
[[0, 0, 0, 1, 1, 1],
|
| 177 |
+
[0, 0, 0, 1, 1, 1],
|
| 178 |
+
[0, 0, 0, 1, 1, 1],
|
| 179 |
+
[0, 0, 0, 1, 1, 1],
|
| 180 |
+
[0, 0, 0, 1, 1, 1],
|
| 181 |
+
[0, 0, 0, 1, 1, 1]],
|
| 182 |
+
[[0, 0, 1, 1, 1, 1],
|
| 183 |
+
[0, 0, 1, 1, 1, 1],
|
| 184 |
+
[0, 0, 1, 1, 1, 1],
|
| 185 |
+
[0, 0, 1, 1, 1, 1],
|
| 186 |
+
[0, 0, 1, 1, 1, 1],
|
| 187 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
| 188 |
+
|
| 189 |
+
"""
|
| 190 |
+
if length_dim == 0:
|
| 191 |
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
| 192 |
+
|
| 193 |
+
if not isinstance(lengths, list):
|
| 194 |
+
lengths = lengths.tolist()
|
| 195 |
+
bs = int(len(lengths))
|
| 196 |
+
if maxlen is None:
|
| 197 |
+
if xs is None:
|
| 198 |
+
maxlen = int(max(lengths))
|
| 199 |
+
else:
|
| 200 |
+
maxlen = xs.size(length_dim)
|
| 201 |
+
else:
|
| 202 |
+
assert xs is None
|
| 203 |
+
assert maxlen >= int(max(lengths))
|
| 204 |
+
|
| 205 |
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
| 206 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
| 207 |
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
| 208 |
+
mask = seq_range_expand >= seq_length_expand
|
| 209 |
+
|
| 210 |
+
if xs is not None:
|
| 211 |
+
assert xs.size(0) == bs, (xs.size(0), bs)
|
| 212 |
+
|
| 213 |
+
if length_dim < 0:
|
| 214 |
+
length_dim = xs.dim() + length_dim
|
| 215 |
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
| 216 |
+
ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
|
| 217 |
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
| 218 |
+
return mask
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
| 222 |
+
"""Make mask tensor containing indices of non-padded part.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 226 |
+
xs (Tensor, optional): The reference tensor.
|
| 227 |
+
If set, masks will be the same shape as this tensor.
|
| 228 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
| 229 |
+
See the example.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
ByteTensor: mask tensor containing indices of padded part.
|
| 233 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
| 234 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
| 235 |
+
|
| 236 |
+
Examples:
|
| 237 |
+
With only lengths.
|
| 238 |
+
|
| 239 |
+
>>> lengths = [5, 3, 2]
|
| 240 |
+
>>> make_non_pad_mask(lengths)
|
| 241 |
+
masks = [[1, 1, 1, 1 ,1],
|
| 242 |
+
[1, 1, 1, 0, 0],
|
| 243 |
+
[1, 1, 0, 0, 0]]
|
| 244 |
+
|
| 245 |
+
With the reference tensor.
|
| 246 |
+
|
| 247 |
+
>>> xs = torch.zeros((3, 2, 4))
|
| 248 |
+
>>> make_non_pad_mask(lengths, xs)
|
| 249 |
+
tensor([[[1, 1, 1, 1],
|
| 250 |
+
[1, 1, 1, 1]],
|
| 251 |
+
[[1, 1, 1, 0],
|
| 252 |
+
[1, 1, 1, 0]],
|
| 253 |
+
[[1, 1, 0, 0],
|
| 254 |
+
[1, 1, 0, 0]]], dtype=torch.uint8)
|
| 255 |
+
>>> xs = torch.zeros((3, 2, 6))
|
| 256 |
+
>>> make_non_pad_mask(lengths, xs)
|
| 257 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
| 258 |
+
[1, 1, 1, 1, 1, 0]],
|
| 259 |
+
[[1, 1, 1, 0, 0, 0],
|
| 260 |
+
[1, 1, 1, 0, 0, 0]],
|
| 261 |
+
[[1, 1, 0, 0, 0, 0],
|
| 262 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 263 |
+
|
| 264 |
+
With the reference tensor and dimension indicator.
|
| 265 |
+
|
| 266 |
+
>>> xs = torch.zeros((3, 6, 6))
|
| 267 |
+
>>> make_non_pad_mask(lengths, xs, 1)
|
| 268 |
+
tensor([[[1, 1, 1, 1, 1, 1],
|
| 269 |
+
[1, 1, 1, 1, 1, 1],
|
| 270 |
+
[1, 1, 1, 1, 1, 1],
|
| 271 |
+
[1, 1, 1, 1, 1, 1],
|
| 272 |
+
[1, 1, 1, 1, 1, 1],
|
| 273 |
+
[0, 0, 0, 0, 0, 0]],
|
| 274 |
+
[[1, 1, 1, 1, 1, 1],
|
| 275 |
+
[1, 1, 1, 1, 1, 1],
|
| 276 |
+
[1, 1, 1, 1, 1, 1],
|
| 277 |
+
[0, 0, 0, 0, 0, 0],
|
| 278 |
+
[0, 0, 0, 0, 0, 0],
|
| 279 |
+
[0, 0, 0, 0, 0, 0]],
|
| 280 |
+
[[1, 1, 1, 1, 1, 1],
|
| 281 |
+
[1, 1, 1, 1, 1, 1],
|
| 282 |
+
[0, 0, 0, 0, 0, 0],
|
| 283 |
+
[0, 0, 0, 0, 0, 0],
|
| 284 |
+
[0, 0, 0, 0, 0, 0],
|
| 285 |
+
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 286 |
+
>>> make_non_pad_mask(lengths, xs, 2)
|
| 287 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
| 288 |
+
[1, 1, 1, 1, 1, 0],
|
| 289 |
+
[1, 1, 1, 1, 1, 0],
|
| 290 |
+
[1, 1, 1, 1, 1, 0],
|
| 291 |
+
[1, 1, 1, 1, 1, 0],
|
| 292 |
+
[1, 1, 1, 1, 1, 0]],
|
| 293 |
+
[[1, 1, 1, 0, 0, 0],
|
| 294 |
+
[1, 1, 1, 0, 0, 0],
|
| 295 |
+
[1, 1, 1, 0, 0, 0],
|
| 296 |
+
[1, 1, 1, 0, 0, 0],
|
| 297 |
+
[1, 1, 1, 0, 0, 0],
|
| 298 |
+
[1, 1, 1, 0, 0, 0]],
|
| 299 |
+
[[1, 1, 0, 0, 0, 0],
|
| 300 |
+
[1, 1, 0, 0, 0, 0],
|
| 301 |
+
[1, 1, 0, 0, 0, 0],
|
| 302 |
+
[1, 1, 0, 0, 0, 0],
|
| 303 |
+
[1, 1, 0, 0, 0, 0],
|
| 304 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
| 305 |
+
|
| 306 |
+
"""
|
| 307 |
+
return ~make_pad_mask(lengths, xs, length_dim)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def mask_by_length(xs, lengths, fill=0):
|
| 311 |
+
"""Mask tensor according to length.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
xs (Tensor): Batch of input tensor (B, `*`).
|
| 315 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
| 316 |
+
fill (int or float): Value to fill masked part.
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
Tensor: Batch of masked input tensor (B, `*`).
|
| 320 |
+
|
| 321 |
+
Examples:
|
| 322 |
+
>>> x = torch.arange(5).repeat(3, 1) + 1
|
| 323 |
+
>>> x
|
| 324 |
+
tensor([[1, 2, 3, 4, 5],
|
| 325 |
+
[1, 2, 3, 4, 5],
|
| 326 |
+
[1, 2, 3, 4, 5]])
|
| 327 |
+
>>> lengths = [5, 3, 2]
|
| 328 |
+
>>> mask_by_length(x, lengths)
|
| 329 |
+
tensor([[1, 2, 3, 4, 5],
|
| 330 |
+
[1, 2, 3, 0, 0],
|
| 331 |
+
[1, 2, 0, 0, 0]])
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
assert xs.size(0) == len(lengths)
|
| 335 |
+
ret = xs.data.new(*xs.size()).fill_(fill)
|
| 336 |
+
for i, l in enumerate(lengths):
|
| 337 |
+
ret[i, :l] = xs[i, :l]
|
| 338 |
+
return ret
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def to_torch_tensor(x):
|
| 342 |
+
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Tensor or ComplexTensor: Type converted inputs.
|
| 349 |
+
|
| 350 |
+
Examples:
|
| 351 |
+
>>> xs = np.ones(3, dtype=np.float32)
|
| 352 |
+
>>> xs = to_torch_tensor(xs)
|
| 353 |
+
tensor([1., 1., 1.])
|
| 354 |
+
>>> xs = torch.ones(3, 4, 5)
|
| 355 |
+
>>> assert to_torch_tensor(xs) is xs
|
| 356 |
+
>>> xs = {'real': xs, 'imag': xs}
|
| 357 |
+
>>> to_torch_tensor(xs)
|
| 358 |
+
ComplexTensor(
|
| 359 |
+
Real:
|
| 360 |
+
tensor([1., 1., 1.])
|
| 361 |
+
Imag;
|
| 362 |
+
tensor([1., 1., 1.])
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
"""
|
| 366 |
+
# If numpy, change to torch tensor
|
| 367 |
+
if isinstance(x, np.ndarray):
|
| 368 |
+
if x.dtype.kind == "c":
|
| 369 |
+
# Dynamically importing because torch_complex requires python3
|
| 370 |
+
from torch_complex.tensor import ComplexTensor
|
| 371 |
+
|
| 372 |
+
return ComplexTensor(x)
|
| 373 |
+
else:
|
| 374 |
+
return torch.from_numpy(x)
|
| 375 |
+
|
| 376 |
+
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
| 377 |
+
elif isinstance(x, dict):
|
| 378 |
+
# Dynamically importing because torch_complex requires python3
|
| 379 |
+
from torch_complex.tensor import ComplexTensor
|
| 380 |
+
|
| 381 |
+
if "real" not in x or "imag" not in x:
|
| 382 |
+
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
| 383 |
+
# Relative importing because of using python3 syntax
|
| 384 |
+
return ComplexTensor(x["real"], x["imag"])
|
| 385 |
+
|
| 386 |
+
# If torch.Tensor, as it is
|
| 387 |
+
elif isinstance(x, torch.Tensor):
|
| 388 |
+
return x
|
| 389 |
+
|
| 390 |
+
else:
|
| 391 |
+
error = (
|
| 392 |
+
"x must be numpy.ndarray, torch.Tensor or a dict like "
|
| 393 |
+
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
| 394 |
+
"but got {}".format(type(x))
|
| 395 |
+
)
|
| 396 |
+
try:
|
| 397 |
+
from torch_complex.tensor import ComplexTensor
|
| 398 |
+
except Exception:
|
| 399 |
+
# If PY2
|
| 400 |
+
raise ValueError(error)
|
| 401 |
+
else:
|
| 402 |
+
# If PY3
|
| 403 |
+
if isinstance(x, ComplexTensor):
|
| 404 |
+
return x
|
| 405 |
+
else:
|
| 406 |
+
raise ValueError(error)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def get_subsample(train_args, mode, arch):
|
| 410 |
+
"""Parse the subsampling factors from the args for the specified `mode` and `arch`.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
train_args: argument Namespace containing options.
|
| 414 |
+
mode: one of ('asr', 'mt', 'st')
|
| 415 |
+
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
np.ndarray / List[np.ndarray]: subsampling factors.
|
| 419 |
+
"""
|
| 420 |
+
if arch == "transformer":
|
| 421 |
+
return np.array([1])
|
| 422 |
+
|
| 423 |
+
elif mode == "mt" and arch == "rnn":
|
| 424 |
+
# +1 means input (+1) and layers outputs (train_args.elayer)
|
| 425 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
|
| 426 |
+
logging.warning("Subsampling is not performed for machine translation.")
|
| 427 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 428 |
+
return subsample
|
| 429 |
+
|
| 430 |
+
elif (
|
| 431 |
+
(mode == "asr" and arch in ("rnn", "rnn-t"))
|
| 432 |
+
or (mode == "mt" and arch == "rnn")
|
| 433 |
+
or (mode == "st" and arch == "rnn")
|
| 434 |
+
):
|
| 435 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
|
| 436 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
| 437 |
+
ss = train_args.subsample.split("_")
|
| 438 |
+
for j in range(min(train_args.elayers + 1, len(ss))):
|
| 439 |
+
subsample[j] = int(ss[j])
|
| 440 |
+
else:
|
| 441 |
+
logging.warning(
|
| 442 |
+
"Subsampling is not performed for vgg*. "
|
| 443 |
+
"It is performed in max pooling layers at CNN."
|
| 444 |
+
)
|
| 445 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 446 |
+
return subsample
|
| 447 |
+
|
| 448 |
+
elif mode == "asr" and arch == "rnn_mix":
|
| 449 |
+
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32)
|
| 450 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
| 451 |
+
ss = train_args.subsample.split("_")
|
| 452 |
+
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
| 453 |
+
subsample[j] = int(ss[j])
|
| 454 |
+
else:
|
| 455 |
+
logging.warning(
|
| 456 |
+
"Subsampling is not performed for vgg*. "
|
| 457 |
+
"It is performed in max pooling layers at CNN."
|
| 458 |
+
)
|
| 459 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 460 |
+
return subsample
|
| 461 |
+
|
| 462 |
+
elif mode == "asr" and arch == "rnn_mulenc":
|
| 463 |
+
subsample_list = []
|
| 464 |
+
for idx in range(train_args.num_encs):
|
| 465 |
+
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
|
| 466 |
+
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
| 467 |
+
ss = train_args.subsample[idx].split("_")
|
| 468 |
+
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
| 469 |
+
subsample[j] = int(ss[j])
|
| 470 |
+
else:
|
| 471 |
+
logging.warning(
|
| 472 |
+
"Encoder %d: Subsampling is not performed for vgg*. "
|
| 473 |
+
"It is performed in max pooling layers at CNN.",
|
| 474 |
+
idx + 1,
|
| 475 |
+
)
|
| 476 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
| 477 |
+
subsample_list.append(subsample)
|
| 478 |
+
return subsample_list
|
| 479 |
+
|
| 480 |
+
else:
|
| 481 |
+
raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
| 485 |
+
"""Replace keys of old prefix with new prefix in state dict."""
|
| 486 |
+
# need this list not to break the dict iterator
|
| 487 |
+
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
| 488 |
+
if len(old_keys) > 0:
|
| 489 |
+
logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
|
| 490 |
+
for k in old_keys:
|
| 491 |
+
v = state_dict.pop(k)
|
| 492 |
+
new_k = k.replace(old_prefix, new_prefix)
|
| 493 |
+
state_dict[new_k] = v
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
class Swish(torch.nn.Module):
|
| 497 |
+
"""Swish activation definition.
|
| 498 |
+
|
| 499 |
+
Swish(x) = (beta * x) * sigmoid(x)
|
| 500 |
+
where beta = 1 defines standard Swish activation.
|
| 501 |
+
|
| 502 |
+
References:
|
| 503 |
+
https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
|
| 504 |
+
E-swish variant: https://arxiv.org/abs/1801.07145.
|
| 505 |
+
|
| 506 |
+
Args:
|
| 507 |
+
beta: Beta parameter for E-Swish.
|
| 508 |
+
(beta >= 1. If beta < 1, use standard Swish).
|
| 509 |
+
use_builtin: Whether to use PyTorch function if available.
|
| 510 |
+
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
|
| 514 |
+
super().__init__()
|
| 515 |
+
|
| 516 |
+
self.beta = beta
|
| 517 |
+
|
| 518 |
+
if beta > 1:
|
| 519 |
+
self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
|
| 520 |
+
else:
|
| 521 |
+
if use_builtin:
|
| 522 |
+
self.swish = torch.nn.SiLU()
|
| 523 |
+
else:
|
| 524 |
+
self.swish = lambda x: x * torch.sigmoid(x)
|
| 525 |
+
|
| 526 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 527 |
+
"""Forward computation."""
|
| 528 |
+
return self.swish(x)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def get_activation(act):
|
| 532 |
+
"""Return activation function."""
|
| 533 |
+
|
| 534 |
+
activation_funcs = {
|
| 535 |
+
"hardtanh": torch.nn.Hardtanh,
|
| 536 |
+
"tanh": torch.nn.Tanh,
|
| 537 |
+
"relu": torch.nn.ReLU,
|
| 538 |
+
"selu": torch.nn.SELU,
|
| 539 |
+
"swish": Swish,
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
return activation_funcs[act]()
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class TooShortUttError(Exception):
|
| 546 |
+
"""Raised when the utt is too short for subsampling.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
message: Error message to display.
|
| 550 |
+
actual_size: The size that cannot pass the subsampling.
|
| 551 |
+
limit: The size limit for subsampling.
|
| 552 |
+
|
| 553 |
+
"""
|
| 554 |
+
|
| 555 |
+
def __init__(self, message: str, actual_size: int, limit: int) -> None:
|
| 556 |
+
"""Construct a TooShortUttError module."""
|
| 557 |
+
super().__init__(message)
|
| 558 |
+
|
| 559 |
+
self.actual_size = actual_size
|
| 560 |
+
self.limit = limit
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
|
| 564 |
+
"""Check if the input is too short for subsampling.
|
| 565 |
+
|
| 566 |
+
Args:
|
| 567 |
+
sub_factor: Subsampling factor for Conv2DSubsampling.
|
| 568 |
+
size: Input size.
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
: Whether an error should be sent.
|
| 572 |
+
: Size limit for specified subsampling factor.
|
| 573 |
+
|
| 574 |
+
"""
|
| 575 |
+
if sub_factor == 2 and size < 3:
|
| 576 |
+
return True, 7
|
| 577 |
+
elif sub_factor == 4 and size < 7:
|
| 578 |
+
return True, 7
|
| 579 |
+
elif sub_factor == 6 and size < 11:
|
| 580 |
+
return True, 11
|
| 581 |
+
|
| 582 |
+
return False, -1
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
|
| 586 |
+
"""Get conv2D second layer parameters for given subsampling factor.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
sub_factor: Subsampling factor (1/X).
|
| 590 |
+
input_size: Input size.
|
| 591 |
+
|
| 592 |
+
Returns:
|
| 593 |
+
: Kernel size for second convolution.
|
| 594 |
+
: Stride for second convolution.
|
| 595 |
+
: Conv2DSubsampling output size.
|
| 596 |
+
|
| 597 |
+
"""
|
| 598 |
+
if sub_factor == 2:
|
| 599 |
+
return 3, 1, (((input_size - 1) // 2 - 2))
|
| 600 |
+
elif sub_factor == 4:
|
| 601 |
+
return 3, 2, (((input_size - 1) // 2 - 1) // 2)
|
| 602 |
+
elif sub_factor == 6:
|
| 603 |
+
return 5, 3, (((input_size - 1) // 2 - 2) // 3)
|
| 604 |
+
else:
|
| 605 |
+
raise ValueError("subsampling_factor parameter should be set to either 2, 4 or 6.")
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def make_chunk_mask(
|
| 609 |
+
size: int,
|
| 610 |
+
chunk_size: int,
|
| 611 |
+
left_chunk_size: int = 0,
|
| 612 |
+
device: torch.device = None,
|
| 613 |
+
) -> torch.Tensor:
|
| 614 |
+
"""Create chunk mask for the subsequent steps (size, size).
|
| 615 |
+
|
| 616 |
+
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
size: Size of the source mask.
|
| 620 |
+
chunk_size: Number of frames in chunk.
|
| 621 |
+
left_chunk_size: Size of the left context in chunks (0 means full context).
|
| 622 |
+
device: Device for the mask tensor.
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
mask: Chunk mask. (size, size)
|
| 626 |
+
|
| 627 |
+
"""
|
| 628 |
+
mask = torch.zeros(size, size, device=device, dtype=torch.bool)
|
| 629 |
+
|
| 630 |
+
for i in range(size):
|
| 631 |
+
if left_chunk_size < 0:
|
| 632 |
+
start = 0
|
| 633 |
+
else:
|
| 634 |
+
start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
|
| 635 |
+
|
| 636 |
+
end = min((i // chunk_size + 1) * chunk_size, size)
|
| 637 |
+
mask[i, start:end] = True
|
| 638 |
+
|
| 639 |
+
return ~mask
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
|
| 643 |
+
"""Create source mask for given lengths.
|
| 644 |
+
|
| 645 |
+
Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
lengths: Sequence lengths. (B,)
|
| 649 |
+
|
| 650 |
+
Returns:
|
| 651 |
+
: Mask for the sequence lengths. (B, max_len)
|
| 652 |
+
|
| 653 |
+
"""
|
| 654 |
+
max_len = lengths.max()
|
| 655 |
+
batch_size = lengths.size(0)
|
| 656 |
+
|
| 657 |
+
expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
|
| 658 |
+
|
| 659 |
+
return expanded_lengths >= lengths.unsqueeze(1)
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def get_transducer_task_io(
|
| 663 |
+
labels: torch.Tensor,
|
| 664 |
+
encoder_out_lens: torch.Tensor,
|
| 665 |
+
ignore_id: int = -1,
|
| 666 |
+
blank_id: int = 0,
|
| 667 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 668 |
+
"""Get Transducer loss I/O.
|
| 669 |
+
|
| 670 |
+
Args:
|
| 671 |
+
labels: Label ID sequences. (B, L)
|
| 672 |
+
encoder_out_lens: Encoder output lengths. (B,)
|
| 673 |
+
ignore_id: Padding symbol ID.
|
| 674 |
+
blank_id: Blank symbol ID.
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
decoder_in: Decoder inputs. (B, U)
|
| 678 |
+
target: Target label ID sequences. (B, U)
|
| 679 |
+
t_len: Time lengths. (B,)
|
| 680 |
+
u_len: Label lengths. (B,)
|
| 681 |
+
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
|
| 685 |
+
"""Create padded batch of labels from a list of labels sequences.
|
| 686 |
+
|
| 687 |
+
Args:
|
| 688 |
+
labels: Labels sequences. [B x (?)]
|
| 689 |
+
padding_value: Padding value.
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
labels: Batch of padded labels sequences. (B,)
|
| 693 |
+
|
| 694 |
+
"""
|
| 695 |
+
batch_size = len(labels)
|
| 696 |
+
|
| 697 |
+
padded = (
|
| 698 |
+
labels[0]
|
| 699 |
+
.new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
|
| 700 |
+
.fill_(padding_value)
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
for i in range(batch_size):
|
| 704 |
+
padded[i, : labels[i].size(0)] = labels[i]
|
| 705 |
+
|
| 706 |
+
return padded
|
| 707 |
+
|
| 708 |
+
device = labels.device
|
| 709 |
+
|
| 710 |
+
labels_unpad = [y[y != ignore_id] for y in labels]
|
| 711 |
+
blank = labels[0].new([blank_id])
|
| 712 |
+
|
| 713 |
+
decoder_in = pad_list(
|
| 714 |
+
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
|
| 715 |
+
).to(device)
|
| 716 |
+
|
| 717 |
+
target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
|
| 718 |
+
|
| 719 |
+
encoder_out_lens = list(map(int, encoder_out_lens))
|
| 720 |
+
t_len = torch.IntTensor(encoder_out_lens).to(device)
|
| 721 |
+
|
| 722 |
+
u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
|
| 723 |
+
|
| 724 |
+
return decoder_in, target, t_len, u_len
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
|
| 728 |
+
"""Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
|
| 729 |
+
if t.size(dim) == pad_len:
|
| 730 |
+
return t
|
| 731 |
+
else:
|
| 732 |
+
pad_size = list(t.shape)
|
| 733 |
+
pad_size[dim] = pad_len - t.size(dim)
|
| 734 |
+
return torch.cat([t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim)
|
funcineforge/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tokenizer import FunCineForgeTokenizer
|