Commit
·
8b651bb
verified
·
0
Parent(s):
Duplicate from camenduru/dreamtalk
Browse filesCo-authored-by: camenduru <camenduru@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +37 -0
- damo/dreamtalk/.mdl +0 -0
- damo/dreamtalk/.msc +0 -0
- damo/dreamtalk/README.md +131 -0
- damo/dreamtalk/checkpoints/denoising_network.pth +3 -0
- damo/dreamtalk/checkpoints/renderer.pt +3 -0
- damo/dreamtalk/configs/default.py +91 -0
- damo/dreamtalk/configuration.json +11 -0
- damo/dreamtalk/core/networks/__init__.py +14 -0
- damo/dreamtalk/core/networks/diffusion_net.py +340 -0
- damo/dreamtalk/core/networks/diffusion_util.py +131 -0
- damo/dreamtalk/core/networks/disentangle_decoder.py +240 -0
- damo/dreamtalk/core/networks/dynamic_conv.py +156 -0
- damo/dreamtalk/core/networks/dynamic_fc_decoder.py +178 -0
- damo/dreamtalk/core/networks/dynamic_linear.py +50 -0
- damo/dreamtalk/core/networks/generator.py +309 -0
- damo/dreamtalk/core/networks/mish.py +51 -0
- damo/dreamtalk/core/networks/self_attention_pooling.py +53 -0
- damo/dreamtalk/core/networks/transformer.py +293 -0
- damo/dreamtalk/core/utils.py +456 -0
- damo/dreamtalk/data/audio/German1.wav +0 -0
- damo/dreamtalk/data/audio/German2.wav +0 -0
- damo/dreamtalk/data/audio/German3.wav +0 -0
- damo/dreamtalk/data/audio/German4.wav +0 -0
- damo/dreamtalk/data/audio/acknowledgement_chinese.m4a +0 -0
- damo/dreamtalk/data/audio/acknowledgement_english.m4a +0 -0
- damo/dreamtalk/data/audio/chinese1_haierlizhi.wav +0 -0
- damo/dreamtalk/data/audio/chinese2_guanyu.wav +0 -0
- damo/dreamtalk/data/audio/french1.wav +0 -0
- damo/dreamtalk/data/audio/french2.wav +0 -0
- damo/dreamtalk/data/audio/french3.wav +0 -0
- damo/dreamtalk/data/audio/italian1.wav +0 -0
- damo/dreamtalk/data/audio/italian2.wav +0 -0
- damo/dreamtalk/data/audio/italian3.wav +0 -0
- damo/dreamtalk/data/audio/japan1.wav +0 -0
- damo/dreamtalk/data/audio/japan2.wav +0 -0
- damo/dreamtalk/data/audio/japan3.wav +0 -0
- damo/dreamtalk/data/audio/korean1.wav +0 -0
- damo/dreamtalk/data/audio/korean2.wav +0 -0
- damo/dreamtalk/data/audio/korean3.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_narrative.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav +0 -0
- damo/dreamtalk/data/audio/out_of_domain_narrative.wav +0 -0
- damo/dreamtalk/data/audio/spanish1.wav +0 -0
- damo/dreamtalk/data/audio/spanish2.wav +0 -0
- damo/dreamtalk/data/audio/spanish3.wav +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
damo/dreamtalk/data/pose/RichardShelby_front_neutral_level1_001.mat filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
damo/dreamtalk/media/teaser.gif filter=lfs diff=lfs merge=lfs -text
|
damo/dreamtalk/.mdl
ADDED
|
Binary file (37 Bytes). View file
|
|
|
damo/dreamtalk/.msc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
damo/dreamtalk/README.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models
|
| 2 |
+
|
| 3 |
+
<a href='https://dreamtalk-project.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2312.09767'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> [](https://youtu.be/VF4vlE6ZqWQ)
|
| 4 |
+
|
| 5 |
+
DreamTalk is a diffusion-based audio-driven expressive talking head generation framework that can produce high-quality talking head videos across diverse speaking styles. DreamTalk exhibits robust performance with a diverse array of inputs, including songs, speech in multiple languages, noisy audio, and out-of-domain portraits.
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
## News
|
| 10 |
+
- __[2023.12]__ Release inference code and pretrained checkpoint.
|
| 11 |
+
|
| 12 |
+
## 安装依赖
|
| 13 |
+
```
|
| 14 |
+
pip install dlib
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Installation
|
| 18 |
+
|
| 19 |
+
我在`output_video`文件夹下已经放入了一些生成好的文件, 可运行下面脚本, 对比下结果.
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
from modelscope.utils.constant import Tasks
|
| 23 |
+
from modelscope.pipelines import pipeline
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
pipe = pipeline(task=Tasks.text_to_video_synthesis, model='damo/dreamtalk',
|
| 27 |
+
style_clip_path="data/style_clip/3DMM/M030_front_surprised_level3_001.mat",
|
| 28 |
+
pose_path="data/pose/RichardShelby_front_neutral_level1_001.mat",
|
| 29 |
+
model_revision='master'
|
| 30 |
+
)
|
| 31 |
+
# ,model_revision='master')
|
| 32 |
+
inputs={
|
| 33 |
+
"output_name": "songbie_yk_male",
|
| 34 |
+
"wav_path": "data/audio/acknowledgement_english.m4a",
|
| 35 |
+
"img_crop": True,
|
| 36 |
+
"image_path": "data/src_img/uncropped/male_face.png",
|
| 37 |
+
"max_gen_len": 20
|
| 38 |
+
}
|
| 39 |
+
pipe(input=inputs)
|
| 40 |
+
print("end")
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
`wav_path` 为输入音频路径;
|
| 44 |
+
|
| 45 |
+
`style_clip_path` 为表情参考文件,从带情绪的视频中提取, 可用来控制生成视频的表情;
|
| 46 |
+
|
| 47 |
+
`pose_path` 为头部运动参考文件, 从视频中提取,可用来控制生成视频的头部运动;
|
| 48 |
+
|
| 49 |
+
`image_path` 为说话人肖像, 最好是正脸, 理论支持任意分辨率输入, 会被裁减成$256\times256$ 分辨率;
|
| 50 |
+
|
| 51 |
+
`max_gen_len` 为最长视频生成时长, 单位为秒, 如果输入音频长于这个时间则会被截断;
|
| 52 |
+
|
| 53 |
+
`output_name`为输出名称, 最终生成的视频会在 `output_video` 文件夹下, 中间结果会在 `tmp` 文件夹下.
|
| 54 |
+
|
| 55 |
+
如果输入图片已经为$256\times256$ 而且大小合适无需裁剪, 则可使用`disable_img_crop`跳过裁剪步骤, 如下:
|
| 56 |
+
|
| 57 |
+
## Download Checkpoints
|
| 58 |
+
Download the checkpoint of the denoising network:
|
| 59 |
+
* [ModelScope](tmp)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
Download the checkpoint of the renderer:
|
| 63 |
+
* [ModelScope](tmp)
|
| 64 |
+
|
| 65 |
+
Put the downloaded checkpoints into `checkpoints` folder.
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
## Inference
|
| 69 |
+
Run the script:
|
| 70 |
+
|
| 71 |
+
```
|
| 72 |
+
python inference_for_demo_video.py \
|
| 73 |
+
--wav_path data/audio/acknowledgement_english.m4a \
|
| 74 |
+
--style_clip_path data/style_clip/3DMM/M030_front_neutral_level1_001.mat \
|
| 75 |
+
--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
|
| 76 |
+
--image_path data/src_img/uncropped/male_face.png \
|
| 77 |
+
--cfg_scale 1.0 \
|
| 78 |
+
--max_gen_len 30 \
|
| 79 |
+
--output_name acknowledgement_english@M030_front_neutral_level1_001@male_face
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
`wav_path` specifies the input audio. The input audio file extensions such as wav, mp3, m4a, and mp4 (video with sound) should all be compatible.
|
| 83 |
+
|
| 84 |
+
`style_clip_path` specifies the reference speaking style and `pose_path` specifies head pose. They are 3DMM paramenter sequences extracted from reference videos. You can follow [PIRenderer](https://github.com/RenYurui/PIRender) to extract 3DMM parameters from your own videos. Note that the video frame rate should be 25 FPS. Besides, videos used for head pose reference should be first cropped to $256\times256$ using scripts in [FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
|
| 85 |
+
|
| 86 |
+
`image_path` specifies the input portrait. Its resolution should be larger than $256\times256$. Frontal portraits, with the face directly facing forward and not tilted to one side, usually achieve satisfactory results. The input portrait will be cropped to $256\times256$. If your portrait is already cropped to $256\times256$ and you want to disable cropping, use option `--disable_img_crop` like this:
|
| 87 |
+
|
| 88 |
+
```
|
| 89 |
+
python inference_for_demo_video.py \
|
| 90 |
+
--wav_path data/audio/acknowledgement_chinese.m4a \
|
| 91 |
+
--style_clip_path data/style_clip/3DMM/M030_front_surprised_level3_001.mat \
|
| 92 |
+
--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
|
| 93 |
+
--image_path data/src_img/cropped/zp1.png \
|
| 94 |
+
--disable_img_crop \
|
| 95 |
+
--cfg_scale 1.0 \
|
| 96 |
+
--max_gen_len 30 \
|
| 97 |
+
--output_name acknowledgement_chinese@M030_front_surprised_level3_001@zp1
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
`cfg_scale` controls the scale of classifer-free guidance. It can adjust the intensity of speaking styles.
|
| 101 |
+
|
| 102 |
+
`max_gen_len` is the maximum video generation duration, measured in seconds. If the input audio exceeds this length, it will be truncated.
|
| 103 |
+
|
| 104 |
+
The generated video will be named `$(output_name).mp4` and put in the output_video folder. Intermediate results, including the cropped portrait, will be in the `tmp/$(output_name)` folder.
|
| 105 |
+
|
| 106 |
+
Sample inputs are presented in `data` folder. Due to copyright issues, we are unable to include the songs we have used in this folder.
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
## Acknowledgements
|
| 110 |
+
|
| 111 |
+
We extend our heartfelt thanks for the invaluable contributions made by preceding works to the development of DreamTalk. This includes, but is not limited to:
|
| 112 |
+
[PIRenderer](https://github.com/RenYurui/PIRender)
|
| 113 |
+
,[AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face)
|
| 114 |
+
,[StyleTalk](https://github.com/FuxiVirtualHuman/styletalk)
|
| 115 |
+
,[Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
|
| 116 |
+
,[Wav2vec2.0](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-english)
|
| 117 |
+
,[diffusion-point-cloud](https://github.com/luost26/diffusion-point-cloud)
|
| 118 |
+
,[FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing). We are dedicated to advancing upon these foundational works with the utmost respect for their original contributions.
|
| 119 |
+
|
| 120 |
+
## Citation
|
| 121 |
+
If you find this codebase useful for your research, please use the following entry.
|
| 122 |
+
```BibTeX
|
| 123 |
+
@article{ma2023dreamtalk,
|
| 124 |
+
title={DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models},
|
| 125 |
+
author={Ma, Yifeng and Zhang, Shiwei and Wang, Jiayu and Wang, Xiang and Zhang, Yingya and Deng, Zhidong},
|
| 126 |
+
journal={arXiv preprint arXiv:2312.09767},
|
| 127 |
+
year={2023}
|
| 128 |
+
}
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
|
damo/dreamtalk/checkpoints/denoising_network.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:93864d1316f60e75b40bd820707bb2464f790b1636ae2b9275ee500d41c4e3ae
|
| 3 |
+
size 47908943
|
damo/dreamtalk/checkpoints/renderer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a67014839d42d592255c9fc3b3ceecbcd62c27ce0c0a89ed6628292447404242
|
| 3 |
+
size 335281551
|
damo/dreamtalk/configs/default.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from yacs.config import CfgNode as CN
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
_C = CN()
|
| 5 |
+
_C.TAG = "style_id_emotion"
|
| 6 |
+
_C.DECODER_TYPE = "DisentangleDecoder"
|
| 7 |
+
_C.CONTENT_ENCODER_TYPE = "ContentW2VEncoder"
|
| 8 |
+
_C.STYLE_ENCODER_TYPE = "StyleEncoder"
|
| 9 |
+
|
| 10 |
+
_C.DIFFNET_TYPE = "DiffusionNet"
|
| 11 |
+
|
| 12 |
+
_C.WIN_SIZE = 5
|
| 13 |
+
_C.D_MODEL = 256
|
| 14 |
+
|
| 15 |
+
_C.DATASET = CN()
|
| 16 |
+
_C.DATASET.FACE3D_DIM = 64
|
| 17 |
+
_C.DATASET.NUM_FRAMES = 64
|
| 18 |
+
_C.DATASET.STYLE_MAX_LEN = 256
|
| 19 |
+
|
| 20 |
+
_C.TRAIN = CN()
|
| 21 |
+
_C.TRAIN.FACE3D_LATENT = CN()
|
| 22 |
+
_C.TRAIN.FACE3D_LATENT.TYPE = "face3d"
|
| 23 |
+
|
| 24 |
+
_C.DIFFUSION = CN()
|
| 25 |
+
_C.DIFFUSION.PREDICT_WHAT = "x0" # noise | x0
|
| 26 |
+
_C.DIFFUSION.SCHEDULE = CN()
|
| 27 |
+
_C.DIFFUSION.SCHEDULE.NUM_STEPS = 1000
|
| 28 |
+
_C.DIFFUSION.SCHEDULE.BETA_1 = 1e-4
|
| 29 |
+
_C.DIFFUSION.SCHEDULE.BETA_T = 0.02
|
| 30 |
+
_C.DIFFUSION.SCHEDULE.MODE = "linear"
|
| 31 |
+
|
| 32 |
+
_C.CONTENT_ENCODER = CN()
|
| 33 |
+
_C.CONTENT_ENCODER.d_model = _C.D_MODEL
|
| 34 |
+
_C.CONTENT_ENCODER.nhead = 8
|
| 35 |
+
_C.CONTENT_ENCODER.num_encoder_layers = 3
|
| 36 |
+
_C.CONTENT_ENCODER.dim_feedforward = 4 * _C.D_MODEL
|
| 37 |
+
_C.CONTENT_ENCODER.dropout = 0.1
|
| 38 |
+
_C.CONTENT_ENCODER.activation = "relu"
|
| 39 |
+
_C.CONTENT_ENCODER.normalize_before = False
|
| 40 |
+
_C.CONTENT_ENCODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
|
| 41 |
+
|
| 42 |
+
_C.STYLE_ENCODER = CN()
|
| 43 |
+
_C.STYLE_ENCODER.d_model = _C.D_MODEL
|
| 44 |
+
_C.STYLE_ENCODER.nhead = 8
|
| 45 |
+
_C.STYLE_ENCODER.num_encoder_layers = 3
|
| 46 |
+
_C.STYLE_ENCODER.dim_feedforward = 4 * _C.D_MODEL
|
| 47 |
+
_C.STYLE_ENCODER.dropout = 0.1
|
| 48 |
+
_C.STYLE_ENCODER.activation = "relu"
|
| 49 |
+
_C.STYLE_ENCODER.normalize_before = False
|
| 50 |
+
_C.STYLE_ENCODER.pos_embed_len = _C.DATASET.STYLE_MAX_LEN
|
| 51 |
+
_C.STYLE_ENCODER.aggregate_method = (
|
| 52 |
+
"self_attention_pooling" # average | self_attention_pooling
|
| 53 |
+
)
|
| 54 |
+
# _C.STYLE_ENCODER.input_dim = _C.DATASET.FACE3D_DIM
|
| 55 |
+
|
| 56 |
+
_C.DECODER = CN()
|
| 57 |
+
_C.DECODER.d_model = _C.D_MODEL
|
| 58 |
+
_C.DECODER.nhead = 8
|
| 59 |
+
_C.DECODER.num_decoder_layers = 3
|
| 60 |
+
_C.DECODER.dim_feedforward = 4 * _C.D_MODEL
|
| 61 |
+
_C.DECODER.dropout = 0.1
|
| 62 |
+
_C.DECODER.activation = "relu"
|
| 63 |
+
_C.DECODER.normalize_before = False
|
| 64 |
+
_C.DECODER.return_intermediate_dec = False
|
| 65 |
+
_C.DECODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
|
| 66 |
+
_C.DECODER.network_type = "TransformerDecoder"
|
| 67 |
+
_C.DECODER.dynamic_K = None
|
| 68 |
+
_C.DECODER.dynamic_ratio = None
|
| 69 |
+
# _C.DECODER.output_dim = _C.DATASET.FACE3D_DIM
|
| 70 |
+
# LSFM basis:
|
| 71 |
+
# _C.DECODER.upper_face3d_indices = tuple(list(range(19)) + list(range(46, 51)))
|
| 72 |
+
# _C.DECODER.lower_face3d_indices = tuple(range(19, 46))
|
| 73 |
+
# BFM basis:
|
| 74 |
+
# fmt: off
|
| 75 |
+
_C.DECODER.upper_face3d_indices = [6, 8, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
|
| 76 |
+
# fmt: on
|
| 77 |
+
_C.DECODER.lower_face3d_indices = [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14]
|
| 78 |
+
|
| 79 |
+
_C.CF_GUIDANCE = CN()
|
| 80 |
+
_C.CF_GUIDANCE.TRAINING = True
|
| 81 |
+
_C.CF_GUIDANCE.INFERENCE = True
|
| 82 |
+
_C.CF_GUIDANCE.NULL_PROB = 0.1
|
| 83 |
+
_C.CF_GUIDANCE.SCALE = 1.0
|
| 84 |
+
|
| 85 |
+
_C.INFERENCE = CN()
|
| 86 |
+
_C.INFERENCE.CHECKPOINT = "checkpoints/denoising_network.pth"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_cfg_defaults():
|
| 90 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
| 91 |
+
return _C.clone()
|
damo/dreamtalk/configuration.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"framework": "pytorch",
|
| 3 |
+
"task": "text-to-video-synthesis",
|
| 4 |
+
"model": {
|
| 5 |
+
"type": "Dreamtalk-Generation"
|
| 6 |
+
},
|
| 7 |
+
"pipeline": {
|
| 8 |
+
"type": "Dreamtalk-generation-pipe"
|
| 9 |
+
},
|
| 10 |
+
"allow_remote": true
|
| 11 |
+
}
|
damo/dreamtalk/core/networks/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core.networks.generator import (
|
| 2 |
+
StyleEncoder,
|
| 3 |
+
Decoder,
|
| 4 |
+
ContentW2VEncoder,
|
| 5 |
+
)
|
| 6 |
+
from core.networks.disentangle_decoder import DisentangleDecoder
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_network(name: str):
|
| 10 |
+
obj = globals().get(name)
|
| 11 |
+
if obj is None:
|
| 12 |
+
raise KeyError("Unknown Network: %s" % name)
|
| 13 |
+
else:
|
| 14 |
+
return obj
|
damo/dreamtalk/core/networks/diffusion_net.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import Module
|
| 5 |
+
from core.networks.diffusion_util import VarianceSchedule
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def face3d_raw_to_norm(face3d_raw, exp_min, exp_max):
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
face3d_raw (_type_): (B, L, C_face3d)
|
| 14 |
+
exp_min (_type_): (C_face3d)
|
| 15 |
+
exp_max (_type_): (C_face3d)
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
_type_: (B, L, C_face3d) in [-1, 1]
|
| 19 |
+
"""
|
| 20 |
+
exp_min_expand = exp_min[None, None, :]
|
| 21 |
+
exp_max_expand = exp_max[None, None, :]
|
| 22 |
+
face3d_norm_01 = (face3d_raw - exp_min_expand) / (exp_max_expand - exp_min_expand)
|
| 23 |
+
face3d_norm = face3d_norm_01 * 2 - 1
|
| 24 |
+
return face3d_norm
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def face3d_norm_to_raw(face3d_norm, exp_min, exp_max):
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
face3d_norm (_type_): (B, L, C_face3d)
|
| 32 |
+
exp_min (_type_): (C_face3d)
|
| 33 |
+
exp_max (_type_): (C_face3d)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
_type_: (B, L, C_face3d)
|
| 37 |
+
"""
|
| 38 |
+
exp_min_expand = exp_min[None, None, :]
|
| 39 |
+
exp_max_expand = exp_max[None, None, :]
|
| 40 |
+
face3d_norm_01 = (face3d_norm + 1) / 2
|
| 41 |
+
face3d_raw = face3d_norm_01 * (exp_max_expand - exp_min_expand) + exp_min_expand
|
| 42 |
+
return face3d_raw
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DiffusionNet(Module):
|
| 46 |
+
def __init__(self, cfg, net, var_sched: VarianceSchedule):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.cfg = cfg
|
| 49 |
+
self.net = net
|
| 50 |
+
self.var_sched = var_sched
|
| 51 |
+
self.face3d_latent_type = self.cfg.TRAIN.FACE3D_LATENT.TYPE
|
| 52 |
+
self.predict_what = self.cfg.DIFFUSION.PREDICT_WHAT
|
| 53 |
+
|
| 54 |
+
if self.cfg.CF_GUIDANCE.TRAINING:
|
| 55 |
+
null_style_clip = torch.zeros(
|
| 56 |
+
self.cfg.DATASET.STYLE_MAX_LEN, self.cfg.DATASET.FACE3D_DIM
|
| 57 |
+
)
|
| 58 |
+
self.register_buffer("null_style_clip", null_style_clip)
|
| 59 |
+
|
| 60 |
+
null_pad_mask = torch.tensor([False] * self.cfg.DATASET.STYLE_MAX_LEN)
|
| 61 |
+
self.register_buffer("null_pad_mask", null_pad_mask)
|
| 62 |
+
|
| 63 |
+
def _face3d_to_latent(self, face3d):
|
| 64 |
+
latent = None
|
| 65 |
+
if self.face3d_latent_type == "face3d":
|
| 66 |
+
latent = face3d
|
| 67 |
+
elif self.face3d_latent_type == "normalized_face3d":
|
| 68 |
+
latent = face3d_raw_to_norm(
|
| 69 |
+
face3d, exp_min=self.exp_min, exp_max=self.exp_max
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
|
| 73 |
+
return latent
|
| 74 |
+
|
| 75 |
+
def _latent_to_face3d(self, latent):
|
| 76 |
+
face3d = None
|
| 77 |
+
if self.face3d_latent_type == "face3d":
|
| 78 |
+
face3d = latent
|
| 79 |
+
elif self.face3d_latent_type == "normalized_face3d":
|
| 80 |
+
latent = torch.clamp(latent, min=-1, max=1)
|
| 81 |
+
face3d = face3d_norm_to_raw(
|
| 82 |
+
latent, exp_min=self.exp_min, exp_max=self.exp_max
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
|
| 86 |
+
return face3d
|
| 87 |
+
|
| 88 |
+
def ddim_sample(
|
| 89 |
+
self,
|
| 90 |
+
audio,
|
| 91 |
+
style_clip,
|
| 92 |
+
style_pad_mask,
|
| 93 |
+
output_dim,
|
| 94 |
+
flexibility=0.0,
|
| 95 |
+
ret_traj=False,
|
| 96 |
+
use_cf_guidance=False,
|
| 97 |
+
cfg_scale=2.0,
|
| 98 |
+
ddim_num_step=50,
|
| 99 |
+
ready_style_code=None,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
audio (_type_): (B, L, W) or (B, L, W, C)
|
| 105 |
+
style_clip (_type_): (B, L_clipmax, C_face3d)
|
| 106 |
+
style_pad_mask : (B, L_clipmax)
|
| 107 |
+
pose_dim (_type_): int
|
| 108 |
+
flexibility (float, optional): _description_. Defaults to 0.0.
|
| 109 |
+
ret_traj (bool, optional): _description_. Defaults to False.
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
_type_: (B, L, C_face)
|
| 114 |
+
"""
|
| 115 |
+
if self.predict_what != "x0":
|
| 116 |
+
raise NotImplementedError(self.predict_what)
|
| 117 |
+
|
| 118 |
+
if ready_style_code is not None and use_cf_guidance:
|
| 119 |
+
raise NotImplementedError("not implement cfg for ready style code")
|
| 120 |
+
|
| 121 |
+
c = self.var_sched.num_steps // ddim_num_step
|
| 122 |
+
time_steps = torch.tensor(
|
| 123 |
+
np.asarray(list(range(0, self.var_sched.num_steps, c))) + 1
|
| 124 |
+
)
|
| 125 |
+
assert len(time_steps) == ddim_num_step
|
| 126 |
+
prev_time_steps = torch.cat((torch.tensor([0]), time_steps[:-1]))
|
| 127 |
+
|
| 128 |
+
batch_size, output_len = audio.shape[:2]
|
| 129 |
+
# batch_size = context.size(0)
|
| 130 |
+
context = {
|
| 131 |
+
"audio": audio,
|
| 132 |
+
"style_clip": style_clip,
|
| 133 |
+
"style_pad_mask": style_pad_mask,
|
| 134 |
+
"ready_style_code": ready_style_code,
|
| 135 |
+
}
|
| 136 |
+
if use_cf_guidance:
|
| 137 |
+
uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
|
| 138 |
+
batch_size, 1, 1
|
| 139 |
+
)
|
| 140 |
+
uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
|
| 141 |
+
|
| 142 |
+
context_double = {
|
| 143 |
+
"audio": torch.cat([audio] * 2, dim=0),
|
| 144 |
+
"style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
|
| 145 |
+
"style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
|
| 146 |
+
"ready_style_code": None
|
| 147 |
+
if ready_style_code is None
|
| 148 |
+
else torch.cat(
|
| 149 |
+
[
|
| 150 |
+
ready_style_code,
|
| 151 |
+
self.net.style_encoder(uncond_style_clip, uncond_pad_mask),
|
| 152 |
+
],
|
| 153 |
+
dim=0,
|
| 154 |
+
),
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
x_t = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
|
| 158 |
+
|
| 159 |
+
for idx in list(range(ddim_num_step))[::-1]:
|
| 160 |
+
t = time_steps[idx]
|
| 161 |
+
t_prev = prev_time_steps[idx]
|
| 162 |
+
ddim_alpha = self.var_sched.alpha_bars[t]
|
| 163 |
+
ddim_alpha_prev = self.var_sched.alpha_bars[t_prev]
|
| 164 |
+
|
| 165 |
+
t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
|
| 166 |
+
if use_cf_guidance:
|
| 167 |
+
x_t_double = torch.cat([x_t] * 2, dim=0)
|
| 168 |
+
t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
|
| 169 |
+
cond_output, uncond_output = self.net(
|
| 170 |
+
x_t_double, t=t_tensor_double, **context_double
|
| 171 |
+
).chunk(2)
|
| 172 |
+
diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
|
| 173 |
+
else:
|
| 174 |
+
diff_output = self.net(x_t, t=t_tensor, **context)
|
| 175 |
+
|
| 176 |
+
pred_x0 = diff_output
|
| 177 |
+
eps = (x_t - torch.sqrt(ddim_alpha) * pred_x0) / torch.sqrt(1 - ddim_alpha)
|
| 178 |
+
c1 = torch.sqrt(ddim_alpha_prev)
|
| 179 |
+
c2 = torch.sqrt(1 - ddim_alpha_prev)
|
| 180 |
+
|
| 181 |
+
x_t = c1 * pred_x0 + c2 * eps
|
| 182 |
+
|
| 183 |
+
latent_output = x_t
|
| 184 |
+
face3d_output = self._latent_to_face3d(latent_output)
|
| 185 |
+
return face3d_output
|
| 186 |
+
|
| 187 |
+
def sample(
|
| 188 |
+
self,
|
| 189 |
+
audio,
|
| 190 |
+
style_clip,
|
| 191 |
+
style_pad_mask,
|
| 192 |
+
output_dim,
|
| 193 |
+
flexibility=0.0,
|
| 194 |
+
ret_traj=False,
|
| 195 |
+
use_cf_guidance=False,
|
| 196 |
+
cfg_scale=2.0,
|
| 197 |
+
sample_method="ddpm",
|
| 198 |
+
ddim_num_step=50,
|
| 199 |
+
ready_style_code=None,
|
| 200 |
+
):
|
| 201 |
+
# sample_method = kwargs["sample_method"]
|
| 202 |
+
if sample_method == "ddpm":
|
| 203 |
+
if ready_style_code is not None:
|
| 204 |
+
raise NotImplementedError("ready style code in ddpm")
|
| 205 |
+
return self.ddpm_sample(
|
| 206 |
+
audio,
|
| 207 |
+
style_clip,
|
| 208 |
+
style_pad_mask,
|
| 209 |
+
output_dim,
|
| 210 |
+
flexibility=flexibility,
|
| 211 |
+
ret_traj=ret_traj,
|
| 212 |
+
use_cf_guidance=use_cf_guidance,
|
| 213 |
+
cfg_scale=cfg_scale,
|
| 214 |
+
)
|
| 215 |
+
elif sample_method == "ddim":
|
| 216 |
+
return self.ddim_sample(
|
| 217 |
+
audio,
|
| 218 |
+
style_clip,
|
| 219 |
+
style_pad_mask,
|
| 220 |
+
output_dim,
|
| 221 |
+
flexibility=flexibility,
|
| 222 |
+
ret_traj=ret_traj,
|
| 223 |
+
use_cf_guidance=use_cf_guidance,
|
| 224 |
+
cfg_scale=cfg_scale,
|
| 225 |
+
ddim_num_step=ddim_num_step,
|
| 226 |
+
ready_style_code=ready_style_code,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def ddpm_sample(
|
| 230 |
+
self,
|
| 231 |
+
audio,
|
| 232 |
+
style_clip,
|
| 233 |
+
style_pad_mask,
|
| 234 |
+
output_dim,
|
| 235 |
+
flexibility=0.0,
|
| 236 |
+
ret_traj=False,
|
| 237 |
+
use_cf_guidance=False,
|
| 238 |
+
cfg_scale=2.0,
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
audio (_type_): (B, L, W) or (B, L, W, C)
|
| 244 |
+
style_clip (_type_): (B, L_clipmax, C_face3d)
|
| 245 |
+
style_pad_mask : (B, L_clipmax)
|
| 246 |
+
pose_dim (_type_): int
|
| 247 |
+
flexibility (float, optional): _description_. Defaults to 0.0.
|
| 248 |
+
ret_traj (bool, optional): _description_. Defaults to False.
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
_type_: (B, L, C_face)
|
| 253 |
+
"""
|
| 254 |
+
batch_size, output_len = audio.shape[:2]
|
| 255 |
+
# batch_size = context.size(0)
|
| 256 |
+
context = {
|
| 257 |
+
"audio": audio,
|
| 258 |
+
"style_clip": style_clip,
|
| 259 |
+
"style_pad_mask": style_pad_mask,
|
| 260 |
+
}
|
| 261 |
+
if use_cf_guidance:
|
| 262 |
+
uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
|
| 263 |
+
batch_size, 1, 1
|
| 264 |
+
)
|
| 265 |
+
uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
|
| 266 |
+
context_double = {
|
| 267 |
+
"audio": torch.cat([audio] * 2, dim=0),
|
| 268 |
+
"style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
|
| 269 |
+
"style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
x_T = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
|
| 273 |
+
traj = {self.var_sched.num_steps: x_T}
|
| 274 |
+
for t in range(self.var_sched.num_steps, 0, -1):
|
| 275 |
+
alpha = self.var_sched.alphas[t]
|
| 276 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
| 277 |
+
alpha_bar_prev = self.var_sched.alpha_bars[t - 1]
|
| 278 |
+
sigma = self.var_sched.get_sigmas(t, flexibility)
|
| 279 |
+
|
| 280 |
+
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
|
| 281 |
+
x_t = traj[t]
|
| 282 |
+
t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
|
| 283 |
+
if use_cf_guidance:
|
| 284 |
+
x_t_double = torch.cat([x_t] * 2, dim=0)
|
| 285 |
+
t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
|
| 286 |
+
cond_output, uncond_output = self.net(
|
| 287 |
+
x_t_double, t=t_tensor_double, **context_double
|
| 288 |
+
).chunk(2)
|
| 289 |
+
diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
|
| 290 |
+
else:
|
| 291 |
+
diff_output = self.net(x_t, t=t_tensor, **context)
|
| 292 |
+
|
| 293 |
+
if self.predict_what == "noise":
|
| 294 |
+
c0 = 1.0 / torch.sqrt(alpha)
|
| 295 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
| 296 |
+
x_next = c0 * (x_t - c1 * diff_output) + sigma * z
|
| 297 |
+
elif self.predict_what == "x0":
|
| 298 |
+
d0 = torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar)
|
| 299 |
+
d1 = torch.sqrt(alpha_bar_prev) * (1 - alpha) / (1 - alpha_bar)
|
| 300 |
+
x_next = d0 * x_t + d1 * diff_output + sigma * z
|
| 301 |
+
traj[t - 1] = x_next.detach()
|
| 302 |
+
traj[t] = traj[t].cpu()
|
| 303 |
+
if not ret_traj:
|
| 304 |
+
del traj[t]
|
| 305 |
+
|
| 306 |
+
if ret_traj:
|
| 307 |
+
raise NotImplementedError
|
| 308 |
+
return traj
|
| 309 |
+
else:
|
| 310 |
+
latent_output = traj[0]
|
| 311 |
+
face3d_output = self._latent_to_face3d(latent_output)
|
| 312 |
+
return face3d_output
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
|
| 317 |
+
|
| 318 |
+
diffnet = DiffusionNet(
|
| 319 |
+
net=NoisePredictor(),
|
| 320 |
+
var_sched=VarianceSchedule(
|
| 321 |
+
num_steps=500, beta_1=1e-4, beta_T=0.02, mode="linear"
|
| 322 |
+
),
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
import torch
|
| 326 |
+
|
| 327 |
+
gt_face3d = torch.randn(16, 64, 64)
|
| 328 |
+
audio = torch.randn(16, 64, 11)
|
| 329 |
+
style_clip = torch.randn(16, 256, 64)
|
| 330 |
+
style_pad_mask = torch.ones(16, 256)
|
| 331 |
+
|
| 332 |
+
context = {
|
| 333 |
+
"audio": audio,
|
| 334 |
+
"style_clip": style_clip,
|
| 335 |
+
"style_pad_mask": style_pad_mask,
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
loss = diffnet.get_loss(gt_face3d, context)
|
| 339 |
+
|
| 340 |
+
print("hello")
|
damo/dreamtalk/core/networks/diffusion_util.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import Module
|
| 5 |
+
from core.networks import get_network
|
| 6 |
+
from core.utils import sinusoidal_embedding
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VarianceSchedule(Module):
|
| 10 |
+
def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
|
| 11 |
+
super().__init__()
|
| 12 |
+
assert mode in ("linear",)
|
| 13 |
+
self.num_steps = num_steps
|
| 14 |
+
self.beta_1 = beta_1
|
| 15 |
+
self.beta_T = beta_T
|
| 16 |
+
self.mode = mode
|
| 17 |
+
|
| 18 |
+
if mode == "linear":
|
| 19 |
+
betas = torch.linspace(beta_1, beta_T, steps=num_steps)
|
| 20 |
+
|
| 21 |
+
betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
|
| 22 |
+
|
| 23 |
+
alphas = 1 - betas
|
| 24 |
+
log_alphas = torch.log(alphas)
|
| 25 |
+
for i in range(1, log_alphas.size(0)): # 1 to T
|
| 26 |
+
log_alphas[i] += log_alphas[i - 1]
|
| 27 |
+
alpha_bars = log_alphas.exp()
|
| 28 |
+
|
| 29 |
+
sigmas_flex = torch.sqrt(betas)
|
| 30 |
+
sigmas_inflex = torch.zeros_like(sigmas_flex)
|
| 31 |
+
for i in range(1, sigmas_flex.size(0)):
|
| 32 |
+
sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
|
| 33 |
+
i
|
| 34 |
+
]
|
| 35 |
+
sigmas_inflex = torch.sqrt(sigmas_inflex)
|
| 36 |
+
|
| 37 |
+
self.register_buffer("betas", betas)
|
| 38 |
+
self.register_buffer("alphas", alphas)
|
| 39 |
+
self.register_buffer("alpha_bars", alpha_bars)
|
| 40 |
+
self.register_buffer("sigmas_flex", sigmas_flex)
|
| 41 |
+
self.register_buffer("sigmas_inflex", sigmas_inflex)
|
| 42 |
+
|
| 43 |
+
def uniform_sample_t(self, batch_size):
|
| 44 |
+
ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
|
| 45 |
+
return ts.tolist()
|
| 46 |
+
|
| 47 |
+
def get_sigmas(self, t, flexibility):
|
| 48 |
+
assert 0 <= flexibility and flexibility <= 1
|
| 49 |
+
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
|
| 50 |
+
1 - flexibility
|
| 51 |
+
)
|
| 52 |
+
return sigmas
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class NoisePredictor(nn.Module):
|
| 56 |
+
def __init__(self, cfg):
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
content_encoder_class = get_network(cfg.CONTENT_ENCODER_TYPE)
|
| 60 |
+
self.content_encoder = content_encoder_class(**cfg.CONTENT_ENCODER)
|
| 61 |
+
|
| 62 |
+
style_encoder_class = get_network(cfg.STYLE_ENCODER_TYPE)
|
| 63 |
+
cfg.defrost()
|
| 64 |
+
cfg.STYLE_ENCODER.input_dim = cfg.DATASET.FACE3D_DIM
|
| 65 |
+
cfg.freeze()
|
| 66 |
+
self.style_encoder = style_encoder_class(**cfg.STYLE_ENCODER)
|
| 67 |
+
|
| 68 |
+
decoder_class = get_network(cfg.DECODER_TYPE)
|
| 69 |
+
cfg.defrost()
|
| 70 |
+
cfg.DECODER.output_dim = cfg.DATASET.FACE3D_DIM
|
| 71 |
+
cfg.freeze()
|
| 72 |
+
self.decoder = decoder_class(**cfg.DECODER)
|
| 73 |
+
|
| 74 |
+
self.content_xt_to_decoder_input_wo_time = nn.Sequential(
|
| 75 |
+
nn.Linear(cfg.D_MODEL + cfg.DATASET.FACE3D_DIM, cfg.D_MODEL),
|
| 76 |
+
nn.ReLU(),
|
| 77 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
| 78 |
+
nn.ReLU(),
|
| 79 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.time_sinusoidal_dim = cfg.D_MODEL
|
| 83 |
+
self.time_embed_net = nn.Sequential(
|
| 84 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
| 85 |
+
nn.SiLU(),
|
| 86 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x_t, t, audio, style_clip, style_pad_mask, ready_style_code=None):
|
| 90 |
+
"""_summary_
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
x_t (_type_): (B, L, C_face)
|
| 94 |
+
t (_type_): (B,) dtype:float32
|
| 95 |
+
audio (_type_): (B, L, W)
|
| 96 |
+
style_clip (_type_): (B, L_clipmax, C_face3d)
|
| 97 |
+
style_pad_mask : (B, L_clipmax)
|
| 98 |
+
ready_style_code: (B, C_model)
|
| 99 |
+
Returns:
|
| 100 |
+
e_theta : (B, L, C_face)
|
| 101 |
+
"""
|
| 102 |
+
W = audio.shape[2]
|
| 103 |
+
content = self.content_encoder(audio)
|
| 104 |
+
# (B, L, W, C_model)
|
| 105 |
+
x_t_expand = x_t.unsqueeze(2).repeat(1, 1, W, 1)
|
| 106 |
+
# (B, L, C_face) -> (B, L, W, C_face)
|
| 107 |
+
content_xt_concat = torch.cat((content, x_t_expand), dim=3)
|
| 108 |
+
# (B, L, W, C_model+C_face)
|
| 109 |
+
decoder_input_without_time = self.content_xt_to_decoder_input_wo_time(
|
| 110 |
+
content_xt_concat
|
| 111 |
+
)
|
| 112 |
+
# (B, L, W, C_model)
|
| 113 |
+
|
| 114 |
+
time_sinusoidal = sinusoidal_embedding(t, self.time_sinusoidal_dim)
|
| 115 |
+
# (B, C_embed)
|
| 116 |
+
time_embedding = self.time_embed_net(time_sinusoidal)
|
| 117 |
+
# (B, C_model)
|
| 118 |
+
B, C = time_embedding.shape
|
| 119 |
+
time_embed_expand = time_embedding.view(B, 1, 1, C)
|
| 120 |
+
decoder_input = decoder_input_without_time + time_embed_expand
|
| 121 |
+
# (B, L, W, C_model)
|
| 122 |
+
|
| 123 |
+
if ready_style_code is not None:
|
| 124 |
+
style_code = ready_style_code
|
| 125 |
+
else:
|
| 126 |
+
style_code = self.style_encoder(style_clip, style_pad_mask)
|
| 127 |
+
# (B, C_model)
|
| 128 |
+
|
| 129 |
+
e_theta = self.decoder(decoder_input, style_code)
|
| 130 |
+
# (B, L, C_face)
|
| 131 |
+
return e_theta
|
damo/dreamtalk/core/networks/disentangle_decoder.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from .transformer import (
|
| 5 |
+
PositionalEncoding,
|
| 6 |
+
TransformerDecoderLayer,
|
| 7 |
+
TransformerDecoder,
|
| 8 |
+
)
|
| 9 |
+
from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder
|
| 10 |
+
from core.utils import _reset_parameters
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_decoder_network(
|
| 14 |
+
network_type,
|
| 15 |
+
d_model,
|
| 16 |
+
nhead,
|
| 17 |
+
dim_feedforward,
|
| 18 |
+
dropout,
|
| 19 |
+
activation,
|
| 20 |
+
normalize_before,
|
| 21 |
+
num_decoder_layers,
|
| 22 |
+
return_intermediate_dec,
|
| 23 |
+
dynamic_K,
|
| 24 |
+
dynamic_ratio,
|
| 25 |
+
):
|
| 26 |
+
decoder = None
|
| 27 |
+
if network_type == "TransformerDecoder":
|
| 28 |
+
decoder_layer = TransformerDecoderLayer(
|
| 29 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 30 |
+
)
|
| 31 |
+
norm = nn.LayerNorm(d_model)
|
| 32 |
+
decoder = TransformerDecoder(
|
| 33 |
+
decoder_layer,
|
| 34 |
+
num_decoder_layers,
|
| 35 |
+
norm,
|
| 36 |
+
return_intermediate_dec,
|
| 37 |
+
)
|
| 38 |
+
elif network_type == "DynamicFCDecoder":
|
| 39 |
+
d_style = d_model
|
| 40 |
+
decoder_layer = DynamicFCDecoderLayer(
|
| 41 |
+
d_model,
|
| 42 |
+
nhead,
|
| 43 |
+
d_style,
|
| 44 |
+
dynamic_K,
|
| 45 |
+
dynamic_ratio,
|
| 46 |
+
dim_feedforward,
|
| 47 |
+
dropout,
|
| 48 |
+
activation,
|
| 49 |
+
normalize_before,
|
| 50 |
+
)
|
| 51 |
+
norm = nn.LayerNorm(d_model)
|
| 52 |
+
decoder = DynamicFCDecoder(
|
| 53 |
+
decoder_layer, num_decoder_layers, norm, return_intermediate_dec
|
| 54 |
+
)
|
| 55 |
+
elif network_type == "DynamicFCEncoder":
|
| 56 |
+
d_style = d_model
|
| 57 |
+
decoder_layer = DynamicFCEncoderLayer(
|
| 58 |
+
d_model,
|
| 59 |
+
nhead,
|
| 60 |
+
d_style,
|
| 61 |
+
dynamic_K,
|
| 62 |
+
dynamic_ratio,
|
| 63 |
+
dim_feedforward,
|
| 64 |
+
dropout,
|
| 65 |
+
activation,
|
| 66 |
+
normalize_before,
|
| 67 |
+
)
|
| 68 |
+
norm = nn.LayerNorm(d_model)
|
| 69 |
+
decoder = DynamicFCEncoder(decoder_layer, num_decoder_layers, norm)
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Invalid network_type {network_type}")
|
| 73 |
+
|
| 74 |
+
return decoder
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DisentangleDecoder(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
d_model=512,
|
| 81 |
+
nhead=8,
|
| 82 |
+
num_decoder_layers=3,
|
| 83 |
+
dim_feedforward=2048,
|
| 84 |
+
dropout=0.1,
|
| 85 |
+
activation="relu",
|
| 86 |
+
normalize_before=False,
|
| 87 |
+
return_intermediate_dec=False,
|
| 88 |
+
pos_embed_len=80,
|
| 89 |
+
upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))),
|
| 90 |
+
lower_face3d_indices=tuple(range(19, 46)),
|
| 91 |
+
network_type="None",
|
| 92 |
+
dynamic_K=None,
|
| 93 |
+
dynamic_ratio=None,
|
| 94 |
+
**_,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.upper_face3d_indices = upper_face3d_indices
|
| 99 |
+
self.lower_face3d_indices = lower_face3d_indices
|
| 100 |
+
|
| 101 |
+
# upper_decoder_layer = TransformerDecoderLayer(
|
| 102 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 103 |
+
# )
|
| 104 |
+
# upper_decoder_norm = nn.LayerNorm(d_model)
|
| 105 |
+
# self.upper_decoder = TransformerDecoder(
|
| 106 |
+
# upper_decoder_layer,
|
| 107 |
+
# num_decoder_layers,
|
| 108 |
+
# upper_decoder_norm,
|
| 109 |
+
# return_intermediate=return_intermediate_dec,
|
| 110 |
+
# )
|
| 111 |
+
self.upper_decoder = get_decoder_network(
|
| 112 |
+
network_type,
|
| 113 |
+
d_model,
|
| 114 |
+
nhead,
|
| 115 |
+
dim_feedforward,
|
| 116 |
+
dropout,
|
| 117 |
+
activation,
|
| 118 |
+
normalize_before,
|
| 119 |
+
num_decoder_layers,
|
| 120 |
+
return_intermediate_dec,
|
| 121 |
+
dynamic_K,
|
| 122 |
+
dynamic_ratio,
|
| 123 |
+
)
|
| 124 |
+
_reset_parameters(self.upper_decoder)
|
| 125 |
+
|
| 126 |
+
# lower_decoder_layer = TransformerDecoderLayer(
|
| 127 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 128 |
+
# )
|
| 129 |
+
# lower_decoder_norm = nn.LayerNorm(d_model)
|
| 130 |
+
# self.lower_decoder = TransformerDecoder(
|
| 131 |
+
# lower_decoder_layer,
|
| 132 |
+
# num_decoder_layers,
|
| 133 |
+
# lower_decoder_norm,
|
| 134 |
+
# return_intermediate=return_intermediate_dec,
|
| 135 |
+
# )
|
| 136 |
+
self.lower_decoder = get_decoder_network(
|
| 137 |
+
network_type,
|
| 138 |
+
d_model,
|
| 139 |
+
nhead,
|
| 140 |
+
dim_feedforward,
|
| 141 |
+
dropout,
|
| 142 |
+
activation,
|
| 143 |
+
normalize_before,
|
| 144 |
+
num_decoder_layers,
|
| 145 |
+
return_intermediate_dec,
|
| 146 |
+
dynamic_K,
|
| 147 |
+
dynamic_ratio,
|
| 148 |
+
)
|
| 149 |
+
_reset_parameters(self.lower_decoder)
|
| 150 |
+
|
| 151 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
| 152 |
+
|
| 153 |
+
tail_hidden_dim = d_model // 2
|
| 154 |
+
self.upper_tail_fc = nn.Sequential(
|
| 155 |
+
nn.Linear(d_model, tail_hidden_dim),
|
| 156 |
+
nn.ReLU(),
|
| 157 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
| 158 |
+
nn.ReLU(),
|
| 159 |
+
nn.Linear(tail_hidden_dim, len(upper_face3d_indices)),
|
| 160 |
+
)
|
| 161 |
+
self.lower_tail_fc = nn.Sequential(
|
| 162 |
+
nn.Linear(d_model, tail_hidden_dim),
|
| 163 |
+
nn.ReLU(),
|
| 164 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
| 165 |
+
nn.ReLU(),
|
| 166 |
+
nn.Linear(tail_hidden_dim, len(lower_face3d_indices)),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def forward(self, content, style_code):
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
content (_type_): (B, num_frames, window, C_dmodel)
|
| 174 |
+
style_code (_type_): (B, C_dmodel)
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
face3d: (B, L_clip, C_3dmm)
|
| 178 |
+
"""
|
| 179 |
+
B, N, W, C = content.shape
|
| 180 |
+
style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
|
| 181 |
+
style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
| 182 |
+
# (W, B*N, C)
|
| 183 |
+
|
| 184 |
+
content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
| 185 |
+
# (W, B*N, C)
|
| 186 |
+
tgt = torch.zeros_like(style)
|
| 187 |
+
pos_embed = self.pos_embed(W)
|
| 188 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
| 189 |
+
|
| 190 |
+
upper_face3d_feat = self.upper_decoder(
|
| 191 |
+
tgt, content, pos=pos_embed, query_pos=style
|
| 192 |
+
)[0]
|
| 193 |
+
# (W, B*N, C)
|
| 194 |
+
upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
|
| 195 |
+
:, :, W // 2, :
|
| 196 |
+
]
|
| 197 |
+
# (B, N, C)
|
| 198 |
+
upper_face3d = self.upper_tail_fc(upper_face3d_feat)
|
| 199 |
+
# (B, N, C_exp)
|
| 200 |
+
|
| 201 |
+
lower_face3d_feat = self.lower_decoder(
|
| 202 |
+
tgt, content, pos=pos_embed, query_pos=style
|
| 203 |
+
)[0]
|
| 204 |
+
lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
|
| 205 |
+
:, :, W // 2, :
|
| 206 |
+
]
|
| 207 |
+
lower_face3d = self.lower_tail_fc(lower_face3d_feat)
|
| 208 |
+
C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices)
|
| 209 |
+
face3d = torch.zeros(B, N, C_exp).to(upper_face3d)
|
| 210 |
+
face3d[:, :, self.upper_face3d_indices] = upper_face3d
|
| 211 |
+
face3d[:, :, self.lower_face3d_indices] = lower_face3d
|
| 212 |
+
return face3d
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
import sys
|
| 217 |
+
|
| 218 |
+
sys.path.append("/home/mayifeng/Research/styleTH")
|
| 219 |
+
|
| 220 |
+
from configs.default import get_cfg_defaults
|
| 221 |
+
|
| 222 |
+
cfg = get_cfg_defaults()
|
| 223 |
+
cfg.merge_from_file("configs/styleTH_unpair_lsfm_emotion.yaml")
|
| 224 |
+
cfg.freeze()
|
| 225 |
+
|
| 226 |
+
# content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
|
| 227 |
+
|
| 228 |
+
# dummy_audio = torch.randint(0, 41, (5, 64, 11))
|
| 229 |
+
# dummy_content = content_encoder(dummy_audio)
|
| 230 |
+
|
| 231 |
+
# style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
|
| 232 |
+
# dummy_face3d_seq = torch.randn(5, 64, 64)
|
| 233 |
+
# dummy_style_code = style_encoder(dummy_face3d_seq)
|
| 234 |
+
|
| 235 |
+
decoder = DisentangleDecoder(**cfg.DECODER)
|
| 236 |
+
dummy_content = torch.randn(5, 64, 11, 256)
|
| 237 |
+
dummy_style = torch.randn(5, 256)
|
| 238 |
+
dummy_output = decoder(dummy_content, dummy_style)
|
| 239 |
+
|
| 240 |
+
print("hello")
|
damo/dreamtalk/core/networks/dynamic_conv.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Attention(nn.Module):
|
| 9 |
+
def __init__(self, cond_planes, ratio, K, temperature=30, init_weight=True):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 12 |
+
self.temprature = temperature
|
| 13 |
+
assert cond_planes > ratio
|
| 14 |
+
hidden_planes = cond_planes // ratio
|
| 15 |
+
self.net = nn.Sequential(
|
| 16 |
+
nn.Conv2d(cond_planes, hidden_planes, kernel_size=1, bias=False),
|
| 17 |
+
nn.ReLU(),
|
| 18 |
+
nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if init_weight:
|
| 22 |
+
self._initialize_weights()
|
| 23 |
+
|
| 24 |
+
def update_temprature(self):
|
| 25 |
+
if self.temprature > 1:
|
| 26 |
+
self.temprature -= 1
|
| 27 |
+
|
| 28 |
+
def _initialize_weights(self):
|
| 29 |
+
for m in self.modules():
|
| 30 |
+
if isinstance(m, nn.Conv2d):
|
| 31 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 32 |
+
if m.bias is not None:
|
| 33 |
+
nn.init.constant_(m.bias, 0)
|
| 34 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 35 |
+
nn.init.constant_(m.weight, 1)
|
| 36 |
+
nn.init.constant_(m.bias, 0)
|
| 37 |
+
|
| 38 |
+
def forward(self, cond):
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
cond (_type_): (B, C_style)
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
_type_: (B, K)
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# att = self.avgpool(cond) # bs,dim,1,1
|
| 49 |
+
att = cond.view(cond.shape[0], cond.shape[1], 1, 1)
|
| 50 |
+
att = self.net(att).view(cond.shape[0], -1) # bs,K
|
| 51 |
+
return F.softmax(att / self.temprature, -1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DynamicConv(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_planes,
|
| 58 |
+
out_planes,
|
| 59 |
+
cond_planes,
|
| 60 |
+
kernel_size,
|
| 61 |
+
stride,
|
| 62 |
+
padding=0,
|
| 63 |
+
dilation=1,
|
| 64 |
+
groups=1,
|
| 65 |
+
bias=True,
|
| 66 |
+
K=4,
|
| 67 |
+
temperature=30,
|
| 68 |
+
ratio=4,
|
| 69 |
+
init_weight=True,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.in_planes = in_planes
|
| 73 |
+
self.out_planes = out_planes
|
| 74 |
+
self.cond_planes = cond_planes
|
| 75 |
+
self.kernel_size = kernel_size
|
| 76 |
+
self.stride = stride
|
| 77 |
+
self.padding = padding
|
| 78 |
+
self.dilation = dilation
|
| 79 |
+
self.groups = groups
|
| 80 |
+
self.bias = bias
|
| 81 |
+
self.K = K
|
| 82 |
+
self.init_weight = init_weight
|
| 83 |
+
self.attention = Attention(
|
| 84 |
+
cond_planes=cond_planes, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.weight = nn.Parameter(
|
| 88 |
+
torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size), requires_grad=True
|
| 89 |
+
)
|
| 90 |
+
if bias:
|
| 91 |
+
self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
|
| 92 |
+
else:
|
| 93 |
+
self.bias = None
|
| 94 |
+
|
| 95 |
+
if self.init_weight:
|
| 96 |
+
self._initialize_weights()
|
| 97 |
+
|
| 98 |
+
def _initialize_weights(self):
|
| 99 |
+
for i in range(self.K):
|
| 100 |
+
nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
|
| 101 |
+
if self.bias is not None:
|
| 102 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
|
| 103 |
+
if fan_in != 0:
|
| 104 |
+
bound = 1 / math.sqrt(fan_in)
|
| 105 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, cond):
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
x (_type_): (B, C_in, L, 1)
|
| 112 |
+
cond (_type_): (B, C_style)
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
_type_: (B, C_out, L, 1)
|
| 116 |
+
"""
|
| 117 |
+
bs, in_planels, h, w = x.shape
|
| 118 |
+
softmax_att = self.attention(cond) # bs,K
|
| 119 |
+
x = x.view(1, -1, h, w)
|
| 120 |
+
weight = self.weight.view(self.K, -1) # K,-1
|
| 121 |
+
aggregate_weight = torch.mm(softmax_att, weight).view(
|
| 122 |
+
bs * self.out_planes, self.in_planes // self.groups, self.kernel_size, self.kernel_size
|
| 123 |
+
) # bs*out_p,in_p,k,k
|
| 124 |
+
|
| 125 |
+
if self.bias is not None:
|
| 126 |
+
bias = self.bias.view(self.K, -1) # K,out_p
|
| 127 |
+
aggregate_bias = torch.mm(softmax_att, bias).view(-1) # bs*out_p
|
| 128 |
+
output = F.conv2d(
|
| 129 |
+
x, # 1, bs*in_p, L, 1
|
| 130 |
+
weight=aggregate_weight,
|
| 131 |
+
bias=aggregate_bias,
|
| 132 |
+
stride=self.stride,
|
| 133 |
+
padding=self.padding,
|
| 134 |
+
groups=self.groups * bs,
|
| 135 |
+
dilation=self.dilation,
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
output = F.conv2d(
|
| 139 |
+
x,
|
| 140 |
+
weight=aggregate_weight,
|
| 141 |
+
bias=None,
|
| 142 |
+
stride=self.stride,
|
| 143 |
+
padding=self.padding,
|
| 144 |
+
groups=self.groups * bs,
|
| 145 |
+
dilation=self.dilation,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
output = output.view(bs, self.out_planes, h, w)
|
| 149 |
+
return output
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
input = torch.randn(3, 32, 64, 64)
|
| 154 |
+
m = DynamicConv(in_planes=32, out_planes=64, kernel_size=3, stride=1, padding=1, bias=True)
|
| 155 |
+
out = m(input)
|
| 156 |
+
print(out.shape)
|
damo/dreamtalk/core/networks/dynamic_fc_decoder.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from core.networks.transformer import _get_activation_fn, _get_clones
|
| 5 |
+
from core.networks.dynamic_linear import DynamicLinear
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DynamicFCDecoderLayer(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
d_model,
|
| 12 |
+
nhead,
|
| 13 |
+
d_style,
|
| 14 |
+
dynamic_K,
|
| 15 |
+
dynamic_ratio,
|
| 16 |
+
dim_feedforward=2048,
|
| 17 |
+
dropout=0.1,
|
| 18 |
+
activation="relu",
|
| 19 |
+
normalize_before=False,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 23 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 24 |
+
# Implementation of Feedforward model
|
| 25 |
+
# self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 26 |
+
self.linear1 = DynamicLinear(d_model, dim_feedforward, d_style, K=dynamic_K, ratio=dynamic_ratio)
|
| 27 |
+
self.dropout = nn.Dropout(dropout)
|
| 28 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 29 |
+
# self.linear2 = DynamicLinear(dim_feedforward, d_model, d_style, K=dynamic_K, ratio=dynamic_ratio)
|
| 30 |
+
|
| 31 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 32 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 33 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 34 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 35 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 36 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 37 |
+
|
| 38 |
+
self.activation = _get_activation_fn(activation)
|
| 39 |
+
self.normalize_before = normalize_before
|
| 40 |
+
|
| 41 |
+
def with_pos_embed(self, tensor, pos):
|
| 42 |
+
return tensor if pos is None else tensor + pos
|
| 43 |
+
|
| 44 |
+
def forward_post(
|
| 45 |
+
self,
|
| 46 |
+
tgt,
|
| 47 |
+
memory,
|
| 48 |
+
style,
|
| 49 |
+
tgt_mask=None,
|
| 50 |
+
memory_mask=None,
|
| 51 |
+
tgt_key_padding_mask=None,
|
| 52 |
+
memory_key_padding_mask=None,
|
| 53 |
+
pos=None,
|
| 54 |
+
query_pos=None,
|
| 55 |
+
):
|
| 56 |
+
# q = k = self.with_pos_embed(tgt, query_pos)
|
| 57 |
+
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
| 58 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 59 |
+
tgt = self.norm1(tgt)
|
| 60 |
+
tgt2 = self.multihead_attn(
|
| 61 |
+
query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
|
| 62 |
+
)[0]
|
| 63 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 64 |
+
tgt = self.norm2(tgt)
|
| 65 |
+
# tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))), style)
|
| 66 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))))
|
| 67 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 68 |
+
tgt = self.norm3(tgt)
|
| 69 |
+
return tgt
|
| 70 |
+
|
| 71 |
+
# def forward_pre(
|
| 72 |
+
# self,
|
| 73 |
+
# tgt,
|
| 74 |
+
# memory,
|
| 75 |
+
# tgt_mask=None,
|
| 76 |
+
# memory_mask=None,
|
| 77 |
+
# tgt_key_padding_mask=None,
|
| 78 |
+
# memory_key_padding_mask=None,
|
| 79 |
+
# pos=None,
|
| 80 |
+
# query_pos=None,
|
| 81 |
+
# ):
|
| 82 |
+
# tgt2 = self.norm1(tgt)
|
| 83 |
+
# # q = k = self.with_pos_embed(tgt2, query_pos)
|
| 84 |
+
# tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
| 85 |
+
# tgt = tgt + self.dropout1(tgt2)
|
| 86 |
+
# tgt2 = self.norm2(tgt)
|
| 87 |
+
# tgt2 = self.multihead_attn(
|
| 88 |
+
# query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
|
| 89 |
+
# )[0]
|
| 90 |
+
# tgt = tgt + self.dropout2(tgt2)
|
| 91 |
+
# tgt2 = self.norm3(tgt)
|
| 92 |
+
# tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 93 |
+
# tgt = tgt + self.dropout3(tgt2)
|
| 94 |
+
# return tgt
|
| 95 |
+
|
| 96 |
+
def forward(
|
| 97 |
+
self,
|
| 98 |
+
tgt,
|
| 99 |
+
memory,
|
| 100 |
+
style,
|
| 101 |
+
tgt_mask=None,
|
| 102 |
+
memory_mask=None,
|
| 103 |
+
tgt_key_padding_mask=None,
|
| 104 |
+
memory_key_padding_mask=None,
|
| 105 |
+
pos=None,
|
| 106 |
+
query_pos=None,
|
| 107 |
+
):
|
| 108 |
+
if self.normalize_before:
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
# return self.forward_pre(
|
| 111 |
+
# tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
| 112 |
+
# )
|
| 113 |
+
return self.forward_post(
|
| 114 |
+
tgt, memory, style, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class DynamicFCDecoder(nn.Module):
|
| 119 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 122 |
+
self.num_layers = num_layers
|
| 123 |
+
self.norm = norm
|
| 124 |
+
self.return_intermediate = return_intermediate
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
tgt,
|
| 129 |
+
memory,
|
| 130 |
+
tgt_mask=None,
|
| 131 |
+
memory_mask=None,
|
| 132 |
+
tgt_key_padding_mask=None,
|
| 133 |
+
memory_key_padding_mask=None,
|
| 134 |
+
pos=None,
|
| 135 |
+
query_pos=None,
|
| 136 |
+
):
|
| 137 |
+
style = query_pos[0]
|
| 138 |
+
# (B*N, C)
|
| 139 |
+
output = tgt + pos + query_pos
|
| 140 |
+
|
| 141 |
+
intermediate = []
|
| 142 |
+
|
| 143 |
+
for layer in self.layers:
|
| 144 |
+
output = layer(
|
| 145 |
+
output,
|
| 146 |
+
memory,
|
| 147 |
+
style,
|
| 148 |
+
tgt_mask=tgt_mask,
|
| 149 |
+
memory_mask=memory_mask,
|
| 150 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 151 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 152 |
+
pos=pos,
|
| 153 |
+
query_pos=query_pos,
|
| 154 |
+
)
|
| 155 |
+
if self.return_intermediate:
|
| 156 |
+
intermediate.append(self.norm(output))
|
| 157 |
+
|
| 158 |
+
if self.norm is not None:
|
| 159 |
+
output = self.norm(output)
|
| 160 |
+
if self.return_intermediate:
|
| 161 |
+
intermediate.pop()
|
| 162 |
+
intermediate.append(output)
|
| 163 |
+
|
| 164 |
+
if self.return_intermediate:
|
| 165 |
+
return torch.stack(intermediate)
|
| 166 |
+
|
| 167 |
+
return output.unsqueeze(0)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
query = torch.randn(11, 1024, 256)
|
| 172 |
+
content = torch.randn(11, 1024, 256)
|
| 173 |
+
style = torch.randn(1024, 256)
|
| 174 |
+
pos = torch.randn(11, 1, 256)
|
| 175 |
+
m = DynamicFCDecoderLayer(256, 4, 256, 4, 4, 1024)
|
| 176 |
+
|
| 177 |
+
out = m(query, content, style, pos=pos)
|
| 178 |
+
print(out.shape)
|
damo/dreamtalk/core/networks/dynamic_linear.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from core.networks.dynamic_conv import DynamicConv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DynamicLinear(nn.Module):
|
| 11 |
+
def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.dynamic_conv = DynamicConv(
|
| 15 |
+
in_planes,
|
| 16 |
+
out_planes,
|
| 17 |
+
cond_planes,
|
| 18 |
+
kernel_size=1,
|
| 19 |
+
stride=1,
|
| 20 |
+
padding=0,
|
| 21 |
+
bias=bias,
|
| 22 |
+
K=K,
|
| 23 |
+
ratio=ratio,
|
| 24 |
+
temperature=temperature,
|
| 25 |
+
init_weight=init_weight,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(self, x, cond):
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
x (_type_): (L, B, C_in)
|
| 33 |
+
cond (_type_): (B, C_style)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
_type_: (L, B, C_out)
|
| 37 |
+
"""
|
| 38 |
+
x = x.permute(1, 2, 0).unsqueeze(-1)
|
| 39 |
+
out = self.dynamic_conv(x, cond)
|
| 40 |
+
# (B, C_out, L, 1)
|
| 41 |
+
out = out.squeeze().permute(2, 0, 1)
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
input = torch.randn(11, 1024, 255)
|
| 47 |
+
cond = torch.randn(1024, 256)
|
| 48 |
+
m = DynamicLinear(255, 1000, 256, K=7, temperature=5, ratio=8)
|
| 49 |
+
out = m(input, cond)
|
| 50 |
+
print(out.shape)
|
damo/dreamtalk/core/networks/generator.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from .transformer import (
|
| 5 |
+
TransformerEncoder,
|
| 6 |
+
TransformerEncoderLayer,
|
| 7 |
+
PositionalEncoding,
|
| 8 |
+
TransformerDecoderLayer,
|
| 9 |
+
TransformerDecoder,
|
| 10 |
+
)
|
| 11 |
+
from core.utils import _reset_parameters
|
| 12 |
+
from core.networks.self_attention_pooling import SelfAttentionPooling
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# class ContentEncoder(nn.Module):
|
| 16 |
+
# def __init__(
|
| 17 |
+
# self,
|
| 18 |
+
# d_model=512,
|
| 19 |
+
# nhead=8,
|
| 20 |
+
# num_encoder_layers=6,
|
| 21 |
+
# dim_feedforward=2048,
|
| 22 |
+
# dropout=0.1,
|
| 23 |
+
# activation="relu",
|
| 24 |
+
# normalize_before=False,
|
| 25 |
+
# pos_embed_len=80,
|
| 26 |
+
# ph_embed_dim=128,
|
| 27 |
+
# ):
|
| 28 |
+
# super().__init__()
|
| 29 |
+
|
| 30 |
+
# encoder_layer = TransformerEncoderLayer(
|
| 31 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 32 |
+
# )
|
| 33 |
+
# encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 34 |
+
# self.encoder = TransformerEncoder(
|
| 35 |
+
# encoder_layer, num_encoder_layers, encoder_norm
|
| 36 |
+
# )
|
| 37 |
+
|
| 38 |
+
# _reset_parameters(self.encoder)
|
| 39 |
+
|
| 40 |
+
# self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
| 41 |
+
|
| 42 |
+
# self.ph_embedding = nn.Embedding(41, ph_embed_dim)
|
| 43 |
+
# self.increase_embed_dim = nn.Linear(ph_embed_dim, d_model)
|
| 44 |
+
|
| 45 |
+
# def forward(self, x):
|
| 46 |
+
# """
|
| 47 |
+
|
| 48 |
+
# Args:
|
| 49 |
+
# x (_type_): (B, num_frames, window)
|
| 50 |
+
|
| 51 |
+
# Returns:
|
| 52 |
+
# content: (B, num_frames, window, C_dmodel)
|
| 53 |
+
# """
|
| 54 |
+
# x_embedding = self.ph_embedding(x)
|
| 55 |
+
# x_embedding = self.increase_embed_dim(x_embedding)
|
| 56 |
+
# # (B, N, W, C)
|
| 57 |
+
# B, N, W, C = x_embedding.shape
|
| 58 |
+
# x_embedding = x_embedding.reshape(B * N, W, C)
|
| 59 |
+
# x_embedding = x_embedding.permute(1, 0, 2)
|
| 60 |
+
# # (W, B*N, C)
|
| 61 |
+
|
| 62 |
+
# pos = self.pos_embed(W)
|
| 63 |
+
# pos = pos.permute(1, 0, 2)
|
| 64 |
+
# # (W, 1, C)
|
| 65 |
+
|
| 66 |
+
# content = self.encoder(x_embedding, pos=pos)
|
| 67 |
+
# # (W, B*N, C)
|
| 68 |
+
# content = content.permute(1, 0, 2).reshape(B, N, W, C)
|
| 69 |
+
# # (B, N, W, C)
|
| 70 |
+
|
| 71 |
+
# return content
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ContentW2VEncoder(nn.Module):
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
d_model=512,
|
| 78 |
+
nhead=8,
|
| 79 |
+
num_encoder_layers=6,
|
| 80 |
+
dim_feedforward=2048,
|
| 81 |
+
dropout=0.1,
|
| 82 |
+
activation="relu",
|
| 83 |
+
normalize_before=False,
|
| 84 |
+
pos_embed_len=80,
|
| 85 |
+
ph_embed_dim=128,
|
| 86 |
+
):
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
encoder_layer = TransformerEncoderLayer(
|
| 90 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 91 |
+
)
|
| 92 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 93 |
+
self.encoder = TransformerEncoder(
|
| 94 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
_reset_parameters(self.encoder)
|
| 98 |
+
|
| 99 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
| 100 |
+
|
| 101 |
+
self.increase_embed_dim = nn.Linear(1024, d_model)
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
"""
|
| 105 |
+
Args:
|
| 106 |
+
x (_type_): (B, num_frames, window, C_wav2vec)
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
content: (B, num_frames, window, C_dmodel)
|
| 110 |
+
"""
|
| 111 |
+
x_embedding = self.increase_embed_dim(
|
| 112 |
+
x
|
| 113 |
+
) # [16, 64, 11, 1024] -> [16, 64, 11, 256]
|
| 114 |
+
# (B, N, W, C)
|
| 115 |
+
B, N, W, C = x_embedding.shape
|
| 116 |
+
x_embedding = x_embedding.reshape(B * N, W, C)
|
| 117 |
+
x_embedding = x_embedding.permute(1, 0, 2) # [11, 1024, 256]
|
| 118 |
+
# (W, B*N, C)
|
| 119 |
+
|
| 120 |
+
pos = self.pos_embed(W)
|
| 121 |
+
pos = pos.permute(1, 0, 2) # [11, 1, 256]
|
| 122 |
+
# (W, 1, C)
|
| 123 |
+
|
| 124 |
+
content = self.encoder(x_embedding, pos=pos) # [11, 1024, 256]
|
| 125 |
+
# (W, B*N, C)
|
| 126 |
+
content = content.permute(1, 0, 2).reshape(B, N, W, C)
|
| 127 |
+
# (B, N, W, C)
|
| 128 |
+
|
| 129 |
+
return content
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class StyleEncoder(nn.Module):
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
d_model=512,
|
| 136 |
+
nhead=8,
|
| 137 |
+
num_encoder_layers=6,
|
| 138 |
+
dim_feedforward=2048,
|
| 139 |
+
dropout=0.1,
|
| 140 |
+
activation="relu",
|
| 141 |
+
normalize_before=False,
|
| 142 |
+
pos_embed_len=80,
|
| 143 |
+
input_dim=128,
|
| 144 |
+
aggregate_method="average",
|
| 145 |
+
):
|
| 146 |
+
super().__init__()
|
| 147 |
+
encoder_layer = TransformerEncoderLayer(
|
| 148 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 149 |
+
)
|
| 150 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 151 |
+
self.encoder = TransformerEncoder(
|
| 152 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
| 153 |
+
)
|
| 154 |
+
_reset_parameters(self.encoder)
|
| 155 |
+
|
| 156 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
| 157 |
+
|
| 158 |
+
self.increase_embed_dim = nn.Linear(input_dim, d_model)
|
| 159 |
+
|
| 160 |
+
self.aggregate_method = None
|
| 161 |
+
if aggregate_method == "self_attention_pooling":
|
| 162 |
+
self.aggregate_method = SelfAttentionPooling(d_model)
|
| 163 |
+
elif aggregate_method == "average":
|
| 164 |
+
pass
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Invalid aggregate method {aggregate_method}")
|
| 167 |
+
|
| 168 |
+
def forward(self, x, pad_mask=None):
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
x (_type_): (B, num_frames(L), C_exp)
|
| 173 |
+
pad_mask: (B, num_frames)
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
style_code: (B, C_model)
|
| 177 |
+
"""
|
| 178 |
+
x = self.increase_embed_dim(x)
|
| 179 |
+
# (B, L, C)
|
| 180 |
+
x = x.permute(1, 0, 2)
|
| 181 |
+
# (L, B, C)
|
| 182 |
+
|
| 183 |
+
pos = self.pos_embed(x.shape[0])
|
| 184 |
+
pos = pos.permute(1, 0, 2)
|
| 185 |
+
# (L, 1, C)
|
| 186 |
+
|
| 187 |
+
style = self.encoder(x, pos=pos, src_key_padding_mask=pad_mask)
|
| 188 |
+
# (L, B, C)
|
| 189 |
+
|
| 190 |
+
if self.aggregate_method is not None:
|
| 191 |
+
permute_style = style.permute(1, 0, 2)
|
| 192 |
+
# (B, L, C)
|
| 193 |
+
style_code = self.aggregate_method(permute_style, pad_mask)
|
| 194 |
+
return style_code
|
| 195 |
+
|
| 196 |
+
if pad_mask is None:
|
| 197 |
+
style = style.permute(1, 2, 0)
|
| 198 |
+
# (B, C, L)
|
| 199 |
+
style_code = style.mean(2)
|
| 200 |
+
# (B, C)
|
| 201 |
+
else:
|
| 202 |
+
permute_style = style.permute(1, 0, 2)
|
| 203 |
+
# (B, L, C)
|
| 204 |
+
permute_style[pad_mask] = 0
|
| 205 |
+
sum_style_code = permute_style.sum(dim=1)
|
| 206 |
+
# (B, C)
|
| 207 |
+
valid_token_num = (~pad_mask).sum(dim=1).unsqueeze(-1)
|
| 208 |
+
# (B, 1)
|
| 209 |
+
style_code = sum_style_code / valid_token_num
|
| 210 |
+
# (B, C)
|
| 211 |
+
|
| 212 |
+
return style_code
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class Decoder(nn.Module):
|
| 216 |
+
def __init__(
|
| 217 |
+
self,
|
| 218 |
+
d_model=512,
|
| 219 |
+
nhead=8,
|
| 220 |
+
num_decoder_layers=3,
|
| 221 |
+
dim_feedforward=2048,
|
| 222 |
+
dropout=0.1,
|
| 223 |
+
activation="relu",
|
| 224 |
+
normalize_before=False,
|
| 225 |
+
return_intermediate_dec=False,
|
| 226 |
+
pos_embed_len=80,
|
| 227 |
+
output_dim=64,
|
| 228 |
+
**_,
|
| 229 |
+
) -> None:
|
| 230 |
+
super().__init__()
|
| 231 |
+
|
| 232 |
+
decoder_layer = TransformerDecoderLayer(
|
| 233 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
| 234 |
+
)
|
| 235 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 236 |
+
self.decoder = TransformerDecoder(
|
| 237 |
+
decoder_layer,
|
| 238 |
+
num_decoder_layers,
|
| 239 |
+
decoder_norm,
|
| 240 |
+
return_intermediate=return_intermediate_dec,
|
| 241 |
+
)
|
| 242 |
+
_reset_parameters(self.decoder)
|
| 243 |
+
|
| 244 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
| 245 |
+
|
| 246 |
+
tail_hidden_dim = d_model // 2
|
| 247 |
+
self.tail_fc = nn.Sequential(
|
| 248 |
+
nn.Linear(d_model, tail_hidden_dim),
|
| 249 |
+
nn.ReLU(),
|
| 250 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
| 251 |
+
nn.ReLU(),
|
| 252 |
+
nn.Linear(tail_hidden_dim, output_dim),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def forward(self, content, style_code):
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
content (_type_): (B, num_frames, window, C_dmodel)
|
| 260 |
+
style_code (_type_): (B, C_dmodel)
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
face3d: (B, num_frames, C_3dmm)
|
| 264 |
+
"""
|
| 265 |
+
B, N, W, C = content.shape
|
| 266 |
+
style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
|
| 267 |
+
style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
| 268 |
+
# (W, B*N, C)
|
| 269 |
+
|
| 270 |
+
content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
| 271 |
+
# (W, B*N, C)
|
| 272 |
+
tgt = torch.zeros_like(style)
|
| 273 |
+
pos_embed = self.pos_embed(W)
|
| 274 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
| 275 |
+
face3d_feat = self.decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
|
| 276 |
+
# (W, B*N, C)
|
| 277 |
+
face3d_feat = face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
|
| 278 |
+
# (B, N, C)
|
| 279 |
+
face3d = self.tail_fc(face3d_feat)
|
| 280 |
+
# (B, N, C_exp)
|
| 281 |
+
return face3d
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
if __name__ == "__main__":
|
| 285 |
+
import sys
|
| 286 |
+
|
| 287 |
+
sys.path.append("/home/mayifeng/Research/styleTH")
|
| 288 |
+
|
| 289 |
+
from configs.default import get_cfg_defaults
|
| 290 |
+
|
| 291 |
+
cfg = get_cfg_defaults()
|
| 292 |
+
cfg.merge_from_file("configs/styleTH_bp.yaml")
|
| 293 |
+
cfg.freeze()
|
| 294 |
+
|
| 295 |
+
# content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
|
| 296 |
+
|
| 297 |
+
# dummy_audio = torch.randint(0, 41, (5, 64, 11))
|
| 298 |
+
# dummy_content = content_encoder(dummy_audio)
|
| 299 |
+
|
| 300 |
+
# style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
|
| 301 |
+
# dummy_face3d_seq = torch.randn(5, 64, 64)
|
| 302 |
+
# dummy_style_code = style_encoder(dummy_face3d_seq)
|
| 303 |
+
|
| 304 |
+
decoder = Decoder(**cfg.DECODER)
|
| 305 |
+
dummy_content = torch.randn(5, 64, 11, 512)
|
| 306 |
+
dummy_style = torch.randn(5, 512)
|
| 307 |
+
dummy_output = decoder(dummy_content, dummy_style)
|
| 308 |
+
|
| 309 |
+
print("hello")
|
damo/dreamtalk/core/networks/mish.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Applies the mish function element-wise:
|
| 3 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# import pytorch
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
@torch.jit.script
|
| 12 |
+
def mish(input):
|
| 13 |
+
"""
|
| 14 |
+
Applies the mish function element-wise:
|
| 15 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
| 16 |
+
See additional documentation for mish class.
|
| 17 |
+
"""
|
| 18 |
+
return input * torch.tanh(F.softplus(input))
|
| 19 |
+
|
| 20 |
+
class Mish(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Applies the mish function element-wise:
|
| 23 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
| 24 |
+
|
| 25 |
+
Shape:
|
| 26 |
+
- Input: (N, *) where * means, any number of additional
|
| 27 |
+
dimensions
|
| 28 |
+
- Output: (N, *), same shape as the input
|
| 29 |
+
|
| 30 |
+
Examples:
|
| 31 |
+
>>> m = Mish()
|
| 32 |
+
>>> input = torch.randn(2)
|
| 33 |
+
>>> output = m(input)
|
| 34 |
+
|
| 35 |
+
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
"""
|
| 40 |
+
Init method.
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
def forward(self, input):
|
| 45 |
+
"""
|
| 46 |
+
Forward pass of the function.
|
| 47 |
+
"""
|
| 48 |
+
if torch.__version__ >= "1.9":
|
| 49 |
+
return F.mish(input)
|
| 50 |
+
else:
|
| 51 |
+
return mish(input)
|
damo/dreamtalk/core/networks/self_attention_pooling.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from core.networks.mish import Mish
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SelfAttentionPooling(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Implementation of SelfAttentionPooling
|
| 9 |
+
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
|
| 10 |
+
https://arxiv.org/pdf/2008.01077v1.pdf
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, input_dim):
|
| 14 |
+
super(SelfAttentionPooling, self).__init__()
|
| 15 |
+
self.W = nn.Sequential(nn.Linear(input_dim, input_dim), Mish(), nn.Linear(input_dim, 1))
|
| 16 |
+
self.softmax = nn.functional.softmax
|
| 17 |
+
|
| 18 |
+
def forward(self, batch_rep, att_mask=None):
|
| 19 |
+
"""
|
| 20 |
+
N: batch size, T: sequence length, H: Hidden dimension
|
| 21 |
+
input:
|
| 22 |
+
batch_rep : size (N, T, H)
|
| 23 |
+
attention_weight:
|
| 24 |
+
att_w : size (N, T, 1)
|
| 25 |
+
att_mask:
|
| 26 |
+
att_mask: size (N, T): if True, mask this item.
|
| 27 |
+
return:
|
| 28 |
+
utter_rep: size (N, H)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
att_logits = self.W(batch_rep).squeeze(-1)
|
| 32 |
+
# (N, T)
|
| 33 |
+
if att_mask is not None:
|
| 34 |
+
att_mask_logits = att_mask.to(dtype=batch_rep.dtype) * -100000.0
|
| 35 |
+
# (N, T)
|
| 36 |
+
att_logits = att_mask_logits + att_logits
|
| 37 |
+
|
| 38 |
+
att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
|
| 39 |
+
utter_rep = torch.sum(batch_rep * att_w, dim=1)
|
| 40 |
+
|
| 41 |
+
return utter_rep
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
batch = torch.randn(8, 64, 256)
|
| 46 |
+
self_attn_pool = SelfAttentionPooling(256)
|
| 47 |
+
att_mask = torch.zeros(8, 64)
|
| 48 |
+
att_mask[:, 60:] = 1
|
| 49 |
+
att_mask = att_mask.to(torch.bool)
|
| 50 |
+
output = self_attn_pool(batch, att_mask)
|
| 51 |
+
# (8, 256)
|
| 52 |
+
|
| 53 |
+
print("hello")
|
damo/dreamtalk/core/networks/transformer.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PositionalEncoding(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, d_hid, n_position=200):
|
| 11 |
+
super(PositionalEncoding, self).__init__()
|
| 12 |
+
|
| 13 |
+
# Not a parameter
|
| 14 |
+
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
| 15 |
+
|
| 16 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
| 17 |
+
''' Sinusoid position encoding table '''
|
| 18 |
+
# TODO: make it with torch instead of numpy
|
| 19 |
+
|
| 20 |
+
def get_position_angle_vec(position):
|
| 21 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
| 22 |
+
|
| 23 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
| 24 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
| 25 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
| 26 |
+
|
| 27 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
| 28 |
+
|
| 29 |
+
def forward(self, winsize):
|
| 30 |
+
return self.pos_table[:, :winsize].clone().detach()
|
| 31 |
+
|
| 32 |
+
def _get_activation_fn(activation):
|
| 33 |
+
"""Return an activation function given a string"""
|
| 34 |
+
if activation == "relu":
|
| 35 |
+
return F.relu
|
| 36 |
+
if activation == "gelu":
|
| 37 |
+
return F.gelu
|
| 38 |
+
if activation == "glu":
|
| 39 |
+
return F.glu
|
| 40 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 41 |
+
|
| 42 |
+
def _get_clones(module, N):
|
| 43 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 44 |
+
|
| 45 |
+
class Transformer(nn.Module):
|
| 46 |
+
|
| 47 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
| 48 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
| 49 |
+
activation="relu", normalize_before=False,
|
| 50 |
+
return_intermediate_dec=True):
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
| 54 |
+
dropout, activation, normalize_before)
|
| 55 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 56 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 57 |
+
|
| 58 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
| 59 |
+
dropout, activation, normalize_before)
|
| 60 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 61 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
| 62 |
+
return_intermediate=return_intermediate_dec)
|
| 63 |
+
|
| 64 |
+
self._reset_parameters()
|
| 65 |
+
|
| 66 |
+
self.d_model = d_model
|
| 67 |
+
self.nhead = nhead
|
| 68 |
+
|
| 69 |
+
def _reset_parameters(self):
|
| 70 |
+
for p in self.parameters():
|
| 71 |
+
if p.dim() > 1:
|
| 72 |
+
nn.init.xavier_uniform_(p)
|
| 73 |
+
|
| 74 |
+
def forward(self,opt, src, query_embed, pos_embed):
|
| 75 |
+
# flatten NxCxHxW to HWxNxC
|
| 76 |
+
|
| 77 |
+
src = src.permute(1, 0, 2)
|
| 78 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
| 79 |
+
query_embed = query_embed.permute(1, 0, 2)
|
| 80 |
+
|
| 81 |
+
tgt = torch.zeros_like(query_embed)
|
| 82 |
+
memory = self.encoder(src, pos=pos_embed)
|
| 83 |
+
|
| 84 |
+
hs = self.decoder(tgt, memory,
|
| 85 |
+
pos=pos_embed, query_pos=query_embed)
|
| 86 |
+
return hs
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TransformerEncoder(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 94 |
+
self.num_layers = num_layers
|
| 95 |
+
self.norm = norm
|
| 96 |
+
|
| 97 |
+
def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):
|
| 98 |
+
output = src+pos
|
| 99 |
+
|
| 100 |
+
for layer in self.layers:
|
| 101 |
+
output = layer(output, src_mask=mask,
|
| 102 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
| 103 |
+
|
| 104 |
+
if self.norm is not None:
|
| 105 |
+
output = self.norm(output)
|
| 106 |
+
|
| 107 |
+
return output
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class TransformerDecoder(nn.Module):
|
| 111 |
+
|
| 112 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 115 |
+
self.num_layers = num_layers
|
| 116 |
+
self.norm = norm
|
| 117 |
+
self.return_intermediate = return_intermediate
|
| 118 |
+
|
| 119 |
+
def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,
|
| 120 |
+
memory_key_padding_mask = None,
|
| 121 |
+
pos = None,
|
| 122 |
+
query_pos = None):
|
| 123 |
+
output = tgt+pos+query_pos
|
| 124 |
+
|
| 125 |
+
intermediate = []
|
| 126 |
+
|
| 127 |
+
for layer in self.layers:
|
| 128 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
| 129 |
+
memory_mask=memory_mask,
|
| 130 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 131 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 132 |
+
pos=pos, query_pos=query_pos)
|
| 133 |
+
if self.return_intermediate:
|
| 134 |
+
intermediate.append(self.norm(output))
|
| 135 |
+
|
| 136 |
+
if self.norm is not None:
|
| 137 |
+
output = self.norm(output)
|
| 138 |
+
if self.return_intermediate:
|
| 139 |
+
intermediate.pop()
|
| 140 |
+
intermediate.append(output)
|
| 141 |
+
|
| 142 |
+
if self.return_intermediate:
|
| 143 |
+
return torch.stack(intermediate)
|
| 144 |
+
|
| 145 |
+
return output.unsqueeze(0)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class TransformerEncoderLayer(nn.Module):
|
| 149 |
+
|
| 150 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 151 |
+
activation="relu", normalize_before=False):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 154 |
+
# Implementation of Feedforward model
|
| 155 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 156 |
+
self.dropout = nn.Dropout(dropout)
|
| 157 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 158 |
+
|
| 159 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 160 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 161 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 162 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 163 |
+
|
| 164 |
+
self.activation = _get_activation_fn(activation)
|
| 165 |
+
self.normalize_before = normalize_before
|
| 166 |
+
|
| 167 |
+
def with_pos_embed(self, tensor, pos):
|
| 168 |
+
return tensor if pos is None else tensor + pos
|
| 169 |
+
|
| 170 |
+
def forward_post(self,
|
| 171 |
+
src,
|
| 172 |
+
src_mask = None,
|
| 173 |
+
src_key_padding_mask = None,
|
| 174 |
+
pos = None):
|
| 175 |
+
# q = k = self.with_pos_embed(src, pos)
|
| 176 |
+
src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,
|
| 177 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 178 |
+
src = src + self.dropout1(src2)
|
| 179 |
+
src = self.norm1(src)
|
| 180 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 181 |
+
src = src + self.dropout2(src2)
|
| 182 |
+
src = self.norm2(src)
|
| 183 |
+
return src
|
| 184 |
+
|
| 185 |
+
def forward_pre(self, src,
|
| 186 |
+
src_mask = None,
|
| 187 |
+
src_key_padding_mask = None,
|
| 188 |
+
pos = None):
|
| 189 |
+
src2 = self.norm1(src)
|
| 190 |
+
# q = k = self.with_pos_embed(src2, pos)
|
| 191 |
+
src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,
|
| 192 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 193 |
+
src = src + self.dropout1(src2)
|
| 194 |
+
src2 = self.norm2(src)
|
| 195 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 196 |
+
src = src + self.dropout2(src2)
|
| 197 |
+
return src
|
| 198 |
+
|
| 199 |
+
def forward(self, src,
|
| 200 |
+
src_mask = None,
|
| 201 |
+
src_key_padding_mask = None,
|
| 202 |
+
pos = None):
|
| 203 |
+
if self.normalize_before:
|
| 204 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| 205 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class TransformerDecoderLayer(nn.Module):
|
| 209 |
+
|
| 210 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 211 |
+
activation="relu", normalize_before=False):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 214 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 215 |
+
# Implementation of Feedforward model
|
| 216 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 217 |
+
self.dropout = nn.Dropout(dropout)
|
| 218 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 219 |
+
|
| 220 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 221 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 222 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 223 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 224 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 225 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 226 |
+
|
| 227 |
+
self.activation = _get_activation_fn(activation)
|
| 228 |
+
self.normalize_before = normalize_before
|
| 229 |
+
|
| 230 |
+
def with_pos_embed(self, tensor, pos):
|
| 231 |
+
return tensor if pos is None else tensor + pos
|
| 232 |
+
|
| 233 |
+
def forward_post(self, tgt, memory,
|
| 234 |
+
tgt_mask = None,
|
| 235 |
+
memory_mask = None,
|
| 236 |
+
tgt_key_padding_mask = None,
|
| 237 |
+
memory_key_padding_mask = None,
|
| 238 |
+
pos = None,
|
| 239 |
+
query_pos = None):
|
| 240 |
+
# q = k = self.with_pos_embed(tgt, query_pos)
|
| 241 |
+
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,
|
| 242 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 243 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 244 |
+
tgt = self.norm1(tgt)
|
| 245 |
+
tgt2 = self.multihead_attn(query=tgt,
|
| 246 |
+
key=memory,
|
| 247 |
+
value=memory, attn_mask=memory_mask,
|
| 248 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 249 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 250 |
+
tgt = self.norm2(tgt)
|
| 251 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 252 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 253 |
+
tgt = self.norm3(tgt)
|
| 254 |
+
return tgt
|
| 255 |
+
|
| 256 |
+
def forward_pre(self, tgt, memory,
|
| 257 |
+
tgt_mask = None,
|
| 258 |
+
memory_mask = None,
|
| 259 |
+
tgt_key_padding_mask = None,
|
| 260 |
+
memory_key_padding_mask = None,
|
| 261 |
+
pos = None,
|
| 262 |
+
query_pos = None):
|
| 263 |
+
tgt2 = self.norm1(tgt)
|
| 264 |
+
# q = k = self.with_pos_embed(tgt2, query_pos)
|
| 265 |
+
tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,
|
| 266 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 267 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 268 |
+
tgt2 = self.norm2(tgt)
|
| 269 |
+
tgt2 = self.multihead_attn(query=tgt2,
|
| 270 |
+
key=memory,
|
| 271 |
+
value=memory, attn_mask=memory_mask,
|
| 272 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 273 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 274 |
+
tgt2 = self.norm3(tgt)
|
| 275 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 276 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 277 |
+
return tgt
|
| 278 |
+
|
| 279 |
+
def forward(self, tgt, memory,
|
| 280 |
+
tgt_mask = None,
|
| 281 |
+
memory_mask = None,
|
| 282 |
+
tgt_key_padding_mask = None,
|
| 283 |
+
memory_key_padding_mask = None,
|
| 284 |
+
pos = None,
|
| 285 |
+
query_pos = None):
|
| 286 |
+
if self.normalize_before:
|
| 287 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
| 288 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 289 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
| 290 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
damo/dreamtalk/core/utils.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
import logging
|
| 5 |
+
import pickle
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from scipy.io import loadmat
|
| 12 |
+
|
| 13 |
+
from configs.default import get_cfg_defaults
|
| 14 |
+
import dlib
|
| 15 |
+
import cv2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _reset_parameters(model):
|
| 19 |
+
for p in model.parameters():
|
| 20 |
+
if p.dim() > 1:
|
| 21 |
+
nn.init.xavier_uniform_(p)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_video_style(video_name, style_type):
|
| 25 |
+
person_id, direction, emotion, level, *_ = video_name.split("_")
|
| 26 |
+
if style_type == "id_dir_emo_level":
|
| 27 |
+
style = "_".join([person_id, direction, emotion, level])
|
| 28 |
+
elif style_type == "emotion":
|
| 29 |
+
style = emotion
|
| 30 |
+
elif style_type == "id":
|
| 31 |
+
style = person_id
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("Unknown style type")
|
| 34 |
+
|
| 35 |
+
return style
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_style_video_lists(video_list, style_type):
|
| 39 |
+
style2video_list = defaultdict(list)
|
| 40 |
+
for video in video_list:
|
| 41 |
+
style = get_video_style(video, style_type)
|
| 42 |
+
style2video_list[style].append(video)
|
| 43 |
+
|
| 44 |
+
return style2video_list
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_face3d_clip(
|
| 48 |
+
video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
|
| 49 |
+
):
|
| 50 |
+
"""_summary_
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
video_name (_type_): _description_
|
| 54 |
+
video_root_dir (_type_): _description_
|
| 55 |
+
num_frames (_type_): _description_
|
| 56 |
+
start_idx (_type_): "random" , middle, int
|
| 57 |
+
dtype (_type_, optional): _description_. Defaults to torch.float32.
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
ValueError: _description_
|
| 61 |
+
ValueError: _description_
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
_type_: _description_
|
| 65 |
+
"""
|
| 66 |
+
video_path = os.path.join(video_root_dir, video_name)
|
| 67 |
+
if video_path[-3:] == "mat":
|
| 68 |
+
face3d_all = loadmat(video_path)["coeff"]
|
| 69 |
+
face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
| 70 |
+
elif video_path[-3:] == "txt":
|
| 71 |
+
face3d_exp = np.loadtxt(video_path)
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError("Invalid 3DMM file extension")
|
| 74 |
+
|
| 75 |
+
length = face3d_exp.shape[0]
|
| 76 |
+
clip_num_frames = num_frames
|
| 77 |
+
if start_idx == "random":
|
| 78 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
| 79 |
+
elif start_idx == "middle":
|
| 80 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
| 81 |
+
elif isinstance(start_idx, int):
|
| 82 |
+
clip_start_idx = start_idx
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
| 85 |
+
|
| 86 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
| 87 |
+
face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
|
| 88 |
+
|
| 89 |
+
return face3d_clip
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_video_style_clip(
|
| 93 |
+
video_name,
|
| 94 |
+
video_root_dir,
|
| 95 |
+
style_max_len,
|
| 96 |
+
start_idx="random",
|
| 97 |
+
dtype=torch.float32,
|
| 98 |
+
return_start_idx=False,
|
| 99 |
+
):
|
| 100 |
+
video_path = os.path.join(video_root_dir, video_name)
|
| 101 |
+
if video_path[-3:] == "mat":
|
| 102 |
+
face3d_all = loadmat(video_path)["coeff"]
|
| 103 |
+
face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
| 104 |
+
elif video_path[-3:] == "txt":
|
| 105 |
+
face3d_exp = np.loadtxt(video_path)
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError("Invalid 3DMM file extension")
|
| 108 |
+
|
| 109 |
+
face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
|
| 110 |
+
|
| 111 |
+
length = face3d_exp.shape[0]
|
| 112 |
+
if length >= style_max_len:
|
| 113 |
+
clip_num_frames = style_max_len
|
| 114 |
+
if start_idx == "random":
|
| 115 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
| 116 |
+
elif start_idx == "middle":
|
| 117 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
| 118 |
+
elif isinstance(start_idx, int):
|
| 119 |
+
clip_start_idx = start_idx
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
| 122 |
+
|
| 123 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
| 124 |
+
pad_mask = torch.tensor([False] * style_max_len)
|
| 125 |
+
else:
|
| 126 |
+
clip_start_idx = None
|
| 127 |
+
padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
|
| 128 |
+
face3d_clip = torch.cat((face3d_exp, padding), dim=0)
|
| 129 |
+
pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
|
| 130 |
+
|
| 131 |
+
if return_start_idx:
|
| 132 |
+
return face3d_clip, pad_mask, clip_start_idx
|
| 133 |
+
else:
|
| 134 |
+
return face3d_clip, pad_mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_video_style_clip_from_np(
|
| 138 |
+
face3d_exp,
|
| 139 |
+
style_max_len,
|
| 140 |
+
start_idx="random",
|
| 141 |
+
dtype=torch.float32,
|
| 142 |
+
return_start_idx=False,
|
| 143 |
+
):
|
| 144 |
+
face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
|
| 145 |
+
|
| 146 |
+
length = face3d_exp.shape[0]
|
| 147 |
+
if length >= style_max_len:
|
| 148 |
+
clip_num_frames = style_max_len
|
| 149 |
+
if start_idx == "random":
|
| 150 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
| 151 |
+
elif start_idx == "middle":
|
| 152 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
| 153 |
+
elif isinstance(start_idx, int):
|
| 154 |
+
clip_start_idx = start_idx
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
| 157 |
+
|
| 158 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
| 159 |
+
pad_mask = torch.tensor([False] * style_max_len)
|
| 160 |
+
else:
|
| 161 |
+
clip_start_idx = None
|
| 162 |
+
padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
|
| 163 |
+
face3d_clip = torch.cat((face3d_exp, padding), dim=0)
|
| 164 |
+
pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
|
| 165 |
+
|
| 166 |
+
if return_start_idx:
|
| 167 |
+
return face3d_clip, pad_mask, clip_start_idx
|
| 168 |
+
else:
|
| 169 |
+
return face3d_clip, pad_mask
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
audio_feat (np.ndarray): (N, 1024)
|
| 177 |
+
start_idx (_type_): _description_
|
| 178 |
+
num_frames (_type_): _description_
|
| 179 |
+
"""
|
| 180 |
+
center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
|
| 181 |
+
audio_window_list = []
|
| 182 |
+
padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
|
| 183 |
+
for center_idx in center_idx_list:
|
| 184 |
+
cur_audio_window = []
|
| 185 |
+
for i in range(center_idx - win_size, center_idx + win_size + 1):
|
| 186 |
+
if i < 0:
|
| 187 |
+
cur_audio_window.append(padding)
|
| 188 |
+
elif i >= len(audio_feat):
|
| 189 |
+
cur_audio_window.append(padding)
|
| 190 |
+
else:
|
| 191 |
+
cur_audio_window.append(audio_feat[i])
|
| 192 |
+
cur_audio_win_array = np.stack(cur_audio_window, axis=0)
|
| 193 |
+
audio_window_list.append(cur_audio_win_array)
|
| 194 |
+
|
| 195 |
+
audio_window_array = np.stack(audio_window_list, axis=0)
|
| 196 |
+
return audio_window_array
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def setup_config():
|
| 200 |
+
parser = argparse.ArgumentParser(description="voice2pose main program")
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--config_file", default="", metavar="FILE", help="path to config file"
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--resume_from", type=str, default=None, help="the checkpoint to resume from"
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--test_only", action="store_true", help="perform testing and evaluation only"
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--demo_input", type=str, default=None, help="path to input for demo"
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--checkpoint", type=str, default=None, help="the checkpoint to test with"
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"opts",
|
| 219 |
+
help="Modify config options using the command-line",
|
| 220 |
+
default=None,
|
| 221 |
+
nargs=argparse.REMAINDER,
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--local_rank",
|
| 225 |
+
type=int,
|
| 226 |
+
help="local rank for DistributedDataParallel",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--master_port",
|
| 230 |
+
type=str,
|
| 231 |
+
default="12345",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--max_audio_len",
|
| 235 |
+
type=int,
|
| 236 |
+
default=450,
|
| 237 |
+
help="max_audio_len for inference",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--ddim_num_step",
|
| 241 |
+
type=int,
|
| 242 |
+
default=10,
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--inference_seed",
|
| 246 |
+
type=int,
|
| 247 |
+
default=1,
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--inference_sample_method",
|
| 251 |
+
type=str,
|
| 252 |
+
default="ddim",
|
| 253 |
+
)
|
| 254 |
+
args = parser.parse_args()
|
| 255 |
+
|
| 256 |
+
cfg = get_cfg_defaults()
|
| 257 |
+
cfg.merge_from_file(args.config_file)
|
| 258 |
+
cfg.merge_from_list(args.opts)
|
| 259 |
+
cfg.freeze()
|
| 260 |
+
return args, cfg
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def setup_logger(base_path, exp_name):
|
| 264 |
+
rootLogger = logging.getLogger()
|
| 265 |
+
rootLogger.setLevel(logging.INFO)
|
| 266 |
+
|
| 267 |
+
logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
|
| 268 |
+
|
| 269 |
+
log_path = "{0}/{1}.log".format(base_path, exp_name)
|
| 270 |
+
fileHandler = logging.FileHandler(log_path)
|
| 271 |
+
fileHandler.setFormatter(logFormatter)
|
| 272 |
+
rootLogger.addHandler(fileHandler)
|
| 273 |
+
|
| 274 |
+
consoleHandler = logging.StreamHandler()
|
| 275 |
+
consoleHandler.setFormatter(logFormatter)
|
| 276 |
+
rootLogger.addHandler(consoleHandler)
|
| 277 |
+
rootLogger.handlers[0].setLevel(logging.INFO)
|
| 278 |
+
|
| 279 |
+
logging.info("log path: %s" % log_path)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def cosine_loss(a, v, y, logloss=nn.BCELoss()):
|
| 283 |
+
d = nn.functional.cosine_similarity(a, v)
|
| 284 |
+
loss = logloss(d.unsqueeze(1), y)
|
| 285 |
+
return loss
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_pose_params(mat_path):
|
| 289 |
+
"""Get pose parameters from mat file
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
mat_path (str): path of mat file
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
|
| 296 |
+
"""
|
| 297 |
+
mat_dict = loadmat(mat_path)
|
| 298 |
+
|
| 299 |
+
np_3dmm = mat_dict["coeff"]
|
| 300 |
+
angles = np_3dmm[:, 224:227]
|
| 301 |
+
translations = np_3dmm[:, 254:257]
|
| 302 |
+
|
| 303 |
+
np_trans_params = mat_dict["transform_params"]
|
| 304 |
+
crop = np_trans_params[:, -3:]
|
| 305 |
+
|
| 306 |
+
pose_params = np.concatenate((angles, translations, crop), axis=1)
|
| 307 |
+
|
| 308 |
+
return pose_params
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def sinusoidal_embedding(timesteps, dim):
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
timesteps (_type_): (B,)
|
| 316 |
+
dim (_type_): (C_embed)
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
_type_: (B, C_embed)
|
| 320 |
+
"""
|
| 321 |
+
# check input
|
| 322 |
+
half = dim // 2
|
| 323 |
+
timesteps = timesteps.float()
|
| 324 |
+
|
| 325 |
+
# compute sinusoidal embedding
|
| 326 |
+
sinusoid = torch.outer(
|
| 327 |
+
timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
|
| 328 |
+
)
|
| 329 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 330 |
+
if dim % 2 != 0:
|
| 331 |
+
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
| 332 |
+
return x
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
audio_feat (np.ndarray): (250, 1024)
|
| 340 |
+
start_idx (_type_): _description_
|
| 341 |
+
num_frames (_type_): _description_
|
| 342 |
+
"""
|
| 343 |
+
center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
|
| 344 |
+
audio_window_list = []
|
| 345 |
+
padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
|
| 346 |
+
for center_idx in center_idx_list:
|
| 347 |
+
cur_audio_window = []
|
| 348 |
+
for i in range(center_idx - win_size, center_idx + win_size + 1):
|
| 349 |
+
if i < 0:
|
| 350 |
+
cur_audio_window.append(padding)
|
| 351 |
+
elif i >= len(audio_feat):
|
| 352 |
+
cur_audio_window.append(padding)
|
| 353 |
+
else:
|
| 354 |
+
cur_audio_window.append(audio_feat[i])
|
| 355 |
+
cur_audio_win_array = np.stack(cur_audio_window, axis=0)
|
| 356 |
+
audio_window_list.append(cur_audio_win_array)
|
| 357 |
+
|
| 358 |
+
audio_window_array = np.stack(audio_window_list, axis=0)
|
| 359 |
+
return audio_window_array
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def reshape_audio_feat(style_audio_all_raw, stride):
|
| 363 |
+
"""_summary_
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
style_audio_all_raw (_type_): (stride * L, C)
|
| 367 |
+
stride (_type_): int
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
_type_: (L, C * stride)
|
| 371 |
+
"""
|
| 372 |
+
style_audio_all_raw = style_audio_all_raw[
|
| 373 |
+
: style_audio_all_raw.shape[0] // stride * stride
|
| 374 |
+
]
|
| 375 |
+
style_audio_all_raw = style_audio_all_raw.reshape(
|
| 376 |
+
style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
|
| 377 |
+
)
|
| 378 |
+
style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
|
| 379 |
+
return style_audio_all
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
import random
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def get_derangement_tuple(n):
|
| 386 |
+
while True:
|
| 387 |
+
v = [i for i in range(n)]
|
| 388 |
+
for j in range(n - 1, -1, -1):
|
| 389 |
+
p = random.randint(0, j)
|
| 390 |
+
if v[p] == j:
|
| 391 |
+
break
|
| 392 |
+
else:
|
| 393 |
+
v[j], v[p] = v[p], v[j]
|
| 394 |
+
else:
|
| 395 |
+
if v[0] != 0:
|
| 396 |
+
return tuple(v)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
|
| 400 |
+
left, top, right, bot = bbox
|
| 401 |
+
width = right - left
|
| 402 |
+
height = bot - top
|
| 403 |
+
|
| 404 |
+
width_increase = max(
|
| 405 |
+
increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
|
| 406 |
+
)
|
| 407 |
+
height_increase = max(
|
| 408 |
+
increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
left_t = int(left - width_increase * width)
|
| 412 |
+
top_t = int(top - height_increase * height)
|
| 413 |
+
right_t = int(right + width_increase * width)
|
| 414 |
+
bot_t = int(bot + height_increase * height)
|
| 415 |
+
|
| 416 |
+
left_oob = -min(0, left_t)
|
| 417 |
+
right_oob = right - min(right_t, w)
|
| 418 |
+
top_oob = -min(0, top_t)
|
| 419 |
+
bot_oob = bot - min(bot_t, h)
|
| 420 |
+
|
| 421 |
+
if max(left_oob, right_oob, top_oob, bot_oob) > 0:
|
| 422 |
+
max_w = max(left_oob, right_oob)
|
| 423 |
+
max_h = max(top_oob, bot_oob)
|
| 424 |
+
if max_w > max_h:
|
| 425 |
+
return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
|
| 426 |
+
else:
|
| 427 |
+
return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
|
| 428 |
+
|
| 429 |
+
else:
|
| 430 |
+
return (left_t, top_t, right_t, bot_t)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def crop_src_image(src_img, save_img, increase_ratio, detector=None):
|
| 434 |
+
if detector is None:
|
| 435 |
+
detector = dlib.get_frontal_face_detector()
|
| 436 |
+
|
| 437 |
+
img = cv2.imread(src_img)
|
| 438 |
+
faces = detector(img, 0)
|
| 439 |
+
h, width, _ = img.shape
|
| 440 |
+
if len(faces) > 0:
|
| 441 |
+
bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
|
| 442 |
+
l = bbox[3] - bbox[1]
|
| 443 |
+
bbox[1] = bbox[1] - l * 0.1
|
| 444 |
+
bbox[3] = bbox[3] - l * 0.1
|
| 445 |
+
bbox[1] = max(0, bbox[1])
|
| 446 |
+
bbox[3] = min(h, bbox[3])
|
| 447 |
+
bbox = compute_aspect_preserved_bbox(
|
| 448 |
+
tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
|
| 449 |
+
)
|
| 450 |
+
img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
| 451 |
+
img = cv2.resize(img, (256, 256))
|
| 452 |
+
cv2.imwrite(save_img, img)
|
| 453 |
+
else:
|
| 454 |
+
raise ValueError("No face detected in the input image")
|
| 455 |
+
# img = cv2.resize(img, (256, 256))
|
| 456 |
+
# cv2.imwrite(save_img, img)
|
damo/dreamtalk/data/audio/German1.wav
ADDED
|
Binary file (279 kB). View file
|
|
|
damo/dreamtalk/data/audio/German2.wav
ADDED
|
Binary file (219 kB). View file
|
|
|
damo/dreamtalk/data/audio/German3.wav
ADDED
|
Binary file (240 kB). View file
|
|
|
damo/dreamtalk/data/audio/German4.wav
ADDED
|
Binary file (219 kB). View file
|
|
|
damo/dreamtalk/data/audio/acknowledgement_chinese.m4a
ADDED
|
Binary file (537 kB). View file
|
|
|
damo/dreamtalk/data/audio/acknowledgement_english.m4a
ADDED
|
Binary file (511 kB). View file
|
|
|
damo/dreamtalk/data/audio/chinese1_haierlizhi.wav
ADDED
|
Binary file (420 kB). View file
|
|
|
damo/dreamtalk/data/audio/chinese2_guanyu.wav
ADDED
|
Binary file (638 kB). View file
|
|
|
damo/dreamtalk/data/audio/french1.wav
ADDED
|
Binary file (220 kB). View file
|
|
|
damo/dreamtalk/data/audio/french2.wav
ADDED
|
Binary file (177 kB). View file
|
|
|
damo/dreamtalk/data/audio/french3.wav
ADDED
|
Binary file (168 kB). View file
|
|
|
damo/dreamtalk/data/audio/italian1.wav
ADDED
|
Binary file (285 kB). View file
|
|
|
damo/dreamtalk/data/audio/italian2.wav
ADDED
|
Binary file (170 kB). View file
|
|
|
damo/dreamtalk/data/audio/italian3.wav
ADDED
|
Binary file (197 kB). View file
|
|
|
damo/dreamtalk/data/audio/japan1.wav
ADDED
|
Binary file (197 kB). View file
|
|
|
damo/dreamtalk/data/audio/japan2.wav
ADDED
|
Binary file (231 kB). View file
|
|
|
damo/dreamtalk/data/audio/japan3.wav
ADDED
|
Binary file (234 kB). View file
|
|
|
damo/dreamtalk/data/audio/korean1.wav
ADDED
|
Binary file (328 kB). View file
|
|
|
damo/dreamtalk/data/audio/korean2.wav
ADDED
|
Binary file (210 kB). View file
|
|
|
damo/dreamtalk/data/audio/korean3.wav
ADDED
|
Binary file (148 kB). View file
|
|
|
damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav
ADDED
|
Binary file (206 kB). View file
|
|
|
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav
ADDED
|
Binary file (206 kB). View file
|
|
|
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav
ADDED
|
Binary file (206 kB). View file
|
|
|
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav
ADDED
|
Binary file (206 kB). View file
|
|
|
damo/dreamtalk/data/audio/noisy_audio_narrative.wav
ADDED
|
Binary file (206 kB). View file
|
|
|
damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav
ADDED
|
Binary file (206 kB). View file
|
|
|
damo/dreamtalk/data/audio/out_of_domain_narrative.wav
ADDED
|
Binary file (445 kB). View file
|
|
|
damo/dreamtalk/data/audio/spanish1.wav
ADDED
|
Binary file (144 kB). View file
|
|
|
damo/dreamtalk/data/audio/spanish2.wav
ADDED
|
Binary file (150 kB). View file
|
|
|
damo/dreamtalk/data/audio/spanish3.wav
ADDED
|
Binary file (212 kB). View file
|
|
|