├── README.md ├── __pycache__ ├── data_preparation.cpython-310.pyc └── demo.cpython-310.pyc ├── back.mp4 ├── checkpoint ├── audio.pkl ├── pca.pkl ├── render.pth.gz.001 └── render.pth.gz.002 ├── circle.mp4 ├── data ├── face_pts_mean.txt ├── face_pts_mean_mainKps.txt ├── pca.pkl └── video_concat.txt ├── data_preparation.py ├── demo.py ├── demo_avatar.py ├── front.mp4 ├── go-web.bat ├── images ├── 1.png ├── 2.png ├── 3.png └── 4.png ├── inp_keypoint.pkl ├── keypoint_rotate.pkl ├── requirements.txt ├── talkingface ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── audio_model.cpython-310.pyc │ ├── render_model.cpython-310.pyc │ ├── run_utils.cpython-310.pyc │ └── utils.cpython-310.pyc ├── audio_model.py ├── config │ └── config.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── few_shot_dataset.cpython-310.pyc │ ├── face_mask.py │ └── few_shot_dataset.py ├── face_pts_mean.txt ├── models │ ├── DINet.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── DINet.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── audio2bs_lstm.cpython-310.pyc │ ├── audio2bs_lstm.py │ ├── common │ │ ├── Discriminator.py │ │ └── VGG19.py │ └── speed_test.py ├── preprocess.py ├── render_model.py ├── run_utils.py ├── util │ ├── __init__.py │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── log_board.py │ ├── smooth.py │ ├── util.py │ ├── utils.py │ └── visualizer.py └── utils.py ├── train.py ├── train_input_validation.py ├── video_data ├── audio0.wav ├── audio1.wav ├── circle.mp4 ├── demo.mp4 ├── keypoint_rotate.pkl └── test │ ├── circle.mp4 │ └── keypoint_rotate.pkl └── webapp.py /README.md: -------------------------------------------------------------------------------- 1 | # Digital Human Generation Tool 2 | 3 | ## 项目简介 4 | 5 | 本项目是 DH_live 的改进版,增加了语音转文本功能,用于生成和处理数字人的音频生成,我写了一个gradio界面易于操作,我是编程小白代码有很多错误的地方欢迎提出,这个项目旨在为用户提供一个全面的解决方案,包括但不限于文本转语音、视频处理、模型训练和人脸生成等功能。 6 | 7 | ## 特别感谢 8 | 9 | 特别感谢kleinlee的开源https://github.com/kleinlee/DH_live.git 10 | 11 | ### 功能特点 12 | 13 | - **视频处理**:支持对视频进行切割、合成和格式转换,适用于创建和编辑数字人类视频素材。 14 | - **文本转语音**:包括多种预语言选择,后期更新可能会增加一些改进。 15 | - **数字人合成**:选择好音频,开始合成数字人。 16 | - **未来的改进**:预计未来会增加改进,比如,文本转语音,可能会增加一些其他的功能。 17 | 18 | ### 图片展示 19 | 20 | 以下是一些项目功能的示例截图: 21 | 22 | 1. **视频处理界面** 23 | 24 | ![视频处理界面](images/1.png) 25 | *图1: 视频处理界面展示了如何使用工具对视频进行切割和合成。* 26 | 27 | 2. **文本转语音界面** 28 | 29 | ![文本转语音界面](images/2.png) 30 | *图2: 文本转语音界面展示了如何使用工具将文本转换为语音。* 31 | 32 | 3. **数字人合成界面** 33 | 34 | ![数字人合成界面](images/3.png) 35 | *图3: 数字人合成界面展示了如何生成和处理数字人图像。* 36 | 37 | 4. **未来的改进** 38 | 39 | ![未来的改进](images/4.png) 40 | *图4: 预计的未来改进和功能扩展。* 41 | 42 | ### 安装与使用 43 | 44 | 1. **环境要求** 45 | 46 | - **Python 版本**:需要 Python 3.10.6 47 | - **操作系统**:支持 Windows、macOS 和 Linux 48 | 49 | 2. **创建虚拟环境** 50 | 51 | 推荐使用虚拟环境来隔离项目依赖。可以使用以下命令创建虚拟环境: 52 | 53 | ```bash 54 | 55 | git clone https://github.com/xingxing2233/Digital-human-generation-tool.git 56 | 57 | cd Digital-human-generation-tool 58 | 59 | 60 | python -m venv venv 61 | 62 | venv\Scripts\activate 63 | 64 | 65 | 66 | pip install -r requirements.txt 67 | 68 | 69 | python webapp.py 70 | 71 | 72 | 打开浏览器http://127.0.0.1:7860 73 | 74 | ### 特别提示​:本项目仅供研究学习使用,禁止商业使用,切勿用于非法用途。 75 | 76 | ### 支持 77 | 78 | 如果需要可来知识星球一起探讨改进https://t.zsxq.com/lnroF 79 | 80 | 81 | ![图片描述](https://private-user-images.githubusercontent.com/26740174/273180278-42119c68-9aa7-4e97-931d-9836369ccf80.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjUxMTYxMjYsIm5iZiI6MTcyNTExNTgyNiwicGF0aCI6Ii8yNjc0MDE3NC8yNzMxODAyNzgtNDIxMTljNjgtOWFhNy00ZTk3LTkzMWQtOTgzNjM2OWNjZjgwLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MzElMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODMxVDE0NTAyNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTYzOTM4MzAyYTlhNjI3NTk0MmMyZmZjZDAxZWNhMGE1ZDllODE1Y2I4Yjg0ZGIwMTdhOWFkMTNiNjg5M2ViYjUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.LALvZGnx9lTTD58lPladTUop-cJPziEQLSqGGCHXqWU) 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /__pycache__/data_preparation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/__pycache__/data_preparation.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/demo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/__pycache__/demo.cpython-310.pyc -------------------------------------------------------------------------------- /back.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/back.mp4 -------------------------------------------------------------------------------- /checkpoint/audio.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/checkpoint/audio.pkl -------------------------------------------------------------------------------- /checkpoint/pca.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/checkpoint/pca.pkl -------------------------------------------------------------------------------- /checkpoint/render.pth.gz.001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/checkpoint/render.pth.gz.001 -------------------------------------------------------------------------------- /checkpoint/render.pth.gz.002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/checkpoint/render.pth.gz.002 -------------------------------------------------------------------------------- /circle.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/circle.mp4 -------------------------------------------------------------------------------- /data/face_pts_mean.txt: -------------------------------------------------------------------------------- 1 | 500.0 701.5 110.2 2 | 500.0 588.2 29.8 3 | 500.0 624.2 113.5 4 | 472.2 470.8 75.8 5 | 500.0 551.6 18.9 6 | 500.0 504.4 32.9 7 | 500.0 389.6 124.6 8 | 266.6 375.6 255.2 9 | 500.0 295.9 148.2 10 | 500.0 247.2 142.4 11 | 500.0 78.0 180.5 12 | 500.0 715.3 114.0 13 | 500.0 729.9 124.9 14 | 500.0 739.6 140.7 15 | 500.0 779.9 151.7 16 | 500.0 797.2 147.6 17 | 500.0 818.5 142.1 18 | 500.0 839.8 149.8 19 | 500.0 867.8 184.2 20 | 500.0 607.1 46.5 21 | 470.4 606.8 91.8 22 | 140.2 243.2 416.6 23 | 356.6 405.7 220.1 24 | 324.8 408.5 222.6 25 | 293.6 407.6 232.1 26 | 252.1 388.8 260.1 27 | 383.6 398.0 223.9 28 | 309.4 302.6 209.1 29 | 346.6 307.0 209.2 30 | 274.6 309.0 220.3 31 | 252.1 324.0 234.0 32 | 223.9 412.4 276.5 33 | 377.4 906.5 221.4 34 | 254.6 367.0 266.4 35 | 125.2 398.4 426.0 36 | 186.7 386.1 310.9 37 | 334.1 559.7 186.7 38 | 460.9 695.2 115.1 39 | 465.3 729.6 130.0 40 | 422.6 710.5 135.2 41 | 398.4 726.6 161.4 42 | 436.1 733.8 143.9 43 | 414.1 740.5 168.1 44 | 350.4 790.8 217.8 45 | 473.1 587.5 32.8 46 | 468.6 552.3 22.9 47 | 205.6 290.1 245.7 48 | 401.5 469.0 183.1 49 | 398.6 579.0 114.5 50 | 395.6 560.0 124.7 51 | 234.5 556.2 228.8 52 | 470.1 508.3 43.3 53 | 282.4 247.9 188.2 54 | 237.0 262.0 214.6 55 | 173.4 178.9 351.5 56 | 427.6 280.1 158.4 57 | 378.4 322.9 218.7 58 | 327.4 752.8 223.3 59 | 167.8 734.2 510.8 60 | 424.0 599.4 127.0 61 | 449.6 609.9 123.7 62 | 372.3 755.8 219.1 63 | 381.0 755.8 209.7 64 | 219.2 238.8 230.9 65 | 397.6 595.0 132.7 66 | 344.0 250.4 168.7 67 | 338.6 221.1 162.2 68 | 307.9 95.4 223.1 69 | 196.4 211.8 282.4 70 | 327.3 160.9 190.7 71 | 185.9 274.5 274.8 72 | 162.1 263.7 340.9 73 | 462.1 713.7 118.7 74 | 429.1 723.0 137.9 75 | 405.7 733.8 161.8 76 | 431.6 606.0 134.1 77 | 376.8 755.8 214.1 78 | 389.4 770.3 196.1 79 | 385.3 755.7 208.8 80 | 438.7 587.3 81.2 81 | 420.7 744.5 173.8 82 | 442.1 740.2 157.5 83 | 468.6 738.7 144.9 84 | 453.8 864.8 185.3 85 | 458.0 835.2 151.8 86 | 460.2 814.0 145.9 87 | 463.8 793.0 151.9 88 | 468.1 777.5 155.9 89 | 420.5 768.3 182.2 90 | 413.1 775.0 178.0 91 | 404.6 785.0 173.8 92 | 397.8 796.2 180.2 93 | 363.1 688.8 172.6 94 | 127.1 556.0 561.8 95 | 500.0 615.5 90.0 96 | 404.6 763.2 198.6 97 | 397.0 766.3 195.8 98 | 458.7 623.5 122.3 99 | 408.5 616.0 157.0 100 | 454.4 617.0 122.2 101 | 366.2 489.9 194.2 102 | 311.4 514.0 203.8 103 | 390.2 576.2 143.4 104 | 226.3 128.2 282.6 105 | 248.7 175.8 233.4 106 | 270.8 219.8 194.3 107 | 375.8 819.5 204.0 108 | 417.6 231.3 146.1 109 | 408.2 160.0 166.3 110 | 396.1 80.8 188.6 111 | 267.3 401.3 246.6 112 | 188.4 438.0 294.8 113 | 401.5 389.6 228.5 114 | 219.8 339.9 263.2 115 | 425.0 445.5 171.2 116 | 417.8 565.2 82.1 117 | 151.9 466.5 331.4 118 | 210.6 463.5 264.0 119 | 252.9 476.8 232.9 120 | 316.5 470.8 218.2 121 | 362.8 455.8 210.8 122 | 397.3 439.3 202.1 123 | 468.7 398.1 137.3 124 | 157.8 542.2 321.0 125 | 193.8 334.6 279.5 126 | 483.9 606.0 48.5 127 | 404.6 504.4 168.0 128 | 116.1 395.5 538.4 129 | 423.4 420.2 197.3 130 | 387.0 578.5 177.1 131 | 239.0 367.2 272.9 132 | 415.9 542.1 94.9 133 | 141.3 642.8 545.4 134 | 401.2 377.3 237.1 135 | 442.2 521.1 63.0 136 | 240.9 821.1 337.1 137 | 242.6 860.2 395.8 138 | 127.6 550.7 440.7 139 | 199.3 772.5 376.6 140 | 142.3 326.2 393.2 141 | 371.6 939.5 239.9 142 | 487.4 614.2 92.2 143 | 374.6 531.0 180.0 144 | 158.6 393.1 344.5 145 | 297.6 385.6 234.9 146 | 324.4 387.8 226.7 147 | 382.5 775.3 201.5 148 | 166.8 614.8 330.6 149 | 430.8 988.5 249.8 150 | 333.1 938.2 312.8 151 | 291.5 906.2 348.9 152 | 500.0 162.0 158.9 153 | 500.0 996.2 240.7 154 | 350.8 385.9 225.2 155 | 375.8 381.8 229.8 156 | 392.6 379.7 236.6 157 | 170.2 327.6 313.6 158 | 372.8 355.1 224.7 159 | 345.1 346.5 219.1 160 | 317.4 345.0 220.8 161 | 289.9 349.5 229.2 162 | 272.8 356.6 240.1 163 | 122.5 311.7 485.2 164 | 279.4 381.0 245.7 165 | 500.0 652.6 120.9 166 | 394.2 667.7 154.4 167 | 423.9 593.3 114.8 168 | 454.8 655.0 124.2 169 | 500.0 343.4 149.1 170 | 284.5 867.4 307.1 171 | 326.4 903.1 275.6 172 | 429.2 966.8 208.8 173 | 203.4 807.4 455.2 174 | 392.2 368.3 232.0 175 | 447.3 455.1 125.6 176 | 500.0 973.9 201.1 177 | 377.8 967.1 274.9 178 | 140.9 630.4 434.6 179 | 441.8 773.2 167.7 180 | 434.6 784.7 163.4 181 | 428.1 800.7 158.8 182 | 423.3 819.3 164.4 183 | 409.4 845.8 190.5 184 | 396.3 747.9 190.2 185 | 388.4 744.8 190.3 186 | 381.8 741.0 190.3 187 | 339.9 717.7 197.4 188 | 211.9 634.2 263.6 189 | 447.4 419.2 156.6 190 | 426.4 347.2 213.3 191 | 406.0 353.0 226.2 192 | 403.4 749.6 193.6 193 | 212.1 721.4 314.0 194 | 456.9 345.2 171.1 195 | 392.1 874.4 208.8 196 | 500.0 465.4 64.6 197 | 470.1 436.4 103.5 198 | 500.0 429.0 95.9 199 | 424.6 511.5 124.8 200 | 500.0 940.8 184.4 201 | 500.0 901.5 185.2 202 | 443.8 897.1 191.5 203 | 320.2 803.3 237.8 204 | 359.4 607.5 185.9 205 | 352.3 843.3 224.8 206 | 283.1 602.8 200.4 207 | 329.4 644.0 192.2 208 | 258.2 661.2 226.9 209 | 433.5 934.8 192.9 210 | 405.2 532.0 148.2 211 | 293.4 829.1 269.1 212 | 334.8 869.5 249.6 213 | 295.1 756.7 239.1 214 | 178.8 676.5 342.4 215 | 254.2 761.9 271.6 216 | 162.2 704.5 419.8 217 | 305.8 690.8 208.7 218 | 426.7 479.2 148.4 219 | 428.6 581.2 74.8 220 | 409.4 590.7 110.6 221 | 441.4 557.0 51.4 222 | 400.1 310.8 203.6 223 | 346.2 286.2 198.4 224 | 298.7 282.1 201.2 225 | 259.4 290.6 215.2 226 | 232.5 309.3 236.1 227 | 216.5 376.2 286.4 228 | 121.6 473.8 435.0 229 | 241.7 427.6 259.0 230 | 275.4 437.9 238.9 231 | 319.4 437.1 225.1 232 | 360.4 428.8 218.6 233 | 392.1 417.5 217.8 234 | 414.2 406.1 215.6 235 | 120.2 475.0 561.5 236 | 410.2 600.4 127.1 237 | 447.1 487.4 98.8 238 | 450.6 582.8 48.9 239 | 468.0 600.7 62.0 240 | 451.2 589.2 64.9 241 | 418.9 611.0 138.4 242 | 473.9 603.8 53.1 243 | 477.5 612.0 92.2 244 | 413.2 378.0 231.9 245 | 431.5 386.1 211.7 246 | 443.0 392.0 188.4 247 | 262.6 362.1 251.6 248 | 238.2 342.3 253.4 249 | 527.8 470.8 75.8 250 | 733.5 375.6 255.2 251 | 529.6 606.8 91.8 252 | 859.8 243.2 416.6 253 | 643.5 405.7 220.1 254 | 675.2 408.5 222.6 255 | 706.4 407.6 232.1 256 | 747.9 388.8 260.1 257 | 616.4 398.0 223.9 258 | 690.5 302.6 209.1 259 | 653.4 307.0 209.2 260 | 725.4 309.0 220.3 261 | 748.0 324.0 234.0 262 | 776.1 412.4 276.5 263 | 622.5 906.5 221.4 264 | 745.4 367.0 266.4 265 | 874.8 398.4 426.0 266 | 813.3 386.1 310.9 267 | 665.9 559.7 186.7 268 | 539.1 695.2 115.1 269 | 534.7 729.6 130.0 270 | 577.4 710.5 135.2 271 | 601.6 726.6 161.4 272 | 563.9 733.8 143.9 273 | 586.0 740.5 168.1 274 | 649.6 790.8 217.8 275 | 526.9 587.5 32.8 276 | 531.4 552.3 22.9 277 | 794.5 290.1 245.7 278 | 598.5 469.0 183.1 279 | 601.4 579.0 114.5 280 | 604.4 560.0 124.7 281 | 765.5 556.2 228.8 282 | 529.9 508.3 43.3 283 | 717.6 247.9 188.2 284 | 763.0 262.0 214.6 285 | 826.6 178.9 351.5 286 | 572.5 280.1 158.4 287 | 621.6 322.9 218.7 288 | 672.6 752.8 223.3 289 | 832.2 734.2 510.8 290 | 576.0 599.4 127.0 291 | 550.5 609.9 123.7 292 | 627.7 755.8 219.1 293 | 619.0 755.8 209.7 294 | 780.8 238.8 230.9 295 | 602.5 595.0 132.7 296 | 656.0 250.4 168.7 297 | 661.4 221.1 162.2 298 | 692.0 95.4 223.1 299 | 803.5 211.8 282.4 300 | 672.7 160.9 190.7 301 | 814.1 274.5 274.8 302 | 837.9 263.7 340.9 303 | 537.9 713.7 118.7 304 | 571.0 723.0 137.9 305 | 594.3 733.8 161.8 306 | 568.4 606.0 134.1 307 | 623.2 755.8 214.1 308 | 610.6 770.3 196.1 309 | 614.7 755.7 208.8 310 | 561.3 587.3 81.2 311 | 579.3 744.5 173.8 312 | 558.0 740.2 157.5 313 | 531.5 738.7 144.9 314 | 546.2 864.8 185.3 315 | 542.0 835.2 151.8 316 | 539.8 814.0 145.9 317 | 536.2 793.0 151.9 318 | 531.9 777.5 155.9 319 | 579.5 768.3 182.2 320 | 586.9 775.0 178.0 321 | 595.4 785.0 173.8 322 | 602.2 796.2 180.2 323 | 636.9 688.8 172.6 324 | 872.9 556.0 561.8 325 | 595.4 763.2 198.6 326 | 603.0 766.3 195.8 327 | 541.3 623.5 122.3 328 | 591.5 616.0 157.0 329 | 545.6 617.0 122.2 330 | 633.8 489.9 194.2 331 | 688.6 514.0 203.8 332 | 609.8 576.2 143.4 333 | 773.7 128.2 282.6 334 | 751.3 175.8 233.4 335 | 729.2 219.8 194.3 336 | 624.2 819.5 204.0 337 | 582.4 231.3 146.1 338 | 591.8 160.0 166.3 339 | 604.0 80.8 188.6 340 | 732.7 401.3 246.6 341 | 811.6 438.0 294.8 342 | 598.5 389.6 228.5 343 | 780.2 339.9 263.2 344 | 575.0 445.5 171.2 345 | 582.2 565.2 82.1 346 | 848.1 466.5 331.4 347 | 789.4 463.5 264.0 348 | 747.0 476.8 232.9 349 | 683.5 470.8 218.2 350 | 637.2 455.8 210.8 351 | 602.7 439.3 202.1 352 | 531.3 398.1 137.3 353 | 842.2 542.2 321.0 354 | 806.2 334.6 279.5 355 | 516.1 606.0 48.5 356 | 595.5 504.4 168.0 357 | 883.9 395.5 538.4 358 | 576.6 420.2 197.3 359 | 613.0 578.5 177.1 360 | 761.0 367.2 272.9 361 | 584.0 542.1 94.9 362 | 858.7 642.8 545.4 363 | 598.8 377.3 237.1 364 | 557.8 521.1 63.0 365 | 759.1 821.1 337.1 366 | 757.4 860.2 395.8 367 | 872.4 550.7 440.7 368 | 800.7 772.5 376.6 369 | 857.7 326.2 393.2 370 | 628.4 939.5 239.9 371 | 512.6 614.2 92.2 372 | 625.4 531.0 180.0 373 | 841.5 393.1 344.5 374 | 702.5 385.6 234.9 375 | 675.6 387.8 226.7 376 | 617.5 775.3 201.5 377 | 833.2 614.8 330.6 378 | 569.2 988.5 249.8 379 | 666.9 938.2 312.8 380 | 708.5 906.2 348.9 381 | 649.2 385.9 225.2 382 | 624.2 381.8 229.8 383 | 607.5 379.7 236.6 384 | 829.8 327.6 313.6 385 | 627.2 355.1 224.7 386 | 654.9 346.5 219.1 387 | 682.6 345.0 220.8 388 | 710.0 349.5 229.2 389 | 727.2 356.6 240.1 390 | 877.5 311.7 485.2 391 | 720.6 381.0 245.7 392 | 605.8 667.7 154.4 393 | 576.0 593.3 114.8 394 | 545.2 655.0 124.2 395 | 715.5 867.4 307.1 396 | 673.6 903.1 275.6 397 | 570.8 966.8 208.8 398 | 796.6 807.4 455.2 399 | 607.8 368.3 232.0 400 | 552.7 455.1 125.6 401 | 622.2 967.1 274.9 402 | 859.0 630.4 434.6 403 | 558.2 773.2 167.7 404 | 565.4 784.7 163.4 405 | 572.0 800.7 158.8 406 | 576.7 819.3 164.4 407 | 590.6 845.8 190.5 408 | 603.7 747.9 190.2 409 | 611.5 744.8 190.3 410 | 618.2 741.0 190.3 411 | 660.1 717.7 197.4 412 | 788.1 634.2 263.6 413 | 552.6 419.2 156.6 414 | 573.6 347.2 213.3 415 | 594.0 353.0 226.2 416 | 596.6 749.6 193.6 417 | 787.9 721.4 314.0 418 | 543.0 345.2 171.1 419 | 607.9 874.4 208.8 420 | 529.9 436.4 103.5 421 | 575.5 511.5 124.8 422 | 556.2 897.1 191.5 423 | 679.8 803.3 237.8 424 | 640.6 607.5 185.9 425 | 647.7 843.3 224.8 426 | 717.0 602.8 200.4 427 | 670.6 644.0 192.2 428 | 741.8 661.2 226.9 429 | 566.5 934.8 192.9 430 | 594.8 532.0 148.2 431 | 706.5 829.1 269.1 432 | 665.2 869.5 249.6 433 | 705.0 756.7 239.1 434 | 821.2 676.5 342.4 435 | 745.8 761.9 271.6 436 | 837.8 704.5 419.8 437 | 694.2 690.8 208.7 438 | 573.3 479.2 148.4 439 | 571.5 581.2 74.8 440 | 590.5 590.7 110.6 441 | 558.5 557.0 51.4 442 | 599.9 310.8 203.6 443 | 653.8 286.2 198.4 444 | 701.3 282.1 201.2 445 | 740.6 290.6 215.2 446 | 767.5 309.3 236.1 447 | 783.5 376.2 286.4 448 | 878.5 473.8 435.0 449 | 758.3 427.6 259.0 450 | 724.5 437.9 238.9 451 | 680.6 437.1 225.1 452 | 639.6 428.8 218.6 453 | 607.9 417.5 217.8 454 | 585.8 406.1 215.6 455 | 879.8 475.0 561.5 456 | 589.8 600.4 127.1 457 | 552.9 487.4 98.8 458 | 549.4 582.8 48.9 459 | 532.0 600.7 62.0 460 | 548.8 589.2 64.9 461 | 581.1 611.0 138.4 462 | 526.1 603.8 53.1 463 | 522.5 612.0 92.2 464 | 586.8 378.0 231.9 465 | 568.5 386.1 211.7 466 | 557.0 392.0 188.4 467 | 737.4 362.1 251.6 468 | 761.8 342.3 253.4 469 | 328.8 363.2 235.9 470 | 361.5 362.8 235.9 471 | 328.6 333.5 235.9 472 | 295.7 363.5 235.9 473 | 328.8 392.9 235.9 474 | 671.2 363.2 235.9 475 | 704.3 363.5 235.9 476 | 671.4 333.5 235.9 477 | 638.5 362.8 235.9 478 | 671.2 392.9 235.9 479 | -------------------------------------------------------------------------------- /data/face_pts_mean_mainKps.txt: -------------------------------------------------------------------------------- 1 | 8.660000000000000000e+02 3.320000000000000000e+02 1.300000000000000000e+01 2 | 8.500000000000000000e+02 3.170000000000000000e+02 -3.000000000000000000e+00 3 | 8.260000000000000000e+02 3.110000000000000000e+02 -1.600000000000000000e+01 4 | 7.930000000000000000e+02 3.140000000000000000e+02 -2.600000000000000000e+01 5 | 7.470000000000000000e+02 3.300000000000000000e+02 -3.100000000000000000e+01 6 | 7.520000000000000000e+02 3.050000000000000000e+02 -3.700000000000000000e+01 7 | 7.960000000000000000e+02 2.990000000000000000e+02 -2.900000000000000000e+01 8 | 8.320000000000000000e+02 2.970000000000000000e+02 -1.300000000000000000e+01 9 | 8.580000000000000000e+02 3.060000000000000000e+02 5.000000000000000000e+00 10 | 8.750000000000000000e+02 3.240000000000000000e+02 2.800000000000000000e+01 11 | 5.480000000000000000e+02 3.260000000000000000e+02 1.800000000000000000e+01 12 | 5.640000000000000000e+02 3.120000000000000000e+02 1.000000000000000000e+00 13 | 5.880000000000000000e+02 3.060000000000000000e+02 -1.300000000000000000e+01 14 | 6.210000000000000000e+02 3.080000000000000000e+02 -2.400000000000000000e+01 15 | 6.670000000000000000e+02 3.250000000000000000e+02 -3.000000000000000000e+01 16 | 6.620000000000000000e+02 2.990000000000000000e+02 -3.600000000000000000e+01 17 | 6.190000000000000000e+02 2.930000000000000000e+02 -2.700000000000000000e+01 18 | 5.820000000000000000e+02 2.910000000000000000e+02 -1.000000000000000000e+01 19 | 5.550000000000000000e+02 3.000000000000000000e+02 1.000000000000000000e+01 20 | 5.390000000000000000e+02 3.180000000000000000e+02 3.300000000000000000e+01 21 | 7.460000000000000000e+02 4.170000000000000000e+02 -2.500000000000000000e+01 22 | 7.570000000000000000e+02 4.480000000000000000e+02 -2.700000000000000000e+01 23 | 7.680000000000000000e+02 4.870000000000000000e+02 -2.200000000000000000e+01 24 | 7.550000000000000000e+02 5.070000000000000000e+02 -3.300000000000000000e+01 25 | 7.260000000000000000e+02 5.120000000000000000e+02 -5.100000000000000000e+01 26 | 7.030000000000000000e+02 5.120000000000000000e+02 -5.600000000000000000e+01 27 | 6.790000000000000000e+02 5.110000000000000000e+02 -5.100000000000000000e+01 28 | 6.520000000000000000e+02 5.050000000000000000e+02 -3.200000000000000000e+01 29 | 6.410000000000000000e+02 4.840000000000000000e+02 -2.100000000000000000e+01 30 | 6.520000000000000000e+02 4.460000000000000000e+02 -2.600000000000000000e+01 31 | 6.650000000000000000e+02 4.160000000000000000e+02 -2.300000000000000000e+01 32 | 7.060000000000000000e+02 3.870000000000000000e+02 -4.900000000000000000e+01 33 | 7.050000000000000000e+02 4.090000000000000000e+02 -6.500000000000000000e+01 34 | 7.050000000000000000e+02 4.280000000000000000e+02 -8.200000000000000000e+01 35 | 7.040000000000000000e+02 4.480000000000000000e+02 -9.900000000000000000e+01 36 | 7.030000000000000000e+02 4.740000000000000000e+02 -1.070000000000000000e+02 37 | 6.324848484848484986e+02 5.827878787878787534e+02 3.545454545454545414e+00 38 | 6.364040404040404155e+02 5.946868686868687064e+02 -5.989898989898989612e+00 39 | 6.433333333333333712e+02 6.077878787878787534e+02 -1.729292929292929415e+01 40 | 6.563232323232323324e+02 6.224848484848484986e+02 -2.594949494949494806e+01 41 | 6.754141414141414543e+02 6.323030303030302548e+02 -3.345454545454545325e+01 42 | 6.995353535353535790e+02 6.354242424242423795e+02 -3.512121212121212466e+01 43 | 7.245050505050504626e+02 6.321717171717172050e+02 -3.454545454545454675e+01 44 | 7.451818181818181301e+02 6.218888888888889142e+02 -2.800000000000000000e+01 45 | 7.605050505050504626e+02 6.067676767676767895e+02 -1.987878787878787890e+01 46 | 7.698888888888889142e+02 5.930909090909091219e+02 -9.252525252525252597e+00 47 | 7.758484848484848726e+02 5.806767676767676676e+02 -2.222222222222222099e-01 48 | 7.710808080808080831e+02 5.711313131313131635e+02 -1.558585858585858652e+01 49 | 7.615151515151515014e+02 5.625555555555555429e+02 -3.061616161616161591e+01 50 | 7.474040404040404155e+02 5.534141414141414543e+02 -4.424242424242424221e+01 51 | 7.245353535353535790e+02 5.453232323232323324e+02 -5.421212121212121104e+01 52 | 7.015050505050504626e+02 5.487474747474747119e+02 -5.628282828282828376e+01 53 | 6.792626262626262132e+02 5.458989898989899530e+02 -5.314141414141413833e+01 54 | 6.583232323232323324e+02 5.545757575757576205e+02 -4.211111111111111427e+01 55 | 6.454343434343434183e+02 5.639797979797980361e+02 -2.765656565656565746e+01 56 | 6.369494949494949196e+02 5.729292929292929557e+02 -1.195959595959596022e+01 57 | 6.399898989898989612e+02 5.829696969696969973e+02 -2.060606060606060552e+00 58 | 6.498181818181818699e+02 5.886161616161616621e+02 -7.171717171717172157e+00 59 | 6.575353535353535790e+02 5.926363636363636260e+02 -1.580808080808080796e+01 60 | 6.682222222222221717e+02 5.965050505050504626e+02 -2.379797979797979934e+01 61 | 6.823636363636363740e+02 5.992626262626262132e+02 -3.029292929292929415e+01 62 | 7.001515151515151274e+02 6.004949494949495374e+02 -3.307070707070707272e+01 63 | 7.189191919191919169e+02 5.986262626262625872e+02 -3.135353535353535293e+01 64 | 7.344343434343434183e+02 5.951818181818181301e+02 -2.544444444444444287e+01 65 | 7.469797979797980361e+02 5.910505050505050804e+02 -1.815151515151515227e+01 66 | 7.564343434343434183e+02 5.868585858585859114e+02 -1.007070707070707094e+01 67 | 7.680000000000000000e+02 5.807878787878787534e+02 -5.333333333333333037e+00 68 | 7.580707070707070443e+02 5.764747474747474598e+02 -1.369696969696969724e+01 69 | 7.479898989898989612e+02 5.731616161616161662e+02 -2.410101010101010033e+01 70 | 7.355858585858585457e+02 5.706969696969697452e+02 -3.276767676767676818e+01 71 | 7.201616161616161662e+02 5.697777777777778283e+02 -3.927272727272727337e+01 72 | 7.016868686868687064e+02 5.645858585858585457e+02 -4.913131313131312794e+01 73 | 6.847070707070706703e+02 5.702525252525252881e+02 -3.825252525252525260e+01 74 | 6.703434343434342964e+02 5.717171717171717091e+02 -3.096969696969696884e+01 75 | 6.591111111111110858e+02 5.745858585858585457e+02 -2.187878787878787890e+01 76 | 6.499898989898989612e+02 5.783030303030302548e+02 -1.081818181818181834e+01 77 | 8.430000000000000000e+02 3.720000000000000000e+02 2.500000000000000000e+01 78 | 8.360000000000000000e+02 3.780000000000000000e+02 1.900000000000000000e+01 79 | 8.290000000000000000e+02 3.830000000000000000e+02 1.400000000000000000e+01 80 | 8.190000000000000000e+02 3.860000000000000000e+02 9.000000000000000000e+00 81 | 8.040000000000000000e+02 3.870000000000000000e+02 5.000000000000000000e+00 82 | 7.890000000000000000e+02 3.850000000000000000e+02 4.000000000000000000e+00 83 | 7.740000000000000000e+02 3.830000000000000000e+02 7.000000000000000000e+00 84 | 7.640000000000000000e+02 3.820000000000000000e+02 1.100000000000000000e+01 85 | 7.590000000000000000e+02 3.800000000000000000e+02 1.100000000000000000e+01 86 | 7.650000000000000000e+02 3.740000000000000000e+02 9.000000000000000000e+00 87 | 7.760000000000000000e+02 3.640000000000000000e+02 5.000000000000000000e+00 88 | 7.920000000000000000e+02 3.570000000000000000e+02 2.000000000000000000e+00 89 | 8.090000000000000000e+02 3.550000000000000000e+02 2.000000000000000000e+00 90 | 8.250000000000000000e+02 3.580000000000000000e+02 6.000000000000000000e+00 91 | 8.340000000000000000e+02 3.640000000000000000e+02 1.200000000000000000e+01 92 | 8.390000000000000000e+02 3.680000000000000000e+02 1.800000000000000000e+01 93 | 5.740000000000000000e+02 3.660000000000000000e+02 2.900000000000000000e+01 94 | 5.800000000000000000e+02 3.720000000000000000e+02 2.300000000000000000e+01 95 | 5.860000000000000000e+02 3.770000000000000000e+02 1.800000000000000000e+01 96 | 5.960000000000000000e+02 3.800000000000000000e+02 1.200000000000000000e+01 97 | 6.110000000000000000e+02 3.820000000000000000e+02 8.000000000000000000e+00 98 | 6.260000000000000000e+02 3.810000000000000000e+02 7.000000000000000000e+00 99 | 6.400000000000000000e+02 3.780000000000000000e+02 1.000000000000000000e+01 100 | 6.500000000000000000e+02 3.770000000000000000e+02 1.300000000000000000e+01 101 | 6.550000000000000000e+02 3.760000000000000000e+02 1.300000000000000000e+01 102 | 6.500000000000000000e+02 3.690000000000000000e+02 1.100000000000000000e+01 103 | 6.390000000000000000e+02 3.590000000000000000e+02 7.000000000000000000e+00 104 | 6.240000000000000000e+02 3.520000000000000000e+02 4.000000000000000000e+00 105 | 6.080000000000000000e+02 3.500000000000000000e+02 5.000000000000000000e+00 106 | 5.920000000000000000e+02 3.530000000000000000e+02 1.000000000000000000e+01 107 | 5.830000000000000000e+02 3.580000000000000000e+02 1.600000000000000000e+01 108 | 5.780000000000000000e+02 3.620000000000000000e+02 2.200000000000000000e+01 109 | 9.080000000000000000e+02 4.380000000000000000e+02 1.790000000000000000e+02 110 | 9.070000000000000000e+02 4.820000000000000000e+02 1.790000000000000000e+02 111 | 9.010000000000000000e+02 5.300000000000000000e+02 1.710000000000000000e+02 112 | 8.880000000000000000e+02 5.820000000000000000e+02 1.550000000000000000e+02 113 | 8.680000000000000000e+02 6.240000000000000000e+02 1.270000000000000000e+02 114 | 8.470000000000000000e+02 6.540000000000000000e+02 9.900000000000000000e+01 115 | 8.190000000000000000e+02 6.800000000000000000e+02 7.700000000000000000e+01 116 | 7.950000000000000000e+02 6.990000000000000000e+02 6.000000000000000000e+01 117 | 7.690000000000000000e+02 7.160000000000000000e+02 4.100000000000000000e+01 118 | 7.380000000000000000e+02 7.300000000000000000e+02 2.900000000000000000e+01 119 | 7.000000000000000000e+02 7.340000000000000000e+02 2.500000000000000000e+01 120 | 6.610000000000000000e+02 7.280000000000000000e+02 3.000000000000000000e+01 121 | 6.330000000000000000e+02 7.130000000000000000e+02 4.300000000000000000e+01 122 | 6.100000000000000000e+02 6.930000000000000000e+02 6.300000000000000000e+01 123 | 5.890000000000000000e+02 6.730000000000000000e+02 8.000000000000000000e+01 124 | 5.640000000000000000e+02 6.450000000000000000e+02 1.020000000000000000e+02 125 | 5.450000000000000000e+02 6.140000000000000000e+02 1.310000000000000000e+02 126 | 5.270000000000000000e+02 5.710000000000000000e+02 1.590000000000000000e+02 127 | 5.140000000000000000e+02 5.190000000000000000e+02 1.760000000000000000e+02 128 | 5.080000000000000000e+02 4.710000000000000000e+02 1.840000000000000000e+02 129 | 5.060000000000000000e+02 4.270000000000000000e+02 1.840000000000000000e+02 130 | 7.740000000000000000e+02 4.620000000000000000e+02 -2.100000000000000000e+01 131 | 7.960000000000000000e+02 4.780000000000000000e+02 -1.700000000000000000e+01 132 | 8.240000000000000000e+02 5.020000000000000000e+02 -1.000000000000000000e+01 133 | 8.370000000000000000e+02 5.350000000000000000e+02 7.000000000000000000e+00 134 | 8.400000000000000000e+02 5.930000000000000000e+02 3.400000000000000000e+01 135 | 8.220000000000000000e+02 6.570000000000000000e+02 5.700000000000000000e+01 136 | 5.860000000000000000e+02 6.490000000000000000e+02 6.000000000000000000e+01 137 | 5.700000000000000000e+02 5.850000000000000000e+02 3.900000000000000000e+01 138 | 5.720000000000000000e+02 5.280000000000000000e+02 1.100000000000000000e+01 139 | 5.850000000000000000e+02 4.970000000000000000e+02 -6.000000000000000000e+00 140 | 6.130000000000000000e+02 4.740000000000000000e+02 -1.500000000000000000e+01 141 | 6.350000000000000000e+02 4.590000000000000000e+02 -1.900000000000000000e+01 142 | -------------------------------------------------------------------------------- /data/pca.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/data/pca.pkl -------------------------------------------------------------------------------- /data/video_concat.txt: -------------------------------------------------------------------------------- 1 | file 'front.mp4' 2 | file 'back.mp4' -------------------------------------------------------------------------------- /data_preparation.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import tqdm 3 | import numpy as np 4 | import cv2 5 | import sys 6 | import os 7 | import math 8 | import pickle 9 | import mediapipe as mp 10 | mp_face_mesh = mp.solutions.face_mesh 11 | mp_face_detection = mp.solutions.face_detection 12 | 13 | def detect_face(frame): 14 | # 剔除掉多个人脸、大角度侧脸(鼻子不在两个眼之间)、部分人脸框在画面外、人脸像素低于80*80的 15 | with mp_face_detection.FaceDetection( 16 | model_selection=1, min_detection_confidence=0.6) as face_detection: 17 | results = face_detection.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 18 | if not results.detections or len(results.detections) > 1: 19 | return -1, None 20 | rect = results.detections[0].location_data.relative_bounding_box 21 | out_rect = [rect.xmin, rect.xmin + rect.width, rect.ymin, rect.ymin + rect.height] 22 | nose_ = mp_face_detection.get_key_point( 23 | results.detections[0], mp_face_detection.FaceKeyPoint.NOSE_TIP) 24 | l_eye_ = mp_face_detection.get_key_point( 25 | results.detections[0], mp_face_detection.FaceKeyPoint.LEFT_EYE) 26 | r_eye_ = mp_face_detection.get_key_point( 27 | results.detections[0], mp_face_detection.FaceKeyPoint.RIGHT_EYE) 28 | # print(nose_, l_eye_, r_eye_) 29 | if nose_.x > l_eye_.x or nose_.x < r_eye_.x: 30 | return -2, out_rect 31 | 32 | h, w = frame.shape[:2] 33 | # print(frame.shape) 34 | if rect.xmin < 0 or rect.ymin < 0 or rect.xmin + rect.width > w or rect.ymin + rect.height > h: 35 | return -3, out_rect 36 | if rect.width * w < 100 or rect.height * h < 100: 37 | return -4, out_rect 38 | return 1, out_rect 39 | 40 | 41 | def calc_face_interact(face0, face1): 42 | x_min = min(face0[0], face1[0]) 43 | x_max = max(face0[1], face1[1]) 44 | y_min = min(face0[2], face1[2]) 45 | y_max = max(face0[3], face1[3]) 46 | tmp0 = ((face0[1] - face0[0]) * (face0[3] - face0[2])) / ((x_max - x_min) * (y_max - y_min)) 47 | tmp1 = ((face1[1] - face1[0]) * (face1[3] - face1[2])) / ((x_max - x_min) * (y_max - y_min)) 48 | return min(tmp0, tmp1) 49 | 50 | 51 | def detect_face_mesh(frame): 52 | with mp_face_mesh.FaceMesh( 53 | static_image_mode=True, 54 | max_num_faces=1, 55 | refine_landmarks=True, 56 | min_detection_confidence=0.5) as face_mesh: 57 | results = face_mesh.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 58 | pts_3d = np.zeros([478, 3]) 59 | if not results.multi_face_landmarks: 60 | print("****** WARNING! No face detected! ******") 61 | else: 62 | image_height, image_width = frame.shape[:2] 63 | for face_landmarks in results.multi_face_landmarks: 64 | for index_, i in enumerate(face_landmarks.landmark): 65 | x_px = min(math.floor(i.x * image_width), image_width - 1) 66 | y_px = min(math.floor(i.y * image_height), image_height - 1) 67 | z_px = min(math.floor(i.z * image_width), image_width - 1) 68 | pts_3d[index_] = np.array([x_px, y_px, z_px]) 69 | return pts_3d 70 | 71 | 72 | def ExtractFromVideo(video_path, circle = False): 73 | cap = cv2.VideoCapture(video_path) 74 | if not cap.isOpened(): 75 | return 0 76 | 77 | dir_path = os.path.dirname(video_path) 78 | vid_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 79 | vid_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 80 | 81 | totalFrames = cap.get(cv2.CAP_PROP_FRAME_COUNT) # 总帧数 82 | totalFrames = int(totalFrames) 83 | pts_3d = np.zeros([totalFrames, 478, 3]) 84 | frame_index = 0 85 | face_rect_list = [] 86 | mat_list = [] 87 | model_name = os.path.basename(video_path)[:-4] 88 | 89 | # os.makedirs("../preparation/{}/image".format(model_name)) 90 | for frame_index in tqdm.tqdm(range(totalFrames)): 91 | ret, frame = cap.read() # 按帧读取视频 92 | # #到视频结尾时终止 93 | if ret is False: 94 | break 95 | # cv2.imwrite("../preparation/{}/image/{:0>6d}.png".format(model_name, frame_index), frame) 96 | tag_, rect = detect_face(frame) 97 | if frame_index == 0 and tag_ != 1: 98 | print("第一帧人脸检测异常,请剔除掉多个人脸、大角度侧脸(鼻子不在两个眼之间)、部分人脸框在画面外、人脸像素低于80*80") 99 | pts_3d = -1 100 | break 101 | elif tag_ == -1: # 有时候人脸检测会失败,就用上一帧的结果替代这一帧的结果 102 | rect = face_rect_list[-1] 103 | elif tag_ != 1: 104 | print("第{}帧人脸检测异常,请剔除掉多个人脸、大角度侧脸(鼻子不在两个眼之间)、部分人脸框在画面外、人脸像素低于80*80, tag: {}".format(frame_index, tag_)) 105 | # exit() 106 | if len(face_rect_list) > 0: 107 | face_area_inter = calc_face_interact(face_rect_list[-1], rect) 108 | # print(frame_index, face_area_inter) 109 | if face_area_inter < 0.6: 110 | print("人脸区域变化幅度太大,请复查,超出值为{}, frame_num: {}".format(face_area_inter, frame_index)) 111 | pts_3d = -2 112 | break 113 | 114 | face_rect_list.append(rect) 115 | 116 | x_min = rect[0] * vid_width 117 | y_min = rect[2] * vid_height 118 | x_max = rect[1] * vid_width 119 | y_max = rect[3] * vid_height 120 | seq_w, seq_h = x_max - x_min, y_max - y_min 121 | x_mid, y_mid = (x_min + x_max) / 2, (y_min + y_max) / 2 122 | # x_min = int(max(0, x_mid - seq_w * 0.65)) 123 | # y_min = int(max(0, y_mid - seq_h * 0.4)) 124 | # x_max = int(min(vid_width, x_mid + seq_w * 0.65)) 125 | # y_max = int(min(vid_height, y_mid + seq_h * 0.8)) 126 | crop_size = int(max(seq_w * 1.35, seq_h * 1.35)) 127 | x_min = int(max(0, x_mid - crop_size * 0.5)) 128 | y_min = int(max(0, y_mid - crop_size * 0.45)) 129 | x_max = int(min(vid_width, x_min + crop_size)) 130 | y_max = int(min(vid_height, y_min + crop_size)) 131 | 132 | frame_face = frame[y_min:y_max, x_min:x_max] 133 | print(y_min, y_max, x_min, x_max) 134 | # cv2.imshow("s", frame_face) 135 | # cv2.waitKey(20) 136 | frame_kps = detect_face_mesh(frame_face) 137 | pts_3d[frame_index] = frame_kps + np.array([x_min, y_min, 0]) 138 | cap.release() # 释放视频对象 139 | return pts_3d 140 | 141 | 142 | def CirculateVideo(video_in_path, video_out_path, export_imgs = False): 143 | # 1 视频转换为25FPS, 并折叠循环拼接 144 | front_video_path = "front.mp4" 145 | back_video_path = "back.mp4" 146 | # ffmpeg_cmd = "ffmpeg -i {} -r 25 -ss 00:00:00 -t 00:02:00 -an -loglevel quiet -y {}".format(video_in_path, front_video_path) 147 | ffmpeg_cmd = "ffmpeg -i {} -r 25 -an -loglevel quiet -y {}".format(video_in_path, front_video_path) 148 | os.system(ffmpeg_cmd) 149 | 150 | # front_video_path = video_in_path 151 | 152 | cap = cv2.VideoCapture(front_video_path) 153 | vid_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 154 | vid_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 155 | frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) 156 | cap.release() 157 | 158 | 159 | ffmpeg_cmd = "ffmpeg -i {} -vf reverse -y {}".format(front_video_path, back_video_path) 160 | os.system(ffmpeg_cmd) 161 | ffmpeg_cmd = "ffmpeg -f concat -i {} -c:v copy -y {}".format("data/video_concat.txt", video_out_path) 162 | os.system(ffmpeg_cmd) 163 | # exit() 164 | print("正向视频帧数:", frames) 165 | pts_3d = ExtractFromVideo(front_video_path) 166 | if type(pts_3d) is np.ndarray and len(pts_3d) == frames: 167 | print("关键点已提取") 168 | pts_3d = np.concatenate([pts_3d, pts_3d[::-1]], axis=0) 169 | Path_output_pkl = "{}/keypoint_rotate.pkl".format(os.path.dirname(video_out_path)) 170 | with open(Path_output_pkl, "wb") as f: 171 | pickle.dump(pts_3d, f) 172 | 173 | if export_imgs: 174 | # 计算整个视频中人脸的范围 175 | x_min, y_min, x_max, y_max = np.min(pts_3d[:, :, 0]), np.min( 176 | pts_3d[:, :, 1]), np.max( 177 | pts_3d[:, :, 0]), np.max(pts_3d[:, :, 1]) 178 | new_w = int((x_max - x_min) * 0.55) * 2 179 | new_h = int((y_max - y_min) * 0.6) * 2 180 | center_x = int((x_max + x_min) / 2.) 181 | center_y = int(y_min + (y_max - y_min) * 0.6) 182 | size = max(new_h, new_w) 183 | x_min, y_min, x_max, y_max = int(center_x - size // 2), int(center_y - size // 2), int( 184 | center_x + size // 2), int(center_y + size // 2) 185 | 186 | # 确定裁剪区域上边top和左边left坐标 187 | top = y_min 188 | left = x_min 189 | # 裁剪区域与原图的重合区域 190 | top_coincidence = int(max(top, 0)) 191 | bottom_coincidence = int(min(y_max, vid_height)) 192 | left_coincidence = int(max(left, 0)) 193 | right_coincidence = int(min(x_max, vid_width)) 194 | print("人脸活动范围:{}:{}, {}:{}".format(top_coincidence, bottom_coincidence, left_coincidence, right_coincidence)) 195 | np.savetxt("{}/face_rect.txt".format(os.path.dirname(video_out_path)), 196 | np.array([top_coincidence, bottom_coincidence, left_coincidence, right_coincidence])) 197 | os.makedirs("{}/image".format(os.path.dirname(video_out_path))) 198 | ffmpeg_cmd = "ffmpeg -i {} -vf crop={}:{}:{}:{},scale=512:512:flags=neighbor -loglevel quiet -y {}/image/%06d.png".format( 199 | front_video_path, 200 | right_coincidence - left_coincidence, 201 | bottom_coincidence - top_coincidence, 202 | left_coincidence, 203 | top_coincidence, 204 | os.path.dirname(video_out_path) 205 | ) 206 | os.system(ffmpeg_cmd) 207 | 208 | cap = cv2.VideoCapture(video_out_path) 209 | frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) 210 | cap.release() 211 | print("循环视频帧数:", frames) 212 | 213 | 214 | def main(): 215 | # 检查命令行参数的数量 216 | if len(sys.argv) != 2: 217 | print("Usage: python data_preparation.py ") 218 | sys.exit(1) # 参数数量不正确时退出程序 219 | 220 | # 获取video_name参数 221 | video_name = sys.argv[1] 222 | print(f"Video name is set to: {video_name}") 223 | 224 | new_data_path = "video_data/{}".format(uuid.uuid1()) 225 | os.makedirs(new_data_path, exist_ok=True) 226 | video_out_path = "{}/circle.mp4".format(new_data_path) 227 | CirculateVideo(video_name, video_out_path, export_imgs=False) 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import uuid 5 | import cv2 6 | import tqdm 7 | import shutil 8 | from talkingface.audio_model import AudioModel 9 | from talkingface.render_model import RenderModel 10 | 11 | def merge_audio_video(video_path, audio_path, output_video_name): 12 | print(f"Video path is set to: {video_path}") 13 | print(f"Audio path is set to: {audio_path}") 14 | print(f"Output video name is set to: {output_video_name}") 15 | 16 | audioModel = AudioModel() 17 | audioModel.loadModel("checkpoint/audio.pkl") 18 | 19 | renderModel = RenderModel() 20 | renderModel.loadModel("checkpoint/render.pth") 21 | pkl_path = os.path.join(video_path, "keypoint_rotate.pkl") 22 | video_file_path = os.path.join(video_path, "circle.mp4") 23 | renderModel.reset_charactor(video_file_path, pkl_path) 24 | 25 | wavpath = audio_path 26 | mouth_frame = audioModel.interface_wav(wavpath) 27 | cap_input = cv2.VideoCapture(video_file_path) 28 | vid_width = cap_input.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 29 | vid_height = cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 30 | cap_input.release() 31 | 32 | task_id = str(uuid.uuid1()) 33 | os.makedirs(f"output/{task_id}", exist_ok=True) 34 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 35 | save_path = f"output/{task_id}/silence.mp4" 36 | videoWriter = cv2.VideoWriter(save_path, fourcc, 25, (int(vid_width), int(vid_height))) 37 | 38 | for frame in tqdm.tqdm(mouth_frame): 39 | frame = renderModel.interface(frame) 40 | videoWriter.write(frame) 41 | 42 | videoWriter.release() 43 | 44 | final_video_path = f"../output/{output_video_name}.mp4" 45 | os.system(f"ffmpeg -i {save_path} -i {wavpath} -c:v libx264 -pix_fmt yuv420p -loglevel quiet {final_video_path}") 46 | shutil.rmtree(f"output/{task_id}") 47 | 48 | return final_video_path 49 | 50 | if __name__ == "__main__": 51 | video_path = "path_to_video" 52 | audio_path = "path_to_audio" 53 | output_video_name = "output_video_name" 54 | merge_audio_video(video_path, audio_path, output_video_name) 55 | 56 | -------------------------------------------------------------------------------- /demo_avatar.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | from scipy.io import wavfile 5 | import cv2 6 | import glob 7 | from talkingface.audio_model import AudioModel 8 | from talkingface.render_model import RenderModel 9 | 10 | audioModel = AudioModel() 11 | audioModel.loadModel("checkpoint/audio.pkl") 12 | 13 | renderModel = RenderModel() 14 | renderModel.loadModel("checkpoint/render.pth") 15 | test_video = "test" 16 | pkl_path = "video_data/{}/keypoint_rotate.pkl".format(test_video) 17 | video_path = "video_data/{}/circle.mp4".format(test_video) 18 | renderModel.reset_charactor(video_path, pkl_path) 19 | 20 | wavpath = "video_data/audio0.wav" 21 | rate, wav = wavfile.read(wavpath, mmap=False) 22 | index_ = 0 23 | frame_index__ = 0 24 | import sounddevice as sd 25 | sample_rate = 16000 26 | samples_per_read = int(0.04 * sample_rate) 27 | with sd.InputStream( 28 | channels=1, dtype="float32", samplerate=sample_rate 29 | ) as s: 30 | while True: 31 | samples, _ = s.read(samples_per_read) # a blocking read 32 | pcm_data = samples.reshape(-1) 33 | print(pcm_data.shape) 34 | mouth_frame = audioModel.interface_frame(pcm_data) 35 | frame = renderModel.interface(mouth_frame) 36 | cv2.imshow("s", frame) 37 | cv2.waitKey(10) 38 | index_ += 1 39 | -------------------------------------------------------------------------------- /front.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/front.mp4 -------------------------------------------------------------------------------- /go-web.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set "PATH=%CD%;%PATH%" 3 | set PYTHONUSERBASE=kelong\Lib\site-packages 4 | kelong\Scripts\python.exe webapp.py 5 | pause -------------------------------------------------------------------------------- /images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/images/1.png -------------------------------------------------------------------------------- /images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/images/2.png -------------------------------------------------------------------------------- /images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/images/3.png -------------------------------------------------------------------------------- /images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/images/4.png -------------------------------------------------------------------------------- /inp_keypoint.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/inp_keypoint.pkl -------------------------------------------------------------------------------- /keypoint_rotate.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/keypoint_rotate.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.12.3 2 | dominate==2.9.0 3 | fastapi==0.111.1 4 | kaldi_native_fbank==1.19.1 5 | librosa==0.10.1 6 | mediapipe==0.10.11 7 | numpy==1.24.3 8 | opencv_contrib_python==4.8.1.78 9 | opencv_python==4.9.0.80 10 | pandas==2.0.3 11 | Pillow==9.4.0 12 | Requests==2.32.3 13 | scipy==1.11.1 14 | sounddevice==0.4.6 15 | thop==0.1.1.post2209072238 16 | torch 17 | torchvision 18 | tqdm==4.65.0 19 | uvicorn==0.30.3 20 | visdom==0.2.4 21 | wandb==0.16.5 22 | gradio 23 | edge_tts -------------------------------------------------------------------------------- /talkingface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/__init__.py -------------------------------------------------------------------------------- /talkingface/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/__pycache__/audio_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/__pycache__/audio_model.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/__pycache__/render_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/__pycache__/render_model.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/__pycache__/run_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/__pycache__/run_utils.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/audio_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import kaldi_native_fbank as knf 4 | from scipy.io import wavfile 5 | import torch 6 | import librosa 7 | import pickle 8 | device = "cuda" if torch.cuda.is_available() else "cpu" 9 | import pickle 10 | import os 11 | def pca_process(x): 12 | a = x.reshape(15, 30, 3) 13 | # a = pca.mean_.reshape(15,30,3) 14 | tmp = a[:, :15] + a[:, 15:][:, ::-1] 15 | a[:, :15] = tmp / 2 16 | a[:, 15:] = a[:, :15][:, ::-1] 17 | return a.flatten() 18 | class AudioModel: 19 | def __init__(self): 20 | self.__net = None 21 | self.__fbank = None 22 | self.__fbank_processed_index = 0 23 | self.frame_index = 0 24 | 25 | current_dir = os.path.dirname(os.path.abspath(__file__)) 26 | Path_output_pkl = os.path.join(current_dir, "../data/pca.pkl") 27 | with open(Path_output_pkl, "rb") as f: 28 | pca = pickle.load(f) 29 | self.pca_mean_ = pca_process(pca.mean_) 30 | self.pca_components_ = np.zeros_like(pca.components_) 31 | self.pca_components_[0] = pca_process(pca.components_[0]) 32 | self.pca_components_[1] = pca_process(pca.components_[1]) 33 | self.pca_components_[2] = pca_process(pca.components_[2]) 34 | self.pca_components_[3] = pca_process(pca.components_[3]) 35 | self.pca_components_[4] = pca_process(pca.components_[4]) 36 | self.pca_components_[5] = pca_process(pca.components_[5]) 37 | 38 | self.reset() 39 | 40 | def loadModel(self, ckpt_path): 41 | # if method == "lstm": 42 | # ckpt_path = 'checkpoint/lstm/lstm_model_epoch_560.pth' 43 | # Audio2FeatureModel = torch.load(model_path).to(device) 44 | # Audio2FeatureModel.eval() 45 | from talkingface.models.audio2bs_lstm import Audio2Feature 46 | self.__net = Audio2Feature() # 调用模型Model 47 | self.__net.load_state_dict(torch.load(ckpt_path)) 48 | self.__net = self.__net.to(device) 49 | self.__net.eval() 50 | 51 | def reset(self): 52 | opts = knf.FbankOptions() 53 | opts.frame_opts.dither = 0 54 | opts.frame_opts.frame_length_ms = 50 55 | opts.frame_opts.frame_shift_ms = 20 56 | opts.mel_opts.num_bins = 80 57 | opts.frame_opts.snip_edges = False 58 | opts.mel_opts.debug_mel = False 59 | self.__fbank = knf.OnlineFbank(opts) 60 | 61 | self.h0 = torch.zeros(2, 1, 192).to(device) 62 | self.c0 = torch.zeros(2, 1, 192).to(device) 63 | 64 | self.__fbank_processed_index = 0 65 | 66 | audio_samples = np.zeros([320]) 67 | self.__fbank.accept_waveform(16000, audio_samples.tolist()) 68 | 69 | def interface_frame(self, audio_samples): 70 | # pcm为uint16位数据。 只处理一帧的数据, 16000/25 = 640 71 | self.__fbank.accept_waveform(16000, audio_samples.tolist()) 72 | orig_mel = np.zeros([2, 80]) 73 | 74 | orig_mel[0] = self.__fbank.get_frame(self.__fbank_processed_index) 75 | orig_mel[1] = self.__fbank.get_frame(self.__fbank_processed_index + 1) 76 | 77 | input = torch.from_numpy(orig_mel).unsqueeze(0).float().to(device) 78 | bs_array, self.h0, self.c0 = self.__net(input, self.h0, self.c0) 79 | bs_array = bs_array[0].detach().cpu().float().numpy() 80 | bs_real = bs_array[0] 81 | # print(self.__fbank_processed_index, self.__fbank.num_frames_ready, bs_real) 82 | 83 | frame = np.dot(bs_real[:6], self.pca_components_[:6]) + self.pca_mean_ 84 | # print(frame_index, frame.shape) 85 | frame = frame.reshape(15, 30, 3).clip(0, 255).astype(np.uint8) 86 | self.__fbank_processed_index += 2 87 | return frame 88 | 89 | def interface_wav(self, wavpath): 90 | rate, wav = wavfile.read(wavpath, mmap=False) 91 | augmented_samples = wav 92 | augmented_samples2 = augmented_samples.astype(np.float32, order='C') / 32768.0 93 | # print(augmented_samples2.shape, augmented_samples2.shape[0] / 16000) 94 | 95 | opts = knf.FbankOptions() 96 | opts.frame_opts.dither = 0 97 | opts.frame_opts.frame_length_ms = 50 98 | opts.frame_opts.frame_shift_ms = 20 99 | opts.mel_opts.num_bins = 80 100 | opts.frame_opts.snip_edges = False 101 | opts.mel_opts.debug_mel = False 102 | fbank = knf.OnlineFbank(opts) 103 | fbank.accept_waveform(16000, augmented_samples2.tolist()) 104 | seq_len = fbank.num_frames_ready // 2 105 | A2Lsamples = np.zeros([2 * seq_len, 80]) 106 | for i in range(2 * seq_len): 107 | f2 = fbank.get_frame(i) 108 | A2Lsamples[i] = f2 109 | 110 | orig_mel = A2Lsamples 111 | # print(orig_mel.shape) 112 | input = torch.from_numpy(orig_mel).unsqueeze(0).float().to(device) 113 | # print(input.shape) 114 | h0 = torch.zeros(2, 1, 192).to(device) 115 | c0 = torch.zeros(2, 1, 192).to(device) 116 | bs_array, hn, cn = self.__net(input, h0, c0) 117 | bs_array = bs_array[0].detach().cpu().float().numpy() 118 | bs_array = bs_array[4:] 119 | 120 | frame_num = len(bs_array) 121 | output = np.zeros([frame_num, 15, 30, 3], dtype = np.uint8) 122 | for frame_index in range(frame_num): 123 | bs_real = bs_array[frame_index] 124 | # bs_real[1:4] = - bs_real[1:4] 125 | frame = np.dot(bs_real[:6], self.pca_components_[:6]) + self.pca_mean_ 126 | # print(frame_index, frame.shape) 127 | frame = frame.reshape(15, 30, 3).clip(0, 255).astype(np.uint8) 128 | output[frame_index] = frame 129 | 130 | return output -------------------------------------------------------------------------------- /talkingface/config/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class DataProcessingOptions(): 4 | def __init__(self): 5 | self.parser = argparse.ArgumentParser() 6 | 7 | def parse_args(self): 8 | self.parser.add_argument('--extract_video_frame', action='store_true', help='extract video frame') 9 | self.parser.add_argument('--extract_audio', action='store_true', help='extract audio files from videos') 10 | self.parser.add_argument('--extract_deep_speech', action='store_true', help='extract deep speech features') 11 | self.parser.add_argument('--crop_face', action='store_true', help='crop face') 12 | self.parser.add_argument('--generate_training_json', action='store_true', help='generate training json file') 13 | 14 | self.parser.add_argument('--source_video_dir', type=str, default="./asserts/training_data/split_video_25fps", 15 | help='path of source video in 25 fps') 16 | self.parser.add_argument('--openface_landmark_dir', type=str, default="./asserts/training_data/split_video_25fps_landmark_openface", 17 | help='path of openface landmark dir') 18 | self.parser.add_argument('--video_frame_dir', type=str, default="./asserts/training_data/split_video_25fps_frame", 19 | help='path of video frames') 20 | self.parser.add_argument('--audio_dir', type=str, default="./asserts/training_data/split_video_25fps_audio", 21 | help='path of audios') 22 | self.parser.add_argument('--deep_speech_dir', type=str, default="./asserts/training_data/split_video_25fps_deepspeech", 23 | help='path of deep speech') 24 | self.parser.add_argument('--crop_face_dir', type=str, default="./asserts/training_data/split_video_25fps_crop_face", 25 | help='path of crop face dir') 26 | self.parser.add_argument('--json_path', type=str, default="./asserts/training_data/training_json.json", 27 | help='path of training json') 28 | self.parser.add_argument('--clip_length', type=int, default=9, help='clip length') 29 | self.parser.add_argument('--deep_speech_model', type=str, default="./asserts/output_graph.pb", 30 | help='path of pretrained deepspeech model') 31 | return self.parser.parse_args() 32 | 33 | class DINetTrainingOptions(): 34 | def __init__(self): 35 | self.parser = argparse.ArgumentParser() 36 | 37 | def parse_args(self): 38 | self.parser.add_argument('--seed', type=int, default=456, help='random seed to use.') 39 | self.parser.add_argument('--source_channel', type=int, default=3, help='input source image channels') 40 | self.parser.add_argument('--ref_channel', type=int, default=15, help='input reference image channels') 41 | self.parser.add_argument('--audio_channel', type=int, default=29, help='input audio channels') 42 | self.parser.add_argument('--augment_num', type=int, default=32, help='augment training data') 43 | self.parser.add_argument('--mouth_region_size', type=int, default=64, help='augment training data') 44 | self.parser.add_argument('--train_data', type=str, default=r"./asserts/training_data/training_json.json", 45 | help='path of training json') 46 | self.parser.add_argument('--batch_size', type=int, default=24, help='training batch size') 47 | self.parser.add_argument('--lamb_perception', type=int, default=10, help='weight of perception loss') 48 | self.parser.add_argument('--lamb_syncnet_perception', type=int, default=0.1, help='weight of perception loss') 49 | self.parser.add_argument('--lr_g', type=float, default=0.0001, help='initial learning rate for adam') 50 | self.parser.add_argument('--lr_d', type=float, default=0.0001, help='initial learning rate for adam') 51 | self.parser.add_argument('--start_epoch', default=1, type=int, help='start epoch in training stage') 52 | self.parser.add_argument('--non_decay', default=80, type=int, help='num of epoches with fixed learning rate') 53 | self.parser.add_argument('--decay', default=80, type=int, help='num of linearly decay epochs') 54 | self.parser.add_argument('--checkpoint', type=int, default=2, help='num of checkpoints in training stage') 55 | self.parser.add_argument('--result_path', type=str, default=r"./asserts/training_model_weight/frame_training_64", 56 | help='result path to save model') 57 | self.parser.add_argument('--coarse2fine', action='store_true', help='If true, load pretrained model path.') 58 | self.parser.add_argument('--coarse_model_path', 59 | default='', 60 | type=str, 61 | help='Save data (.pth) of previous training') 62 | self.parser.add_argument('--pretrained_syncnet_path', 63 | default='', 64 | type=str, 65 | help='Save data (.pth) of pretrained syncnet') 66 | self.parser.add_argument('--pretrained_frame_DINet_path', 67 | default='', 68 | type=str, 69 | help='Save data (.pth) of frame trained DINet') 70 | # ========================= Discriminator ========================== 71 | self.parser.add_argument('--D_num_blocks', type=int, default=4, help='num of down blocks in discriminator') 72 | self.parser.add_argument('--D_block_expansion', type=int, default=64, help='block expansion in discriminator') 73 | self.parser.add_argument('--D_max_features', type=int, default=256, help='max channels in discriminator') 74 | return self.parser.parse_args() 75 | 76 | 77 | class DINetInferenceOptions(): 78 | def __init__(self): 79 | self.parser = argparse.ArgumentParser() 80 | 81 | def parse_args(self): 82 | self.parser.add_argument('--source_channel', type=int, default=3, help='channels of source image') 83 | self.parser.add_argument('--ref_channel', type=int, default=15, help='channels of reference image') 84 | self.parser.add_argument('--audio_channel', type=int, default=29, help='channels of audio feature') 85 | self.parser.add_argument('--mouth_region_size', type=int, default=256, help='help to resize window') 86 | self.parser.add_argument('--source_video_path', 87 | default='./asserts/examples/test4.mp4', 88 | type=str, 89 | help='path of source video') 90 | self.parser.add_argument('--source_openface_landmark_path', 91 | default='./asserts/examples/test4.csv', 92 | type=str, 93 | help='path of detected openface landmark') 94 | self.parser.add_argument('--driving_audio_path', 95 | default='./asserts/examples/driving_audio_1.wav', 96 | type=str, 97 | help='path of driving audio') 98 | self.parser.add_argument('--pretrained_clip_DINet_path', 99 | default='./asserts/clip_training_DINet_256mouth.pth', 100 | type=str, 101 | help='pretrained model of DINet(clip trained)') 102 | self.parser.add_argument('--deepspeech_model_path', 103 | default='./asserts/output_graph.pb', 104 | type=str, 105 | help='path of deepspeech model') 106 | self.parser.add_argument('--res_video_dir', 107 | default='./asserts/inference_result', 108 | type=str, 109 | help='path of generated videos') 110 | return self.parser.parse_args() -------------------------------------------------------------------------------- /talkingface/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/data/__init__.py -------------------------------------------------------------------------------- /talkingface/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/data/__pycache__/few_shot_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/data/__pycache__/few_shot_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/data/face_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KMP_DUPLICATE_LIB_OK"] = "true" 3 | import pickle 4 | import cv2 5 | import numpy as np 6 | import os 7 | import glob 8 | from talkingface.util.smooth import smooth_array 9 | from talkingface.run_utils import calc_face_mat 10 | import tqdm 11 | from talkingface.utils import * 12 | 13 | path_ = r"../../../preparation_mix" 14 | video_list = [os.path.join(path_, i) for i in os.listdir(path_)] 15 | path_ = r"../../../preparation_hdtf" 16 | video_list += [os.path.join(path_, i) for i in os.listdir(path_)] 17 | path_ = r"../../../preparation_vfhq" 18 | video_list += [os.path.join(path_, i) for i in os.listdir(path_)] 19 | path_ = r"../../../preparation_bilibili" 20 | video_list += [os.path.join(path_, i) for i in os.listdir(path_)] 21 | print(video_list) 22 | video_list = video_list[:] 23 | img_all = [] 24 | keypoints_all = [] 25 | point_size = 1 26 | point_color = (0, 0, 255) # BGR 27 | thickness = 4 # 0 、4、8 28 | for path_ in tqdm.tqdm(video_list): 29 | img_filelist = glob.glob("{}/image/*.png".format(path_)) 30 | img_filelist.sort() 31 | if len(img_filelist) == 0: 32 | continue 33 | img_all.append(img_filelist) 34 | 35 | Path_output_pkl = "{}/keypoint_rotate.pkl".format(path_) 36 | 37 | with open(Path_output_pkl, "rb") as f: 38 | images_info = pickle.load(f)[:, main_keypoints_index, :] 39 | pts_driven = images_info.reshape(len(images_info), -1) 40 | pts_driven = smooth_array(pts_driven).reshape(len(pts_driven), -1, 3) 41 | 42 | face_pts_mean = np.loadtxt(r"data\face_pts_mean_mainKps.txt") 43 | mat_list,pts_normalized_list,face_pts_mean_personal = calc_face_mat(pts_driven, face_pts_mean) 44 | pts_normalized_list = np.array(pts_normalized_list) 45 | # print(face_pts_mean_personal[INDEX_FACE_OVAL[:10], 1]) 46 | # print(np.max(pts_normalized_list[:,INDEX_FACE_OVAL[:10], 1], axis = 1)) 47 | face_pts_mean_personal[INDEX_FACE_OVAL[:10], 1] = np.max(pts_normalized_list[:,INDEX_FACE_OVAL[:10], 1], axis = 0) + np.arange(5,25,2) 48 | face_pts_mean_personal[INDEX_FACE_OVAL[:10], 0] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[:10], 0], axis=0) - (9 - np.arange(0,10)) 49 | face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 1], axis=0) - np.arange(5,25,2) + 28 50 | face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 0] = np.min(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 0], axis=0) + np.arange(0,10) 51 | 52 | face_pts_mean_personal[INDEX_FACE_OVAL[10], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[10], 1], axis=0) + 25 53 | 54 | # for keypoints_normalized in pts_normalized_list: 55 | # img = np.zeros([1000,1000,3], dtype=np.uint8) 56 | # for coor in face_pts_mean_personal: 57 | # # coor = (coor +1 )/2. 58 | # cv2.circle(img, (int(coor[0]), int(coor[1])), point_size, (255, 0, 0), thickness) 59 | # for coor in keypoints_normalized: 60 | # # coor = (coor +1 )/2. 61 | # cv2.circle(img, (int(coor[0]), int(coor[1])), point_size, point_color, thickness) 62 | # cv2.imshow("a", img) 63 | # cv2.waitKey(30) 64 | 65 | with open("{}/face_mat_mask20240722.pkl".format(path_), "wb") as f: 66 | pickle.dump([mat_list, face_pts_mean_personal], f) 67 | -------------------------------------------------------------------------------- /talkingface/data/few_shot_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import tqdm 4 | import copy 5 | from talkingface.utils import * 6 | import glob 7 | import pickle 8 | import torch 9 | import torch.utils.data as data 10 | def get_image(A_path, crop_coords, input_type, resize= (256, 256)): 11 | (x_min, y_min, x_max, y_max) = crop_coords 12 | size = (x_max - x_min, y_max - y_min) 13 | 14 | if input_type == 'mediapipe': 15 | pose_pts = (A_path - np.array([x_min, y_min])) * resize / size 16 | return pose_pts[:, :2] 17 | else: 18 | img_output = A_path[y_min:y_max, x_min:x_max, :] 19 | img_output = cv2.resize(img_output, resize) 20 | return img_output 21 | def generate_input(img, keypoints, mask_keypoints, is_train = False, mode=["mouth_bias"], mouth_width = None, mouth_height = None): 22 | # 根据关键点决定正方形裁剪区域 23 | crop_coords = crop_face(keypoints, size=img.shape[:2], is_train=is_train) 24 | target_keypoints = get_image(keypoints[:,:2], crop_coords, input_type='mediapipe') 25 | target_img = get_image(img, crop_coords, input_type='img') 26 | 27 | target_mask_keypoints = get_image(mask_keypoints[:,:2], crop_coords, input_type='mediapipe') 28 | 29 | # source_img信息:扣出嘴部区域 30 | source_img = copy.deepcopy(target_img) 31 | source_keypoints = target_keypoints 32 | 33 | pts = source_keypoints.copy() 34 | 35 | face_edge_start_index = 2 36 | 37 | pts[INDEX_FACE_OVAL[face_edge_start_index:-face_edge_start_index], 1] = target_mask_keypoints[face_edge_start_index:-face_edge_start_index, 1] 38 | 39 | # pts = pts[INDEX_FACE_OVAL[face_edge_start_index:-face_edge_start_index] + INDEX_NOSE_EDGE[::-1], :2] 40 | pts = pts[FACE_MASK_INDEX + INDEX_NOSE_EDGE[::-1], :2] 41 | 42 | pts = pts.reshape((-1, 1, 2)).astype(np.int32) 43 | cv2.fillPoly(source_img, [pts], color=(0, 0, 0)) 44 | source_face_egde = draw_face_feature_maps(source_keypoints, mode=mode, im_edges=target_img, 45 | mouth_width = mouth_width * (256/(crop_coords[2] - crop_coords[0])), mouth_height = mouth_height * (256/(crop_coords[2] - crop_coords[0]))) 46 | source_img = np.concatenate([source_img, source_face_egde], axis=2) 47 | return source_img,target_img,crop_coords 48 | 49 | def generate_ref(img, keypoints, is_train=False, alpha = None, beta = None): 50 | crop_coords = crop_face(keypoints, size=img.shape[:2], is_train=is_train) 51 | ref_keypoints = get_image(keypoints, crop_coords, input_type='mediapipe') 52 | ref_img = get_image(img, crop_coords, input_type='img') 53 | 54 | if beta is not None: 55 | if alpha: 56 | ref_img[:, :, :3] = cv2.add(ref_img[:, :, :3], beta) 57 | else: 58 | ref_img[:, :, :3] = cv2.subtract(ref_img[:, :, :3], beta) 59 | ref_face_edge = draw_face_feature_maps(ref_keypoints, mode=["mouth", "nose", "eye", "oval_all","muscle"]) 60 | ref_img = np.concatenate([ref_img, ref_face_edge], axis=2) 61 | return ref_img 62 | 63 | def select_ref_index(driven_keypoints, n_ref = 5, ratio = 1/3.): 64 | # 根据嘴巴开合程度,选取开合最大的那一半 65 | lips_distance = np.linalg.norm( 66 | driven_keypoints[:, INDEX_LIPS_INNER[5]] - driven_keypoints[:, INDEX_LIPS_INNER[-5]], axis=1) 67 | selected_index_list = np.argsort(lips_distance).tolist()[int(len(lips_distance) * ratio):] 68 | ref_img_index_list = random.sample(selected_index_list, n_ref) # 从当前视频选n_ref个图片 69 | return ref_img_index_list 70 | 71 | def get_ref_images_fromVideo(cap, ref_img_index_list, ref_keypoints): 72 | ref_img_list = [] 73 | for index in ref_img_index_list: 74 | cap.set(cv2.CAP_PROP_POS_FRAMES, index) # 设置要获取的帧号 75 | ret, frame = cap.read() 76 | if ret is False: 77 | print("请检查当前视频, 错误帧数:", index) 78 | ref_img = generate_ref(frame, ref_keypoints[index]) 79 | ref_img_list.append(ref_img) 80 | ref_img = np.concatenate(ref_img_list, axis=2) 81 | return ref_img 82 | 83 | 84 | 85 | 86 | class Few_Shot_Dataset(data.Dataset): 87 | def __init__(self, dict_info, n_ref = 2, is_train = False): 88 | super(Few_Shot_Dataset, self).__init__() 89 | self.driven_images = dict_info["driven_images"] 90 | self.driven_keypoints = dict_info["driven_keypoints"] 91 | self.driving_keypoints = dict_info["driving_keypoints"] 92 | self.driven_mask_keypoints = dict_info["driven_mask_keypoints"] 93 | self.is_train = is_train 94 | 95 | assert len(self.driven_images) == len(self.driven_keypoints) 96 | assert len(self.driven_images) == len(self.driving_keypoints) 97 | 98 | self.out_size = (256, 256) 99 | 100 | self.sample_num = np.sum([len(i) for i in self.driven_images]) 101 | 102 | # list: 每个视频序列的视频块个数 103 | self.clip_count_list = [] # number of frames in each sequence 104 | for path in self.driven_images: 105 | self.clip_count_list.append(len(path)) 106 | self.n_ref = n_ref 107 | 108 | def get_ref_images(self, video_index, ref_img_index_list): 109 | # 参考图片 110 | ref_img_list = [] 111 | for ref_img_index in ref_img_index_list: 112 | ref_img = cv2.imread(self.driven_images[video_index][ref_img_index]) 113 | # ref_img = cv2.convertScaleAbs(ref_img, alpha=self.alpha, beta=self.beta) 114 | 115 | 116 | ref_keypoints = self.driven_keypoints[video_index][ref_img_index] 117 | ref_img = generate_ref(ref_img, ref_keypoints, self.is_train, self.alpha, self.beta) 118 | 119 | ref_img_list.append(ref_img) 120 | self.ref_img = np.concatenate(ref_img_list, axis=2) 121 | 122 | def __getitem__(self, index): 123 | 124 | # 调整亮度和对比度 125 | # self.alpha = random.uniform(0.8,1.25) # 缩放因子 126 | # self.beta = random.uniform(-50,50) # 移位因子 127 | # adjusted = cv2.convertScaleAbs(img, alpha=alpha, beta=beta) 128 | 129 | self.alpha = (random.random() > 0.5) # 正负因子 130 | self.beta = np.ones([256,256,3]) * np.random.rand(3) * 20 # 色彩调整0-20个色差 131 | self.beta = self.beta.astype(np.uint8) 132 | 133 | 134 | if self.is_train: 135 | video_index = random.randint(0, len(self.driven_images) - 1) 136 | current_clip = random.randint(0, self.clip_count_list[video_index] - 1) 137 | ref_img_index_list = select_ref_index(self.driven_keypoints[video_index], n_ref = self.n_ref) # 从当前视频选n_ref个图片 138 | self.get_ref_images(video_index, ref_img_index_list) 139 | else: 140 | video_index = 0 141 | current_clip = index 142 | 143 | if index == 0: 144 | ref_img_index_list = select_ref_index(self.driven_keypoints[video_index], n_ref=self.n_ref5) # 从当前视频选n_ref个图片 145 | self.get_ref_images(video_index, ref_img_index_list) 146 | 147 | # target图片 148 | target_img = cv2.imread(self.driven_images[video_index][current_clip]) 149 | # target_img = cv2.convertScaleAbs(target_img, alpha=self.alpha, beta=self.beta) 150 | 151 | target_keypoints = self.driving_keypoints[video_index][current_clip] 152 | target_mask_keypoints = self.driven_mask_keypoints[video_index][current_clip] 153 | 154 | mouth_rect = self.driving_keypoints[video_index][:, INDEX_LIPS].max(axis=1) - self.driving_keypoints[video_index][:, INDEX_LIPS].min(axis=1) 155 | mouth_width = mouth_rect[:, 0].max() 156 | mouth_height = mouth_rect[:, 1].max() 157 | 158 | # source_img, target_img,crop_coords = generate_input(target_img, target_keypoints, target_mask_keypoints, self.is_train) 159 | source_img, target_img,crop_coords = generate_input(target_img, target_keypoints, target_mask_keypoints, self.is_train, mode=["mouth_bias", "nose", "eye"], 160 | mouth_width = mouth_width, mouth_height = mouth_height) 161 | 162 | target_img = target_img/255. 163 | source_img = source_img/255. 164 | ref_img = self.ref_img / 255. 165 | 166 | # tensor 167 | source_tensor = torch.from_numpy(source_img).float().permute(2, 0, 1) 168 | ref_tensor = torch.from_numpy(ref_img).float().permute(2, 0, 1) 169 | target_tensor = torch.from_numpy(target_img).float().permute(2, 0, 1) 170 | return source_tensor, ref_tensor, target_tensor 171 | 172 | def __len__(self): 173 | if self.is_train: 174 | return len(self.driven_images) 175 | else: 176 | return len(self.driven_images[0]) 177 | # return self.sample_num 178 | def data_preparation(train_video_list): 179 | img_all = [] 180 | keypoints_all = [] 181 | mask_all = [] 182 | point_size = 1 183 | point_color = (0, 0, 255) # BGR 184 | thickness = 4 # 0 、4、8 185 | for i in tqdm.tqdm(train_video_list): 186 | # for i in ["xiaochangzhang/00004"]: 187 | model_name = i 188 | img_filelist = glob.glob("{}/image/*.png".format(model_name)) 189 | img_filelist.sort() 190 | if len(img_filelist) == 0: 191 | continue 192 | img_all.append(img_filelist) 193 | 194 | Path_output_pkl = "{}/keypoint_rotate.pkl".format(model_name) 195 | with open(Path_output_pkl, "rb") as f: 196 | images_info = pickle.load(f) 197 | keypoints_all.append(images_info[:, main_keypoints_index, :2]) 198 | 199 | Path_output_pkl = "{}/face_mat_mask20240722.pkl".format(model_name) 200 | with open(Path_output_pkl, "rb") as f: 201 | mat_list, face_pts_mean_personal = pickle.load(f) 202 | 203 | face_pts_mean_personal = face_pts_mean_personal[INDEX_FACE_OVAL] 204 | face_mask_pts = np.zeros([len(mat_list), len(face_pts_mean_personal), 2]) 205 | for index_ in range(len(mat_list)): 206 | # img = np.zeros([1000,1000,3], dtype=np.uint8) 207 | # img = cv2.imread(img_filelist[index_]) 208 | 209 | rotationMatrix = mat_list[index_] 210 | 211 | keypoints = np.ones([4, len(face_pts_mean_personal)]) 212 | keypoints[:3, :] = face_pts_mean_personal.T 213 | driving_mask = rotationMatrix.dot(keypoints).T 214 | face_mask_pts[index_] = driving_mask[:, :2] 215 | 216 | 217 | 218 | # for coor in driving_mask: 219 | # # coor = (coor +1 )/2. 220 | # cv2.circle(img, (int(coor[0]), int(coor[1])), point_size, point_color, thickness) 221 | # cv2.imshow("a", img) 222 | # cv2.waitKey(30) 223 | mask_all.append(face_mask_pts) 224 | 225 | print("train size: ", len(img_all)) 226 | dict_info = {} 227 | dict_info["driven_images"] = img_all 228 | dict_info["driven_keypoints"] = keypoints_all 229 | dict_info["driving_keypoints"] = keypoints_all 230 | dict_info["driven_mask_keypoints"] = mask_all 231 | return dict_info 232 | 233 | 234 | def generate_input_pixels(img, keypoints, rotationMatrix, pixels_mouth, mask_keypoints, coords_array): 235 | # 根据关键点决定正方形裁剪区域 236 | crop_coords = crop_face(keypoints, size=img.shape[:2], is_train=False) 237 | target_keypoints = get_image(keypoints[:, :2], crop_coords, input_type='mediapipe') 238 | 239 | # 画出嘴部像素图 240 | pixels_mouth_coords = rotationMatrix.dot(coords_array).T 241 | pixels_mouth_coords = pixels_mouth_coords[:, :2].astype(int) 242 | pixels_mouth_coords = (pixels_mouth_coords[:, 1], pixels_mouth_coords[:, 0]) 243 | 244 | source_face_egde = np.zeros_like(img, dtype=np.uint8) 245 | # out_frame = img.copy() 246 | frame = pixels_mouth.reshape(15, 30, 3).clip(0, 255).astype(np.uint8) 247 | 248 | frame = cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY), (150, 100)) 249 | sharpen_image = frame.astype(np.float32) 250 | mean_ = int(np.mean(sharpen_image)) 251 | max_, min_ = mean_ + 60, mean_ - 60 252 | sharpen_image = (sharpen_image - min_) / (max_ - min_) * 255. 253 | sharpen_image = sharpen_image.clip(0, 255).astype(np.uint8) 254 | 255 | sharpen_image = np.concatenate( 256 | [sharpen_image[:, :, np.newaxis], sharpen_image[:, :, np.newaxis], sharpen_image[:, :, np.newaxis]], axis=2) 257 | # sharpen_image = cv2.resize(sharpen_image, (150, 100)) 258 | source_face_egde[pixels_mouth_coords] = sharpen_image.reshape(-1, 3) 259 | # cv2.imshow("sharpen_image", source_face_egde) 260 | # cv2.waitKey(40) 261 | 262 | source_face_egde = get_image(source_face_egde, crop_coords, input_type='image') 263 | source_face_egde = draw_face_feature_maps(target_keypoints, mode = ["nose", "eye"], im_edges=source_face_egde) 264 | # cv2.imshow("sharpen_image", source_face_egde) 265 | # cv2.waitKey(40) 266 | 267 | 268 | target_img = get_image(img, crop_coords, input_type='img') 269 | target_mask_keypoints = get_image(mask_keypoints[:, :2], crop_coords, input_type='mediapipe') 270 | # source_img信息:扣出嘴部区域 271 | source_img = copy.deepcopy(target_img) 272 | source_keypoints = target_keypoints 273 | pts = source_keypoints.copy() 274 | face_edge_start_index = 3 275 | pts[INDEX_FACE_OVAL[face_edge_start_index:-face_edge_start_index], 1] = target_mask_keypoints[ 276 | face_edge_start_index:-face_edge_start_index, 277 | 1] 278 | pts = pts[INDEX_FACE_OVAL[face_edge_start_index:-face_edge_start_index] + INDEX_NOSE_EDGE[::-1], :2] 279 | pts = pts.reshape((-1, 1, 2)).astype(np.int32) 280 | cv2.fillPoly(source_img, [pts], color=(0, 0, 0)) 281 | source_img = np.concatenate([source_img, source_face_egde], axis=2) 282 | return source_img, target_img, crop_coords -------------------------------------------------------------------------------- /talkingface/models/DINet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | import cv2 6 | import numpy as np 7 | from torch.nn import BatchNorm2d 8 | from torch.nn import BatchNorm1d 9 | 10 | def make_coordinate_grid_3d(spatial_size, type): 11 | ''' 12 | generate 3D coordinate grid 13 | ''' 14 | d, h, w = spatial_size 15 | x = torch.arange(w).type(type) 16 | y = torch.arange(h).type(type) 17 | z = torch.arange(d).type(type) 18 | x = (2 * (x / (w - 1)) - 1) 19 | y = (2 * (y / (h - 1)) - 1) 20 | z = (2 * (z / (d - 1)) - 1) 21 | yy = y.view(1,-1, 1).repeat(d,1, w) 22 | xx = x.view(1,1, -1).repeat(d,h, 1) 23 | zz = z.view(-1,1,1).repeat(1,h,w) 24 | meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3)], 3) 25 | return meshed,zz 26 | 27 | class ResBlock1d(nn.Module): 28 | ''' 29 | basic block 30 | ''' 31 | def __init__(self, in_features,out_features, kernel_size, padding): 32 | super(ResBlock1d, self).__init__() 33 | self.in_features = in_features 34 | self.out_features = out_features 35 | self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 36 | padding=padding) 37 | self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 38 | padding=padding) 39 | if out_features != in_features: 40 | self.channel_conv = nn.Conv1d(in_features,out_features,1) 41 | self.norm1 = BatchNorm1d(in_features) 42 | self.norm2 = BatchNorm1d(in_features) 43 | self.relu = nn.ReLU() 44 | def forward(self, x): 45 | out = self.norm1(x) 46 | out = self.relu(out) 47 | out = self.conv1(out) 48 | out = self.norm2(out) 49 | out = self.relu(out) 50 | out = self.conv2(out) 51 | if self.in_features != self.out_features: 52 | out += self.channel_conv(x) 53 | else: 54 | out += x 55 | return out 56 | 57 | class ResBlock2d(nn.Module): 58 | ''' 59 | basic block 60 | ''' 61 | def __init__(self, in_features,out_features, kernel_size, padding): 62 | super(ResBlock2d, self).__init__() 63 | self.in_features = in_features 64 | self.out_features = out_features 65 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 66 | padding=padding) 67 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 68 | padding=padding) 69 | if out_features != in_features: 70 | self.channel_conv = nn.Conv2d(in_features,out_features,1) 71 | self.norm1 = BatchNorm2d(in_features) 72 | self.norm2 = BatchNorm2d(in_features) 73 | self.relu = nn.ReLU() 74 | def forward(self, x): 75 | out = self.norm1(x) 76 | out = self.relu(out) 77 | out = self.conv1(out) 78 | out = self.norm2(out) 79 | out = self.relu(out) 80 | out = self.conv2(out) 81 | if self.in_features != self.out_features: 82 | out += self.channel_conv(x) 83 | else: 84 | out += x 85 | return out 86 | 87 | class UpBlock2d(nn.Module): 88 | ''' 89 | basic block 90 | ''' 91 | def __init__(self, in_features, out_features, kernel_size=3, padding=1): 92 | super(UpBlock2d, self).__init__() 93 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 94 | padding=padding) 95 | self.norm = BatchNorm2d(out_features) 96 | self.relu = nn.ReLU() 97 | def forward(self, x): 98 | out = F.interpolate(x, scale_factor=2) 99 | out = self.conv(out) 100 | out = self.norm(out) 101 | out = F.relu(out) 102 | return out 103 | 104 | class DownBlock1d(nn.Module): 105 | ''' 106 | basic block 107 | ''' 108 | def __init__(self, in_features, out_features, kernel_size, padding): 109 | super(DownBlock1d, self).__init__() 110 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 111 | padding=padding,stride=2) 112 | self.norm = BatchNorm1d(out_features) 113 | self.relu = nn.ReLU() 114 | def forward(self, x): 115 | out = self.conv(x) 116 | out = self.norm(out) 117 | out = self.relu(out) 118 | return out 119 | 120 | class DownBlock2d(nn.Module): 121 | ''' 122 | basic block 123 | ''' 124 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, stride=2): 125 | super(DownBlock2d, self).__init__() 126 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 127 | padding=padding, stride=stride) 128 | self.norm = BatchNorm2d(out_features) 129 | self.relu = nn.ReLU() 130 | def forward(self, x): 131 | out = self.conv(x) 132 | out = self.norm(out) 133 | out = self.relu(out) 134 | return out 135 | 136 | class SameBlock1d(nn.Module): 137 | ''' 138 | basic block 139 | ''' 140 | def __init__(self, in_features, out_features, kernel_size, padding): 141 | super(SameBlock1d, self).__init__() 142 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, 143 | kernel_size=kernel_size, padding=padding) 144 | self.norm = BatchNorm1d(out_features) 145 | self.relu = nn.ReLU() 146 | def forward(self, x): 147 | out = self.conv(x) 148 | out = self.norm(out) 149 | out = self.relu(out) 150 | return out 151 | 152 | class SameBlock2d(nn.Module): 153 | ''' 154 | basic block 155 | ''' 156 | def __init__(self, in_features, out_features, kernel_size=3, padding=1): 157 | super(SameBlock2d, self).__init__() 158 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 159 | kernel_size=kernel_size, padding=padding) 160 | self.norm = BatchNorm2d(out_features) 161 | self.relu = nn.ReLU() 162 | def forward(self, x): 163 | out = self.conv(x) 164 | out = self.norm(out) 165 | out = self.relu(out) 166 | return out 167 | 168 | class AdaAT(nn.Module): 169 | ''' 170 | AdaAT operator 171 | ''' 172 | def __init__(self, para_ch,feature_ch, cuda = True): 173 | super(AdaAT, self).__init__() 174 | self.para_ch = para_ch 175 | self.feature_ch = feature_ch 176 | self.commn_linear = nn.Sequential( 177 | nn.Linear(para_ch, para_ch), 178 | nn.ReLU() 179 | ) 180 | self.scale = nn.Sequential( 181 | nn.Linear(para_ch, feature_ch), 182 | nn.Sigmoid() 183 | ) 184 | self.rotation = nn.Sequential( 185 | nn.Linear(para_ch, feature_ch), 186 | nn.Tanh() 187 | ) 188 | self.translation = nn.Sequential( 189 | nn.Linear(para_ch, 2 * feature_ch), 190 | nn.Tanh() 191 | ) 192 | self.tanh = nn.Tanh() 193 | self.sigmoid = nn.Sigmoid() 194 | 195 | def forward(self, feature_map,para_code): 196 | batch,d, h, w = feature_map.size(0), feature_map.size(1), feature_map.size(2), feature_map.size(3) 197 | para_code = self.commn_linear(para_code) 198 | scale = self.scale(para_code).unsqueeze(-1) * 2 199 | angle = self.rotation(para_code).unsqueeze(-1) * 3.14159# 200 | rotation_matrix = torch.cat([torch.cos(angle), -torch.sin(angle), torch.sin(angle), torch.cos(angle)], -1) 201 | rotation_matrix = rotation_matrix.view(batch, self.feature_ch, 2, 2) 202 | translation = self.translation(para_code).view(batch, self.feature_ch, 2) 203 | grid_xy, grid_z = make_coordinate_grid_3d((d, h, w), feature_map.type()) 204 | grid_xy = grid_xy.unsqueeze(0).repeat(batch, 1, 1, 1, 1) 205 | grid_z = grid_z.unsqueeze(0).repeat(batch, 1, 1, 1) 206 | scale = scale.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) 207 | rotation_matrix = rotation_matrix.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1, 1) 208 | translation = translation.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) 209 | trans_grid = torch.matmul(rotation_matrix, grid_xy.unsqueeze(-1)).squeeze(-1) * scale + translation 210 | full_grid = torch.cat([trans_grid, grid_z.unsqueeze(-1)], -1) 211 | trans_feature = F.grid_sample(feature_map.unsqueeze(1), full_grid).squeeze(1) 212 | return trans_feature 213 | 214 | class DINet_five_Ref(nn.Module): 215 | def __init__(self, source_channel,ref_channel, cuda = True): 216 | super(DINet_five_Ref, self).__init__() 217 | self.source_in_conv = nn.Sequential( 218 | SameBlock2d(source_channel,32,kernel_size=7, padding=3), 219 | DownBlock2d(32, 64, kernel_size=3, padding=1), 220 | DownBlock2d(64, 128,kernel_size=3, padding=1) 221 | ) 222 | self.ref_in_conv = nn.Sequential( 223 | SameBlock2d(ref_channel, 64, kernel_size=7, padding=3), 224 | DownBlock2d(64, 128, kernel_size=3, padding=1), 225 | DownBlock2d(128, 256, kernel_size=3, padding=1), 226 | ) 227 | self.trans_conv = nn.Sequential( 228 | # 20 →10 229 | SameBlock2d(384, 128, kernel_size=3, padding=1), 230 | # SameBlock2d(128, 128, kernel_size=11, padding=5), 231 | SameBlock2d(128, 128, kernel_size=7, padding=3), 232 | DownBlock2d(128, 128, kernel_size=3, padding=1), 233 | # 10 →5 234 | SameBlock2d(128, 128, kernel_size=7, padding=3), 235 | # SameBlock2d(128, 128, kernel_size=7, padding=3), 236 | DownBlock2d(128, 128, kernel_size=3, padding=1), 237 | # 5 →3 238 | SameBlock2d(128, 128, kernel_size=3, padding=1), 239 | DownBlock2d(128, 128, kernel_size=3, padding=1), 240 | # 3 →2 241 | SameBlock2d(128, 128, kernel_size=3, padding=1), 242 | DownBlock2d(128, 128, kernel_size=3, padding=1), 243 | 244 | ) 245 | 246 | appearance_conv_list = [] 247 | for i in range(2): 248 | appearance_conv_list.append( 249 | nn.Sequential( 250 | ResBlock2d(256, 256, 3, 1), 251 | ResBlock2d(256, 256, 3, 1), 252 | ) 253 | ) 254 | self.appearance_conv_list = nn.ModuleList(appearance_conv_list) 255 | self.adaAT = AdaAT(128, 256, cuda) 256 | self.out_conv = nn.Sequential( 257 | SameBlock2d(384, 128, kernel_size=3, padding=1), 258 | UpBlock2d(128,128,kernel_size=3, padding=1), 259 | ResBlock2d(128, 128, 3, 1), 260 | UpBlock2d(128, 128, kernel_size=3, padding=1), 261 | nn.Conv2d(128, 3, kernel_size=3, padding=1), 262 | nn.Sigmoid() 263 | ) 264 | self.global_avg2d = nn.AdaptiveAvgPool2d(1) 265 | self.global_avg1d = nn.AdaptiveAvgPool1d(1) 266 | def ref_input(self, ref_img): 267 | ## reference image encoder 268 | self.ref_in_feature = self.ref_in_conv(ref_img) 269 | # print(self.ref_in_feature.size(), self.ref_in_feature) 270 | ## use AdaAT do spatial deformation on reference feature maps 271 | self.ref_trans_feature0 = self.appearance_conv_list[0](self.ref_in_feature) 272 | def interface(self, source_img, source_prompt): 273 | self.source_img = torch.cat([source_img, source_prompt], dim=1) 274 | ## source image encoder 275 | source_in_feature = self.source_in_conv(self.source_img) 276 | # print(source_in_feature.size(), source_in_feature) 277 | 278 | ## alignment encoder 279 | img_para = self.trans_conv(torch.cat([source_in_feature,self.ref_in_feature],1)) 280 | img_para = self.global_avg2d(img_para).squeeze(3).squeeze(2) 281 | # print(img_para.size(), img_para) 282 | ## concat alignment feature and audio feature 283 | trans_para = img_para 284 | 285 | ref_trans_feature = self.adaAT(self.ref_trans_feature0, trans_para) 286 | ref_trans_feature = self.appearance_conv_list[1](ref_trans_feature) 287 | # print(ref_trans_feature.size(), ref_trans_feature) 288 | ## feature decoder 289 | merge_feature = torch.cat([source_in_feature,ref_trans_feature],1) 290 | # print(merge_feature.size(), merge_feature) 291 | out = self.out_conv(merge_feature) 292 | return out 293 | def forward(self, source_img, source_prompt, ref_img): 294 | self.ref_input(ref_img) 295 | out = self.interface(source_img, source_prompt) 296 | return out 297 | 298 | 299 | # from torch import nn 300 | # import time 301 | # import torch 302 | # 303 | # device = "cuda" if torch.cuda.is_available() else "cpu" 304 | # model_Generator = DINet_five_Ref(6, 30).to(device) 305 | # torch.save(model_Generator.state_dict(), "DINet_five_Ref.pth") 306 | # 307 | # source_img = torch.zeros([1,3,256,256]).cuda() 308 | # source_prompt = torch.zeros([1,3, 256,256]).cuda() 309 | # ref_img = torch.zeros([1,30,256,256]).cuda() 310 | # 311 | # model_Generator.ref_input(ref_img) 312 | # start_time = time.time() 313 | # for i in range(2000): 314 | # print(i, time.time() - start_time) 315 | # out = model_Generator.interface(source_img, source_prompt) -------------------------------------------------------------------------------- /talkingface/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/models/__init__.py -------------------------------------------------------------------------------- /talkingface/models/__pycache__/DINet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/models/__pycache__/DINet.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/models/__pycache__/audio2bs_lstm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/talkingface/models/__pycache__/audio2bs_lstm.cpython-310.pyc -------------------------------------------------------------------------------- /talkingface/models/audio2bs_lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | class Audio2Feature(nn.Module): 4 | def __init__(self): 5 | super(Audio2Feature, self).__init__() 6 | num_pred = 1 7 | self.output_size = 6 8 | self.ndim = 80 9 | APC_hidden_size = 80 10 | # define networks 11 | self.downsample = nn.Sequential( 12 | nn.Linear(in_features=APC_hidden_size * 2, out_features=APC_hidden_size), 13 | nn.BatchNorm1d(APC_hidden_size), 14 | nn.LeakyReLU(0.2), 15 | nn.Linear(APC_hidden_size, APC_hidden_size), 16 | ) 17 | self.LSTM = nn.LSTM(input_size=APC_hidden_size, 18 | hidden_size=192, 19 | num_layers=2, 20 | dropout=0, 21 | bidirectional=False, 22 | batch_first=True) 23 | self.fc = nn.Sequential( 24 | nn.Linear(in_features=192, out_features=256), 25 | nn.BatchNorm1d(256), 26 | nn.LeakyReLU(0.2), 27 | nn.Linear(256, 256), 28 | nn.BatchNorm1d(256), 29 | nn.LeakyReLU(0.2), 30 | nn.Linear(256, self.output_size)) 31 | 32 | def forward(self, audio_features, h0, c0): 33 | ''' 34 | Args: 35 | audio_features: [b, T, ndim] 36 | ''' 37 | self.item_len = audio_features.size()[1] 38 | # new in 0324 39 | audio_features = audio_features.reshape(-1, self.ndim * 2) 40 | down_audio_feats = self.downsample(audio_features) 41 | # print(down_audio_feats) 42 | down_audio_feats = down_audio_feats.reshape(-1, int(self.item_len / 2), self.ndim) 43 | output, (hn, cn) = self.LSTM(down_audio_feats, (h0, c0)) 44 | 45 | # output, (hn, cn) = self.LSTM(audio_features) 46 | pred = self.fc(output.reshape(-1, 192)).reshape(-1, int(self.item_len / 2), self.output_size) 47 | return pred, hn, cn 48 | -------------------------------------------------------------------------------- /talkingface/models/common/Discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | class DownBlock2d(nn.Module): 5 | def __init__(self, in_features, out_features, kernel_size=4, pool=False): 6 | super(DownBlock2d, self).__init__() 7 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 8 | self.pool = pool 9 | def forward(self, x): 10 | out = x 11 | out = self.conv(out) 12 | out = F.leaky_relu(out, 0.2) 13 | if self.pool: 14 | out = F.avg_pool2d(out, (2, 2)) 15 | return out 16 | 17 | 18 | class Discriminator(nn.Module): 19 | """ 20 | Discriminator for GAN loss 21 | """ 22 | def __init__(self, num_channels, block_expansion=64, num_blocks=4, max_features=512): 23 | super(Discriminator, self).__init__() 24 | down_blocks = [] 25 | for i in range(num_blocks): 26 | down_blocks.append( 27 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), 28 | min(max_features, block_expansion * (2 ** (i + 1))), 29 | kernel_size=4, pool=(i != num_blocks - 1))) 30 | self.down_blocks = nn.ModuleList(down_blocks) 31 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 32 | def forward(self, x): 33 | feature_maps = [] 34 | out = x 35 | for down_block in self.down_blocks: 36 | feature_maps.append(down_block(out)) 37 | out = feature_maps[-1] 38 | out = self.conv(out) 39 | return feature_maps, out 40 | -------------------------------------------------------------------------------- /talkingface/models/common/VGG19.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import numpy as np 4 | 5 | 6 | class Vgg19(torch.nn.Module): 7 | """ 8 | Vgg19 network for perceptual loss 9 | """ 10 | def __init__(self, requires_grad=False): 11 | super(Vgg19, self).__init__() 12 | vgg_model = models.vgg19(pretrained=True) 13 | vgg_pretrained_features = vgg_model.features 14 | self.slice1 = torch.nn.Sequential() 15 | self.slice2 = torch.nn.Sequential() 16 | self.slice3 = torch.nn.Sequential() 17 | self.slice4 = torch.nn.Sequential() 18 | self.slice5 = torch.nn.Sequential() 19 | for x in range(2): 20 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(2, 7): 22 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(7, 12): 24 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(12, 21): 26 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(21, 30): 28 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 29 | 30 | self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), 31 | requires_grad=False) 32 | self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), 33 | requires_grad=False) 34 | 35 | if not requires_grad: 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward(self, X): 40 | X = (X - self.mean) / self.std 41 | h_relu1 = self.slice1(X) 42 | h_relu2 = self.slice2(h_relu1) 43 | h_relu3 = self.slice3(h_relu2) 44 | h_relu4 = self.slice4(h_relu3) 45 | h_relu5 = self.slice5(h_relu4) 46 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 47 | return out 48 | -------------------------------------------------------------------------------- /talkingface/models/speed_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from talkingface.models.audio2bs_lstm import Audio2Feature 3 | import time 4 | import random 5 | import numpy as np 6 | import cv2 7 | device = "cpu" 8 | 9 | model = Audio2Feature() 10 | model.eval() 11 | x = torch.ones((1, 2, 80)) 12 | h0 = torch.zeros(2, 1, 192) 13 | c0 = torch.zeros(2, 1, 192) 14 | y, hn, cn = model(x, h0, c0) 15 | start_time = time.time() 16 | 17 | from thop import profile 18 | from thop import clever_format 19 | flops, params = profile(model.to(device), inputs=(x, h0, c0)) 20 | flops, params = clever_format([flops, params], "%.3f") 21 | print(flops, params) 22 | -------------------------------------------------------------------------------- /talkingface/preprocess.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import numpy as np 3 | import cv2 4 | import os 5 | import sys 6 | import time 7 | import argparse 8 | from talkingface.run_utils import video_pts_process, concat_output_2binfile 9 | from talkingface.mediapipe_utils import detect_face_mesh,detect_face 10 | from talkingface.utils import main_keypoints_index,INDEX_LIPS 11 | # 1、是否是mp4,宽高是否大于200,时长是否大于2s,可否成功转换为符合格式的mp4 12 | # 2、面部关键点检测及是否可以构成循环视频 13 | # 4、旋转矩阵、面部mask估计 14 | # 5、验证文件完整性 15 | 16 | dir_ = "data/asset/Actor" 17 | def print_log(task_id, progress, status, Error, mode = 0): 18 | ''' 19 | status: -1代表未开始, 0代表处理中, 1代表已完成, 2代表出错中断 20 | progress: 0-1000, 进度千分比 21 | ''' 22 | print("task_id: {}. progress: {:0>4d}. status: {}. mode: {}. Error: {}".format(task_id, progress, status, mode, Error)) 23 | sys.stdout.flush() 24 | 25 | def check_step0(task_id, video_path): 26 | try: 27 | cap = cv2.VideoCapture(video_path) 28 | vid_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 29 | vid_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 30 | frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) 31 | fps = cap.get(cv2.CAP_PROP_FPS) 32 | cap.release() 33 | if vid_width < 200 or vid_height < 200: 34 | print_log(task_id, 0, 2, "video width/height < 200") 35 | return 0 36 | if frames < 2*fps: 37 | print_log(task_id, 0, 2, "video duration < 2s") 38 | return 0 39 | os.makedirs(os.path.join(dir_, task_id), exist_ok=True) 40 | front_video_path = os.path.join("data", "front.mp4") 41 | scale = max(vid_width / 720., vid_height / 1280.) 42 | if scale > 1: 43 | new_width = int(vid_width / scale + 0.1)//2 * 2 44 | new_height = int(vid_height / scale + 0.1)//2 * 2 45 | ffmpeg_cmd = "ffmpeg -i {} -r 25 -ss 00:00:00 -t 00:02:00 -vf scale={}:{} -an -loglevel quiet -y {}".format( 46 | video_path,new_width,new_height,front_video_path) 47 | else: 48 | ffmpeg_cmd = "ffmpeg -i {} -r 25 -ss 00:00:00 -t 00:02:00 -an -loglevel quiet -y {}".format( 49 | video_path, front_video_path) 50 | os.system(ffmpeg_cmd) 51 | if not os.path.isfile(front_video_path): 52 | return 0 53 | return 1 54 | except: 55 | print_log(task_id, 0, 2, "video cant be opened") 56 | return 0 57 | 58 | def check_step1(task_id): 59 | front_video_path = os.path.join("data", "front.mp4") 60 | back_video_path = os.path.join("data", "back.mp4") 61 | video_out_path = os.path.join(dir_, task_id, "video.mp4") 62 | face_info_path = os.path.join(dir_, task_id, "video_info.bin") 63 | preview_path = os.path.join(dir_, task_id, "preview.jpg") 64 | if ExtractFromVideo(task_id, front_video_path) != 1: 65 | shutil.rmtree(os.path.join(dir_, task_id)) 66 | return 0 67 | ffmpeg_cmd = "ffmpeg -i {} -vf reverse -loglevel quiet -y {}".format(front_video_path, back_video_path) 68 | os.system(ffmpeg_cmd) 69 | ffmpeg_cmd = "ffmpeg -f concat -i {} -loglevel quiet -y {}".format("data/video_concat.txt", video_out_path) 70 | os.system(ffmpeg_cmd) 71 | ffmpeg_cmd = "ffmpeg -i {} -vf crop=w='min(iw\,ih)':h='min(iw\,ih)',scale=256:256,setsar=1 -vframes 1 {}".format(front_video_path, preview_path) 72 | # ffmpeg_cmd = "ffmpeg -i {} -vf scale=256:-1 -loglevel quiet -y {}".format(front_video_path, preview_path) 73 | os.system(ffmpeg_cmd) 74 | if os.path.isfile(front_video_path): 75 | os.remove(front_video_path) 76 | if os.path.isfile(back_video_path): 77 | os.remove(back_video_path) 78 | if os.path.isfile(video_out_path) and os.path.isfile(face_info_path): 79 | return 1 80 | else: 81 | return 0 82 | 83 | # def check_step2(task_id, ): 84 | # mat_list, pts_normalized_list, face_mask_pts = video_pts_process(pts_array_origin) 85 | 86 | 87 | def ExtractFromVideo(task_id, front_video_path): 88 | cap = cv2.VideoCapture(front_video_path) 89 | if not cap.isOpened(): 90 | print_log(task_id, 0, 2, "front_video cant be opened by opencv") 91 | return -1 92 | 93 | vid_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度 94 | vid_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度 95 | 96 | totalFrames = cap.get(cv2.CAP_PROP_FRAME_COUNT) # 总帧数 97 | totalFrames = int(totalFrames) 98 | pts_3d = np.zeros([totalFrames, 478, 3]) 99 | frame_index = 0 100 | face_rect_list = [] 101 | start_time = time.time() 102 | while cap.isOpened(): 103 | ret, frame = cap.read() # 按帧读取视频 104 | # #到视频结尾时终止 105 | if ret is False: 106 | break 107 | rect_2d = detect_face([frame]) 108 | rect = rect_2d[0] 109 | tag_ = 1 if np.sum(rect) > 0 else 0 110 | if frame_index == 0 and tag_ != 1: 111 | print_log(task_id, 0, 2, "no face detected in first frame") 112 | cap.release() # 释放视频对象 113 | return 0 114 | elif tag_ == 0: # 有时候人脸检测会失败,就用上一帧的结果替代这一帧的结果 115 | rect = face_rect_list[-1] 116 | 117 | face_rect_list.append(rect) 118 | 119 | x_min = rect[0] * vid_width 120 | y_min = rect[2] * vid_height 121 | x_max = rect[1] * vid_width 122 | y_max = rect[3] * vid_height 123 | seq_w, seq_h = x_max - x_min, y_max - y_min 124 | x_mid, y_mid = (x_min + x_max) / 2, (y_min + y_max) / 2 125 | x_min = int(max(0, x_mid - seq_w * 0.65)) 126 | y_min = int(max(0, y_mid - seq_h * 0.4)) 127 | x_max = int(min(vid_width, x_mid + seq_w * 0.65)) 128 | y_max = int(min(vid_height, y_mid + seq_h * 0.8)) 129 | 130 | frame_face = frame[y_min:y_max, x_min:x_max] 131 | frame_kps = detect_face_mesh([frame_face])[0] 132 | if np.sum(frame_kps) == 0: 133 | print_log(task_id, 0, 2, "Frame num {} keypoint error".format(frame_index)) 134 | cap.release() # 释放视频对象 135 | return 0 136 | pts_3d[frame_index] = frame_kps + np.array([x_min, y_min, 0]) 137 | frame_index += 1 138 | 139 | if time.time() - start_time > 0.5: 140 | progress = int(1000 * frame_index / totalFrames * 0.99) 141 | print_log(task_id, progress, 0, "handling...") 142 | start_time = time.time() 143 | cap.release() # 释放视频对象 144 | if type(pts_3d) is np.ndarray and len(pts_3d) == totalFrames: 145 | pts_3d_main = pts_3d[:, main_keypoints_index] 146 | mat_list, pts_normalized_list, face_pts_mean_personal, face_mask_pts_normalized = video_pts_process(pts_3d_main) 147 | 148 | output = concat_output_2binfile(mat_list, pts_3d, face_pts_mean_personal, face_mask_pts_normalized) 149 | # print(output.shape) 150 | pts_normalized_list = np.array(pts_normalized_list)[:, INDEX_LIPS] 151 | # 找出此模特正面人脸的嘴巴区域范围 152 | x_max, x_min = np.max(pts_normalized_list[:, :, 0]), np.min(pts_normalized_list[:, :, 0]) 153 | y_max, y_min = np.max(pts_normalized_list[:, :, 1]), np.min(pts_normalized_list[:, :, 1]) 154 | y_min = y_min + (y_max - y_min) / 10. 155 | 156 | first_line = np.zeros([406]) 157 | first_line[:4] = np.array([x_min,x_max,y_min,y_max]) 158 | # print(first_line) 159 | # 160 | # pts_2d_main = pts_3d[:, main_keypoints_index, :2].reshape(len(pts_3d), -1) 161 | # smooth_array_ = np.array(mat_list).reshape(-1, 16)*100 162 | # 163 | # output = np.concatenate([smooth_array_, pts_2d_main], axis=1).astype(np.float32) 164 | output = np.concatenate([first_line.reshape(1,-1), output], axis=0).astype(np.float32) 165 | # print(smooth_array_.shape, pts_2d_main.shape, first_line.shape, output.shape) 166 | face_info_path = os.path.join(dir_, task_id, "video_info.bin") 167 | # np.savetxt(face_info_path, output, fmt='%.1f') 168 | # print(222) 169 | output.tofile(face_info_path) 170 | return 1 171 | else: 172 | print_log(task_id, 0, 2, "keypoint cant be saved") 173 | return 0 174 | 175 | def check_step0_audio(task_id, video_path): 176 | dir_ = "data/asset/Audio" 177 | wav_path = os.path.join(dir_, task_id + ".wav") 178 | ffmpeg_cmd = "ffmpeg -i {} -ac 1 -ar 16000 -loglevel quiet -y {}".format( 179 | video_path, wav_path) 180 | os.system(ffmpeg_cmd) 181 | if not os.path.isfile(wav_path): 182 | print_log(task_id, 0, 2, "audio convert failed", 2) 183 | return 0 184 | return 1 185 | 186 | def new_task(task_id, task_mode, video_path): 187 | # print(task_id, task_mode, video_path) 188 | if task_mode == "0": # "actor" 189 | print_log(task_id, 0, 0, "handling...") 190 | if check_step0(task_id, video_path): 191 | print_log(task_id, 0, 0, "handling...") 192 | if check_step1(task_id): 193 | print_log(task_id, 1000, 1, "process finished, click to confirm") 194 | if task_mode == "2": # "audio" 195 | print_log(task_id, 0, 0, "handling...", 2) 196 | if check_step0_audio(task_id, video_path): 197 | print_log(task_id, 1000, 1, "process finished, click to confirm", 2) 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser(description='Inference code to preprocess videos') 201 | parser.add_argument('--task_id', type=str, help='task_id') 202 | parser.add_argument('--task_mode', type=str, help='task_mode') 203 | parser.add_argument('--video_path', type=str, help='Filepath of video that contains faces to use') 204 | args = parser.parse_args() 205 | new_task(args.task_id, args.task_mode, args.video_path) -------------------------------------------------------------------------------- /talkingface/render_model.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | import os 4 | import numpy as np 5 | import time 6 | from talkingface.run_utils import smooth_array, video_pts_process 7 | from talkingface.run_utils import mouth_replace, prepare_video_data 8 | from talkingface.utils import generate_face_mask, INDEX_LIPS_OUTER 9 | from talkingface.data.few_shot_dataset import select_ref_index,get_ref_images_fromVideo,generate_input, generate_input_pixels 10 | device = "cuda" if torch.cuda.is_available() else "cpu" 11 | import pickle 12 | import cv2 13 | 14 | 15 | face_mask = generate_face_mask() 16 | 17 | 18 | class RenderModel: 19 | def __init__(self): 20 | self.__net = None 21 | 22 | self.__pts_driven = None 23 | self.__mat_list = None 24 | self.__pts_normalized_list = None 25 | self.__face_mask_pts = None 26 | self.__ref_img = None 27 | self.__cap_input = None 28 | self.frame_index = 0 29 | self.__mouth_coords_array = None 30 | 31 | def loadModel(self, ckpt_path): 32 | from talkingface.models.DINet import DINet_five_Ref as DINet 33 | n_ref = 5 34 | source_channel = 6 35 | ref_channel = n_ref * 6 36 | self.__net = DINet(source_channel, ref_channel).cuda() 37 | checkpoint = torch.load(ckpt_path) 38 | self.__net.load_state_dict(checkpoint) 39 | self.__net.eval() 40 | 41 | def reset_charactor(self, video_path, Path_pkl, ref_img_index_list = None): 42 | if self.__cap_input is not None: 43 | self.__cap_input.release() 44 | 45 | self.__pts_driven, self.__mat_list,self.__pts_normalized_list, self.__face_mask_pts, self.__ref_img, self.__cap_input = \ 46 | prepare_video_data(video_path, Path_pkl, ref_img_index_list) 47 | 48 | ref_tensor = torch.from_numpy(self.__ref_img / 255.).float().permute(2, 0, 1).unsqueeze(0).cuda() 49 | self.__net.ref_input(ref_tensor) 50 | 51 | x_min, x_max = np.min(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 0]), np.max(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 0]) 52 | y_min, y_max = np.min(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 1]), np.max(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 1]) 53 | z_min, z_max = np.min(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 2]), np.max(self.__pts_normalized_list[:, INDEX_LIPS_OUTER, 2]) 54 | 55 | x_mid,y_mid,z_mid = (x_min + x_max)/2, (y_min + y_max)/2, (z_min + z_max)/2 56 | x_len, y_len, z_len = (x_max - x_min)/2, (y_max - y_min)/2, (z_max - z_min)/2 57 | x_min, x_max = x_mid - x_len*0.9, x_mid + x_len*0.9 58 | y_min, y_max = y_mid - y_len*0.9, y_mid + y_len*0.9 59 | z_min, z_max = z_mid - z_len*0.9, z_mid + z_len*0.9 60 | 61 | # print(face_personal.shape, x_min, x_max, y_min, y_max, z_min, z_max) 62 | coords_array = np.zeros([100, 150, 4]) 63 | for i in range(100): 64 | for j in range(150): 65 | coords_array[i, j, 0] = j/149 66 | coords_array[i, j, 1] = i/100 67 | # coords_array[i, j, 2] = int((-75 + abs(j - 75))*(2./3)) 68 | coords_array[i, j, 2] = ((j - 75)/ 75) ** 2 69 | coords_array[i, j, 3] = 1 70 | 71 | coords_array = coords_array*np.array([x_max - x_min, y_max - y_min, z_max - z_min, 1]) + np.array([x_min, y_min, z_min, 0]) 72 | self.__mouth_coords_array = coords_array.reshape(-1, 4).transpose(1, 0) 73 | 74 | 75 | 76 | def interface(self, mouth_frame): 77 | vid_frame_count = self.__cap_input.get(cv2.CAP_PROP_FRAME_COUNT) 78 | if self.frame_index % vid_frame_count == 0: 79 | self.__cap_input.set(cv2.CAP_PROP_POS_FRAMES, 0) # 设置要获取的帧号 80 | ret, frame = self.__cap_input.read() # 按帧读取视频 81 | 82 | epoch = self.frame_index // len(self.__mat_list) 83 | if epoch % 2 == 0: 84 | new_index = self.frame_index % len(self.__mat_list) 85 | else: 86 | new_index = -1 - self.frame_index % len(self.__mat_list) 87 | 88 | # print(self.__face_mask_pts.shape, "ssssssss") 89 | source_img, target_img, crop_coords = generate_input_pixels(frame, self.__pts_driven[new_index], self.__mat_list[new_index], 90 | mouth_frame, self.__face_mask_pts[new_index], 91 | self.__mouth_coords_array) 92 | 93 | # tensor 94 | source_tensor = torch.from_numpy(source_img / 255.).float().permute(2, 0, 1).unsqueeze(0).cuda() 95 | target_tensor = torch.from_numpy(target_img / 255.).float().permute(2, 0, 1).unsqueeze(0).cuda() 96 | 97 | source_tensor, source_prompt_tensor = source_tensor[:, :3], source_tensor[:, 3:] 98 | fake_out = self.__net.interface(source_tensor, source_prompt_tensor) 99 | 100 | image_numpy = fake_out.detach().squeeze(0).cpu().float().numpy() 101 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 102 | image_numpy = image_numpy.clip(0, 255) 103 | image_numpy = image_numpy.astype(np.uint8) 104 | 105 | image_numpy = target_img * face_mask + image_numpy * (1 - face_mask) 106 | 107 | img_bg = frame 108 | x_min, y_min, x_max, y_max = crop_coords 109 | 110 | img_face = cv2.resize(image_numpy, (x_max - x_min, y_max - y_min)) 111 | img_bg[y_min:y_max, x_min:x_max] = img_face 112 | self.frame_index += 1 113 | return img_bg 114 | 115 | def save(self, path): 116 | torch.save(self.__net.state_dict(), path) -------------------------------------------------------------------------------- /talkingface/run_utils.py: -------------------------------------------------------------------------------- 1 | from talkingface.utils import * 2 | import os 3 | import pickle 4 | import copy 5 | def Tensor2img(tensor_, channel_index): 6 | frame = tensor_[channel_index:channel_index + 3, :, :].detach().squeeze(0).cpu().float().numpy() 7 | frame = np.transpose(frame, (1, 2, 0)) * 255.0 8 | frame = frame.clip(0, 255) 9 | return frame.astype(np.uint8) 10 | INDEX_EYE_CORNER = [INDEX_LEFT_EYE[0], INDEX_LEFT_EYE[8], INDEX_RIGHT_EYE[0], INDEX_RIGHT_EYE[8]] 11 | # INDEX_FACE_OVAL_UPPER = INDEX_FACE_OVAL[7:-7] 12 | INDEX_FACE_OVAL_UPPER = INDEX_FACE_OVAL[:7] + INDEX_FACE_OVAL[-7:] 13 | def calc_face_mat(pts_array_origin, face_pts_mean): 14 | ''' 15 | 16 | :param pts_array_origin: mediapipe检测出的人脸关键点 17 | :return: 18 | ''' 19 | face_pts_mean_rotatePts = face_pts_mean[INDEX_EYEBROW + INDEX_NOSE + INDEX_EYE_CORNER + INDEX_FACE_OVAL_UPPER] 20 | A = np.zeros([len(face_pts_mean_rotatePts) * 3, 12]) 21 | for i in range(len(face_pts_mean_rotatePts)): 22 | A[3 * i + 0, 0:3] = face_pts_mean_rotatePts[i] 23 | A[3 * i + 0, 3] = 1 24 | A[3 * i + 1, 4:7] = face_pts_mean_rotatePts[i] 25 | A[3 * i + 1, 7] = 1 26 | A[3 * i + 2, 8:11] = face_pts_mean_rotatePts[i] 27 | A[3 * i + 2, 11] = 1 28 | A_inverse = np.linalg.pinv(A) 29 | pts_normalized_list = [] 30 | mat_list = [] 31 | for i in pts_array_origin: 32 | i_rotatePts = i[INDEX_EYEBROW + INDEX_NOSE + INDEX_EYE_CORNER + INDEX_FACE_OVAL_UPPER] 33 | B = i_rotatePts.flatten() 34 | x = A_inverse.dot(B) 35 | rotationMatrix = np.zeros([4, 4]) 36 | rotationMatrix[:3, :] = x.reshape(3, 4) 37 | rotationMatrix[3, 3] = 1 38 | mat_list.append(rotationMatrix) 39 | 40 | for index_, i in enumerate(pts_array_origin): 41 | rotationMatrix = mat_list[index_] 42 | keypoints = np.ones([4, len(i)]) 43 | keypoints[:3, :] = i.T 44 | keypoints_normalized = np.linalg.inv(rotationMatrix).dot(keypoints).T 45 | pts_normalized_list.append(keypoints_normalized[:, :3]) 46 | 47 | # mouth_normalized_adjust = np.mean(np.array(pts_normalized_list)[:, INDEX_LIPS], axis=0) 48 | # print(mouth_normalized_adjust.shape) 49 | 50 | face_pts_mean_personal = np.mean(np.array(pts_normalized_list), axis=0) 51 | # print("face_pts_mean_personal", face_pts_mean_personal.shape) 52 | face_pts_mean_rotatePts2 = face_pts_mean_personal[ 53 | INDEX_EYEBROW + INDEX_NOSE + INDEX_EYE_CORNER + INDEX_FACE_OVAL_UPPER] 54 | A2 = np.zeros([len(face_pts_mean_rotatePts2) * 3, 12]) 55 | for i in range(len(face_pts_mean_rotatePts2)): 56 | A2[3 * i + 0, 0:3] = face_pts_mean_rotatePts2[i] 57 | A2[3 * i + 0, 3] = 1 58 | A2[3 * i + 1, 4:7] = face_pts_mean_rotatePts2[i] 59 | A2[3 * i + 1, 7] = 1 60 | A2[3 * i + 2, 8:11] = face_pts_mean_rotatePts2[i] 61 | A2[3 * i + 2, 11] = 1 62 | A2_inverse = np.linalg.pinv(A2) 63 | pts_normalized_list = [] 64 | mat_list = [] 65 | for i in pts_array_origin: 66 | i_rotatePts = i[INDEX_EYEBROW + INDEX_NOSE + INDEX_EYE_CORNER + INDEX_FACE_OVAL_UPPER] 67 | B = i_rotatePts.flatten() 68 | x = A2_inverse.dot(B) 69 | rotationMatrix = np.zeros([4, 4]) 70 | rotationMatrix[:3, :] = x.reshape(3, 4) 71 | rotationMatrix[3, 3] = 1 72 | mat_list.append(rotationMatrix) 73 | 74 | mat_list = np.array(mat_list) 75 | # mat_list必须要平滑,注意是针对每个视频分别平滑 76 | sub_mat_list = mat_list 77 | smooth_array_ = sub_mat_list.reshape(-1, 16) 78 | smooth_array_ = smooth_array(smooth_array_) 79 | # print(smooth_array_, smooth_array_.shape) 80 | smooth_array_ = smooth_array_.reshape(-1, 4, 4) 81 | mat_list = smooth_array_ 82 | mat_list = [hh for hh in mat_list] 83 | 84 | for index_,i in enumerate(pts_array_origin): 85 | rotationMatrix = mat_list[index_] 86 | keypoints = np.ones([4, len(i)]) 87 | keypoints[:3, :] = i.T 88 | keypoints_normalized = np.linalg.inv(rotationMatrix).dot(keypoints).T 89 | pts_normalized_list.append(keypoints_normalized[:,:3]) 90 | 91 | return mat_list,pts_normalized_list,face_pts_mean_personal 92 | face_pts_mean = None 93 | def video_pts_process(pts_array_origin): 94 | global face_pts_mean 95 | if face_pts_mean is None: 96 | current_dir = os.path.dirname(os.path.abspath(__file__)) 97 | face_pts_mean = np.loadtxt(os.path.join(current_dir, "../data/face_pts_mean_mainKps.txt")) 98 | # 先根据pts_array_origin计算出旋转矩阵、去除旋转后的人脸关键点、面部mask、 99 | mat_list, pts_normalized_list, face_pts_mean_personal = calc_face_mat(pts_array_origin, face_pts_mean) 100 | pts_normalized_list = np.array(pts_normalized_list) 101 | face_mask_pts_normalized = face_pts_mean_personal[INDEX_FACE_OVAL].copy() 102 | face_mask_pts_normalized[:10, 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[:10], 1], 103 | axis=0) + np.arange(5, 25, 2) 104 | face_mask_pts_normalized[:10, 0] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[:10], 0], 105 | axis=0) - (9 - np.arange(0, 10)) 106 | face_mask_pts_normalized[-10:, 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 1], 107 | axis=0) - np.arange(5, 25, 2) + 28 108 | face_mask_pts_normalized[-10:, 0] = np.min(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 0], 109 | axis=0) + np.arange(0, 10) 110 | face_mask_pts_normalized[10, 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[10], 1], axis=0) + 25 111 | 112 | face_mask_pts = np.zeros([len(mat_list), len(face_mask_pts_normalized), 2]) 113 | for index_ in range(len(mat_list)): 114 | rotationMatrix = mat_list[index_] 115 | 116 | keypoints = np.ones([4, len(face_mask_pts_normalized)]) 117 | keypoints[:3, :] = face_mask_pts_normalized.T 118 | driving_mask = rotationMatrix.dot(keypoints).T 119 | face_mask_pts[index_] = driving_mask[:,:2] 120 | 121 | return mat_list, pts_normalized_list, face_pts_mean_personal, face_mask_pts 122 | 123 | def mouth_replace(pts_array_origin, frames_num): 124 | ''' 125 | 126 | :param pts_array_origin: mediapipe检测出的人脸关键点 127 | :return: 128 | ''' 129 | if os.path.isfile("face_pts_mean_mainKps.txt"): 130 | face_pts_mean = np.loadtxt("face_pts_mean_mainKps.txt") 131 | else: 132 | face_pts_mean = np.loadtxt("data/face_pts_mean_mainKps.txt") 133 | mat_list,pts_normalized_list,face_pts_mean_personal = calc_face_mat(pts_array_origin, face_pts_mean) 134 | face_personal = face_pts_mean_personal.copy() 135 | pts_normalized_list = np.array(pts_normalized_list) 136 | # face_pts_mean_personal[INDEX_FACE_OVAL[:10], 1] = np.max(pts_normalized_list[:,INDEX_FACE_OVAL[:10], 1], axis = 0) + 20 137 | # face_pts_mean_personal[INDEX_FACE_OVAL[:10], 0] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[:10], 0], axis=0) + 10 138 | # face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 1], axis=0) + 20 139 | # face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 0] = np.min(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 0], axis=0) - 10 140 | # face_pts_mean_personal[INDEX_FACE_OVAL[10], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[10], 1], axis=0) + 20 141 | face_pts_mean_personal[INDEX_FACE_OVAL[:10], 1] = np.max(pts_normalized_list[:,INDEX_FACE_OVAL[:10], 1], axis = 0) + np.arange(5,25,2) 142 | face_pts_mean_personal[INDEX_FACE_OVAL[:10], 0] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[:10], 0], axis=0) - (9 - np.arange(0,10)) 143 | face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 1], axis=0) - np.arange(5,25,2) + 28 144 | face_pts_mean_personal[INDEX_FACE_OVAL[-10:], 0] = np.min(pts_normalized_list[:, INDEX_FACE_OVAL[-10:], 0], axis=0) + np.arange(0,10) 145 | 146 | face_pts_mean_personal[INDEX_FACE_OVAL[10], 1] = np.max(pts_normalized_list[:, INDEX_FACE_OVAL[10], 1], axis=0) + 25 147 | 148 | face_pts_mean_personal = face_pts_mean_personal[INDEX_FACE_OVAL] 149 | face_mask_pts = np.zeros([len(mat_list), len(face_pts_mean_personal), 2]) 150 | for index_ in range(len(mat_list)): 151 | rotationMatrix = mat_list[index_] 152 | 153 | keypoints = np.ones([4, len(face_pts_mean_personal)]) 154 | keypoints[:3, :] = face_pts_mean_personal.T 155 | driving_mask = rotationMatrix.dot(keypoints).T 156 | face_mask_pts[index_] = driving_mask[:,:2] 157 | 158 | iteration = frames_num // len(pts_array_origin) + 1 159 | if iteration == 1: 160 | pass 161 | else: 162 | pts_array_origin2 = copy.deepcopy(pts_array_origin) 163 | mat_list2 = copy.deepcopy(mat_list) 164 | face_mask_pts2 = copy.deepcopy(face_mask_pts) 165 | for i in range(iteration - 1): 166 | if i % 2 == 0: 167 | pts_array_origin2 = np.concatenate( 168 | [pts_array_origin2, pts_array_origin[::-1]], axis=0) 169 | mat_list2 += mat_list[::-1] 170 | face_mask_pts2 = np.concatenate( 171 | [face_mask_pts2, face_mask_pts[::-1]], axis=0) 172 | else: 173 | pts_array_origin2 = np.concatenate( 174 | [pts_array_origin2, pts_array_origin], axis=0) 175 | mat_list2 += mat_list 176 | face_mask_pts2 = np.concatenate( 177 | [face_mask_pts2, face_mask_pts], axis=0) 178 | pts_array_origin = pts_array_origin2 179 | mat_list = mat_list2 180 | face_mask_pts = face_mask_pts2 181 | 182 | pts_array_origin, mat_list, face_mask_pts = pts_array_origin[:frames_num], mat_list[:frames_num], face_mask_pts[:frames_num] 183 | return pts_array_origin, mat_list, face_mask_pts, face_personal, pts_normalized_list 184 | 185 | 186 | def concat_output_2binfile(mat_list, pts_3d, face_pts_mean_personal, face_mask_pts_normalized): 187 | face_stable_pts_2d = np.zeros([len(mat_list), len(INDEX_FACE_OVAL + INDEX_MUSCLE), 2]) # 法令纹和脸部外轮廓关键点 188 | face_mask_pts_2d = np.zeros([len(mat_list), face_mask_pts_normalized.shape[0], 2]) 189 | for index_, i in enumerate(mat_list): 190 | rotationMatrix = i 191 | # 法令纹和脸部外轮廓关键点 192 | driving_mouth_pts = face_pts_mean_personal[INDEX_FACE_OVAL + INDEX_MUSCLE] 193 | keypoints = np.ones([4, len(driving_mouth_pts)]) 194 | keypoints[:3, :] = driving_mouth_pts.T 195 | driving_mouth_pts = rotationMatrix.dot(keypoints).T 196 | face_stable_pts_2d[index_] = driving_mouth_pts[:, :2] 197 | 198 | # 脸部mask关键点 199 | driving_mouth_pts = face_mask_pts_normalized 200 | keypoints = np.ones([4, len(driving_mouth_pts)]) 201 | keypoints[:3, :] = driving_mouth_pts.T 202 | driving_mouth_pts = rotationMatrix.dot(keypoints).T 203 | face_mask_pts_2d[index_] = driving_mouth_pts[:, :2] 204 | 205 | 206 | pts_2d_main = pts_3d[:, main_keypoints_index, :2].reshape(len(pts_3d), -1) 207 | smooth_array_ = np.array(mat_list).reshape(-1, 16) * 100 208 | face_mask_pts_2d = face_mask_pts_2d.reshape(len(face_mask_pts_2d), -1) 209 | face_stable_pts_2d = face_stable_pts_2d.reshape(len(face_stable_pts_2d), -1) 210 | 211 | output = np.concatenate([smooth_array_, pts_2d_main, face_mask_pts_2d, face_stable_pts_2d], axis=1).astype(np.float32) 212 | return output 213 | from talkingface.data.few_shot_dataset import select_ref_index,get_ref_images_fromVideo 214 | def prepare_video_data(video_path, Path_pkl, ref_img_index_list, ref_img = None,save_ref = None): 215 | with open(Path_pkl, "rb") as f: 216 | images_info = pickle.load(f)[:, main_keypoints_index, :] 217 | 218 | pts_driven = images_info.reshape(len(images_info), -1) 219 | pts_driven = smooth_array(pts_driven).reshape(len(pts_driven), -1, 3) 220 | cap_input = cv2.VideoCapture(video_path) 221 | if ref_img is not None: 222 | ref_img = cv2.imread(ref_img).reshape(256, -1, 256, 3).transpose(0, 2, 1, 3).reshape(256, 256, -1) 223 | else: 224 | if ref_img_index_list is None: 225 | ref_img_index_list = select_ref_index(pts_driven, n_ref=5, ratio=1 / 2.) 226 | ref_img = get_ref_images_fromVideo(cap_input, ref_img_index_list, pts_driven[:, :, :2]) 227 | 228 | mat_list, pts_normalized_list, face_pts_mean_personal, face_mask_pts = video_pts_process(pts_driven) 229 | 230 | if save_ref is not None: 231 | h, w, c = ref_img.shape 232 | ref_img_ = ref_img.reshape(h, w, -1, 3).transpose(0, 2, 1, 3).reshape(h, -1, 3) 233 | # ref_path = "ref2.png" 234 | cv2.imwrite(save_ref, ref_img_) 235 | # logger.info("参考图片已存至{}.".format(ref_path)) 236 | 237 | return pts_driven, mat_list, pts_normalized_list, face_mask_pts, ref_img, cap_input -------------------------------------------------------------------------------- /talkingface/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /talkingface/util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /talkingface/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /talkingface/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /talkingface/util/log_board.py: -------------------------------------------------------------------------------- 1 | def log( 2 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 3 | ): 4 | if losses is not None: 5 | logger.add_scalar("Loss/d_loss", losses[0], step) 6 | logger.add_scalar("Loss/g_gan_loss", losses[1], step) 7 | logger.add_scalar("Loss/g_l1_loss", losses[2], step) 8 | 9 | if fig is not None: 10 | logger.add_image(tag, fig, 2, dataformats='HWC') 11 | 12 | if audio is not None: 13 | logger.add_audio( 14 | tag, 15 | audio / max(abs(audio)), 16 | sample_rate=sampling_rate, 17 | ) -------------------------------------------------------------------------------- /talkingface/util/smooth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def smooth_array(array, weight = [0.1,0.8,0.1]): 7 | ''' 8 | 9 | Args: 10 | array: [n_frames, n_values], 需要转换为[n_values, 1, n_frames] 11 | weight: Conv1d.weight, 一维卷积核权重 12 | Returns: 13 | array: [n_frames, n_values], 光滑后的array 14 | ''' 15 | input = torch.Tensor(np.transpose(array[:,np.newaxis,:], (2, 1, 0))) 16 | smooth_length = len(weight) 17 | assert smooth_length%2 == 1, "卷积核权重个数必须使用奇数" 18 | pad = (smooth_length//2, smooth_length//2) # 当pad只有两个参数时,仅改变最后一个维度, 左边扩充1列,右边扩充1列 19 | input = F.pad(input, pad, "replicate") 20 | 21 | with torch.no_grad(): 22 | conv1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=smooth_length) 23 | # 卷积核的元素值初始化 24 | weight = torch.tensor(weight).view(1, 1, -1) 25 | conv1.weight = torch.nn.Parameter(weight) 26 | nn.init.constant_(conv1.bias, 0) # 偏置值为0 27 | # print(conv1.weight) 28 | out = conv1(input) 29 | return out.permute(2,1,0).squeeze().numpy() 30 | 31 | if __name__ == '__main__': 32 | model_id = "new_case" 33 | Path_output_pkl = "../preparation/{}/mouth_info.pkl".format(model_id + "/00001") 34 | import pickle 35 | with open(Path_output_pkl, "rb") as f: 36 | images_info = pickle.load(f) 37 | pts_array_normalized = np.array(images_info[2]) 38 | pts_array_normalized = pts_array_normalized.reshape(-1, 16) 39 | smooth_array_ = smooth_array(pts_array_normalized) 40 | print(smooth_array_, smooth_array_.shape) 41 | smooth_array_ = smooth_array_.reshape(-1, 4, 4) 42 | import pandas as pd 43 | 44 | pd.DataFrame(smooth_array_[:, :, 0]).to_csv("mat2.csv") -------------------------------------------------------------------------------- /talkingface/util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 25 | else: # if it is a numpy array, do nothing 26 | image_numpy = input_image 27 | return image_numpy.astype(imtype) 28 | 29 | 30 | def diagnose_network(net, name='network'): 31 | """Calculate and print the mean of average absolute(gradients) 32 | 33 | Parameters: 34 | net (torch network) -- Torch network 35 | name (str) -- the name of the network 36 | """ 37 | mean = 0.0 38 | count = 0 39 | for param in net.parameters(): 40 | if param.grad is not None: 41 | mean += torch.mean(torch.abs(param.grad.data)) 42 | count += 1 43 | if count > 0: 44 | mean = mean / count 45 | print(name) 46 | print(mean) 47 | 48 | 49 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 50 | """Save a numpy image to the disk 51 | 52 | Parameters: 53 | image_numpy (numpy array) -- input numpy array 54 | image_path (str) -- the path of the image 55 | """ 56 | 57 | image_pil = Image.fromarray(image_numpy) 58 | h, w, _ = image_numpy.shape 59 | 60 | if aspect_ratio > 1.0: 61 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 62 | if aspect_ratio < 1.0: 63 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 64 | image_pil.save(image_path) 65 | 66 | 67 | def print_numpy(x, val=True, shp=False): 68 | """Print the mean, min, max, median, std, and size of a numpy array 69 | 70 | Parameters: 71 | val (bool) -- if print the values of the numpy array 72 | shp (bool) -- if print the shape of the numpy array 73 | """ 74 | x = x.astype(np.float64) 75 | if shp: 76 | print('shape,', x.shape) 77 | if val: 78 | x = x.flatten() 79 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 80 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 81 | 82 | 83 | def mkdirs(paths): 84 | """create empty directories if they don't exist 85 | 86 | Parameters: 87 | paths (str list) -- a list of directory paths 88 | """ 89 | if isinstance(paths, list) and not isinstance(paths, str): 90 | for path in paths: 91 | mkdir(path) 92 | else: 93 | mkdir(paths) 94 | 95 | 96 | def mkdir(path): 97 | """create a single empty directory if it didn't exist 98 | 99 | Parameters: 100 | path (str) -- a single directory path 101 | """ 102 | if not os.path.exists(path): 103 | os.makedirs(path) 104 | -------------------------------------------------------------------------------- /talkingface/util/utils.py: -------------------------------------------------------------------------------- 1 | from torch.optim import lr_scheduler 2 | 3 | import torch.nn as nn 4 | import torch 5 | 6 | ######################################################### training utils########################################################## 7 | 8 | def get_scheduler(optimizer, niter,niter_decay,lr_policy='lambda',lr_decay_iters=50): 9 | ''' 10 | scheduler in training stage 11 | ''' 12 | if lr_policy == 'lambda': 13 | def lambda_rule(epoch): 14 | lr_l = 1.0 - max(0, epoch - niter) / float(niter_decay + 1) 15 | return lr_l 16 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 17 | elif lr_policy == 'step': 18 | scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1) 19 | elif lr_policy == 'plateau': 20 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 21 | elif lr_policy == 'cosine': 22 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0) 23 | else: 24 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 25 | return scheduler 26 | 27 | def update_learning_rate(scheduler, optimizer): 28 | scheduler.step() 29 | lr = optimizer.param_groups[0]['lr'] 30 | print('learning rate = %.7f' % lr) 31 | 32 | class GANLoss(nn.Module): 33 | ''' 34 | GAN loss 35 | ''' 36 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 37 | super(GANLoss, self).__init__() 38 | self.register_buffer('real_label', torch.tensor(target_real_label)) 39 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 40 | if use_lsgan: 41 | self.loss = nn.MSELoss() 42 | else: 43 | self.loss = nn.BCELoss() 44 | 45 | def get_target_tensor(self, input, target_is_real): 46 | if target_is_real: 47 | target_tensor = self.real_label 48 | else: 49 | target_tensor = self.fake_label 50 | return target_tensor.expand_as(input) 51 | 52 | def forward(self, input, target_is_real): 53 | target_tensor = self.get_target_tensor(input, target_is_real) 54 | return self.loss(input, target_tensor) 55 | 56 | 57 | 58 | import tqdm 59 | import numpy as np 60 | import cv2 61 | import glob 62 | import os 63 | import math 64 | import pickle 65 | import mediapipe as mp 66 | mp_face_mesh = mp.solutions.face_mesh 67 | landmark_points_68 = [162,234,93,58,172,136,149,148,152,377,378,365,397,288,323,454,389, 68 | 71,63,105,66,107,336,296,334,293,301, 69 | 168,197,5,4,75,97,2,326,305, 70 | 33,160,158,133,153,144,362,385,387,263,373, 71 | 380,61,39,37,0,267,269,291,405,314,17,84,181,78,82,13,312,308,317,14,87] 72 | def ExtractFaceFromFrameList(frames_list, vid_height, vid_width, out_size = 256): 73 | pts_3d = np.zeros([len(frames_list), 478, 3]) 74 | with mp_face_mesh.FaceMesh( 75 | static_image_mode=True, 76 | max_num_faces=1, 77 | refine_landmarks=True, 78 | min_detection_confidence=0.5) as face_mesh: 79 | 80 | for index, frame in tqdm.tqdm(enumerate(frames_list)): 81 | results = face_mesh.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 82 | if not results.multi_face_landmarks: 83 | print("****** WARNING! No face detected! ******") 84 | pts_3d[index] = 0 85 | return 86 | # continue 87 | image_height, image_width = frame.shape[:2] 88 | for face_landmarks in results.multi_face_landmarks: 89 | for index_, i in enumerate(face_landmarks.landmark): 90 | x_px = min(math.floor(i.x * image_width), image_width - 1) 91 | y_px = min(math.floor(i.y * image_height), image_height - 1) 92 | z_px = min(math.floor(i.z * image_height), image_height - 1) 93 | pts_3d[index, index_] = np.array([x_px, y_px, z_px]) 94 | 95 | # 计算整个视频中人脸的范围 96 | 97 | x_min, y_min, x_max, y_max = np.min(pts_3d[:, :, 0]), np.min( 98 | pts_3d[:, :, 1]), np.max( 99 | pts_3d[:, :, 0]), np.max(pts_3d[:, :, 1]) 100 | new_w = int((x_max - x_min) * 0.55)*2 101 | new_h = int((y_max - y_min) * 0.6)*2 102 | center_x = int((x_max + x_min) / 2.) 103 | center_y = int(y_min + (y_max - y_min) * 0.6) 104 | size = max(new_h, new_w) 105 | x_min, y_min, x_max, y_max = int(center_x - size // 2), int(center_y - size // 2), int( 106 | center_x + size // 2), int(center_y + size // 2) 107 | 108 | # 确定裁剪区域上边top和左边left坐标 109 | top = y_min 110 | left = x_min 111 | # 裁剪区域与原图的重合区域 112 | top_coincidence = int(max(top, 0)) 113 | bottom_coincidence = int(min(y_max, vid_height)) 114 | left_coincidence = int(max(left, 0)) 115 | right_coincidence = int(min(x_max, vid_width)) 116 | 117 | scale = out_size / size 118 | pts_3d = (pts_3d - np.array([left, top, 0])) * scale 119 | pts_3d = pts_3d 120 | 121 | face_rect = np.array([center_x, center_y, size]) 122 | print(np.array([x_min, y_min, x_max, y_max])) 123 | 124 | img_array = np.zeros([len(frames_list), out_size, out_size, 3], dtype = np.uint8) 125 | for index, frame in tqdm.tqdm(enumerate(frames_list)): 126 | img_new = np.zeros([size, size, 3], dtype=np.uint8) 127 | img_new[top_coincidence - top:bottom_coincidence - top, left_coincidence - left:right_coincidence - left,:] = \ 128 | frame[top_coincidence:bottom_coincidence, left_coincidence:right_coincidence, :] 129 | img_new = cv2.resize(img_new, (out_size, out_size)) 130 | img_array[index] = img_new 131 | return pts_3d,img_array, face_rect 132 | 133 | -------------------------------------------------------------------------------- /talkingface/util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | 9 | 10 | try: 11 | import wandb 12 | except ImportError: 13 | print('Warning: wandb package cannot be found. The option "--use_wandb" will result in error.') 14 | 15 | if sys.version_info[0] == 2: 16 | VisdomExceptionBase = Exception 17 | else: 18 | VisdomExceptionBase = ConnectionError 19 | 20 | 21 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, use_wandb=False): 22 | """Save images to the disk. 23 | 24 | Parameters: 25 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 26 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 27 | image_path (str) -- the string is used to create image paths 28 | aspect_ratio (float) -- the aspect ratio of saved images 29 | width (int) -- the images will be resized to width x width 30 | 31 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 32 | """ 33 | image_dir = webpage.get_image_dir() 34 | short_path = ntpath.basename(image_path[0]) 35 | name = os.path.splitext(short_path)[0] 36 | 37 | webpage.add_header(name) 38 | ims, txts, links = [], [], [] 39 | ims_dict = {} 40 | for label, im_data in visuals.items(): 41 | im = util.tensor2im(im_data) 42 | image_name = '%s_%s.png' % (name, label) 43 | save_path = os.path.join(image_dir, image_name) 44 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 45 | ims.append(image_name) 46 | txts.append(label) 47 | links.append(image_name) 48 | if use_wandb: 49 | ims_dict[label] = wandb.Image(im) 50 | webpage.add_images(ims, txts, links, width=width) 51 | if use_wandb: 52 | wandb.log(ims_dict) 53 | 54 | 55 | class Visualizer(): 56 | """This class includes several functions that can display/save images and print/save logging information. 57 | 58 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 59 | """ 60 | 61 | def __init__(self, opt): 62 | """Initialize the Visualizer class 63 | 64 | Parameters: 65 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 66 | Step 1: Cache the training/test options 67 | Step 2: connect to a visdom server 68 | Step 3: create an HTML object for saveing HTML filters 69 | Step 4: create a logging file to store training losses 70 | """ 71 | self.opt = opt # cache the option 72 | self.display_id = opt.display_id 73 | self.use_html = opt.isTrain and not opt.no_html 74 | self.win_size = opt.display_winsize 75 | self.name = opt.name 76 | self.port = opt.display_port 77 | self.saved = False 78 | self.use_wandb = opt.use_wandb 79 | self.wandb_project_name = opt.wandb_project_name 80 | self.current_epoch = 0 81 | self.ncols = opt.display_ncols 82 | 83 | if self.display_id > 0: # connect to a visdom server given and 84 | import visdom 85 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 86 | if not self.vis.check_connection(): 87 | self.create_visdom_connections() 88 | 89 | if self.use_wandb: 90 | self.wandb_run = wandb.init(project=self.wandb_project_name, name=opt.name, config=opt) if not wandb.run else wandb.run 91 | self.wandb_run._label(repo='CycleGAN-and-pix2pix') 92 | 93 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 94 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 95 | self.img_dir = os.path.join(self.web_dir, 'images') 96 | print('create web directory %s...' % self.web_dir) 97 | util.mkdirs([self.web_dir, self.img_dir]) 98 | # create a logging file to store training losses 99 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 100 | with open(self.log_name, "a") as log_file: 101 | now = time.strftime("%c") 102 | log_file.write('================ Training Loss (%s) ================\n' % now) 103 | 104 | def reset(self): 105 | """Reset the self.saved status""" 106 | self.saved = False 107 | 108 | def create_visdom_connections(self): 109 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 110 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 111 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 112 | print('Command: %s' % cmd) 113 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 114 | 115 | def display_current_results(self, visuals, epoch, save_result): 116 | """Display current results on visdom; save current results to an HTML file. 117 | 118 | Parameters: 119 | visuals (OrderedDict) - - dictionary of images to display or save 120 | epoch (int) - - the current epoch 121 | save_result (bool) - - if save the current results to an HTML file 122 | """ 123 | if self.display_id > 0: # show images in the browser using visdom 124 | ncols = self.ncols 125 | if ncols > 0: # show all the images in one visdom panel 126 | ncols = min(ncols, len(visuals)) 127 | h, w = next(iter(visuals.values())).shape[:2] 128 | table_css = """""" % (w, h) # create a table css 132 | # create a table of images. 133 | title = self.name 134 | label_html = '' 135 | label_html_row = '' 136 | images = [] 137 | idx = 0 138 | for label, image in visuals.items(): 139 | image_numpy = util.tensor2im(image) 140 | label_html_row += '%s' % label 141 | images.append(image_numpy.transpose([2, 0, 1])) 142 | idx += 1 143 | if idx % ncols == 0: 144 | label_html += '%s' % label_html_row 145 | label_html_row = '' 146 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 147 | while idx % ncols != 0: 148 | images.append(white_image) 149 | label_html_row += '' 150 | idx += 1 151 | if label_html_row != '': 152 | label_html += '%s' % label_html_row 153 | try: 154 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 155 | padding=2, opts=dict(title=title + ' images')) 156 | label_html = '%s
' % label_html 157 | self.vis.text(table_css + label_html, win=self.display_id + 2, 158 | opts=dict(title=title + ' labels')) 159 | except VisdomExceptionBase: 160 | self.create_visdom_connections() 161 | 162 | else: # show each image in a separate visdom panel; 163 | idx = 1 164 | try: 165 | for label, image in visuals.items(): 166 | image_numpy = util.tensor2im(image) 167 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 168 | win=self.display_id + idx) 169 | idx += 1 170 | except VisdomExceptionBase: 171 | self.create_visdom_connections() 172 | 173 | if self.use_wandb: 174 | columns = [key for key, _ in visuals.items()] 175 | columns.insert(0, 'epoch') 176 | result_table = wandb.Table(columns=columns) 177 | table_row = [epoch] 178 | ims_dict = {} 179 | for label, image in visuals.items(): 180 | image_numpy = util.tensor2im(image) 181 | wandb_image = wandb.Image(image_numpy) 182 | table_row.append(wandb_image) 183 | ims_dict[label] = wandb_image 184 | self.wandb_run.log(ims_dict) 185 | if epoch != self.current_epoch: 186 | self.current_epoch = epoch 187 | result_table.add_data(*table_row) 188 | self.wandb_run.log({"Result": result_table}) 189 | 190 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 191 | self.saved = True 192 | # save images to the disk 193 | for label, image in visuals.items(): 194 | image_numpy = util.tensor2im(image) 195 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 196 | util.save_image(image_numpy, img_path) 197 | 198 | # update website 199 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 200 | for n in range(epoch, 0, -1): 201 | webpage.add_header('epoch [%d]' % n) 202 | ims, txts, links = [], [], [] 203 | 204 | for label, image_numpy in visuals.items(): 205 | image_numpy = util.tensor2im(image) 206 | img_path = 'epoch%.3d_%s.png' % (n, label) 207 | ims.append(img_path) 208 | txts.append(label) 209 | links.append(img_path) 210 | webpage.add_images(ims, txts, links, width=self.win_size) 211 | webpage.save() 212 | 213 | def plot_current_losses(self, epoch, counter_ratio, losses): 214 | """display the current losses on visdom display: dictionary of error labels and values 215 | 216 | Parameters: 217 | epoch (int) -- current epoch 218 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 219 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 220 | """ 221 | if not hasattr(self, 'plot_data'): 222 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 223 | self.plot_data['X'].append(epoch + counter_ratio) 224 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 225 | try: 226 | self.vis.line( 227 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 228 | Y=np.array(self.plot_data['Y']), 229 | opts={ 230 | 'title': self.name + ' loss over time', 231 | 'legend': self.plot_data['legend'], 232 | 'xlabel': 'epoch', 233 | 'ylabel': 'loss'}, 234 | win=self.display_id) 235 | except VisdomExceptionBase: 236 | self.create_visdom_connections() 237 | if self.use_wandb: 238 | self.wandb_run.log(losses) 239 | 240 | # losses: same format as |losses| of plot_current_losses 241 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 242 | """print current losses on console; also save the losses to the disk 243 | 244 | Parameters: 245 | epoch (int) -- current epoch 246 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 247 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 248 | t_comp (float) -- computational time per data point (normalized by batch_size) 249 | t_data (float) -- data loading time per data point (normalized by batch_size) 250 | """ 251 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 252 | for k, v in losses.items(): 253 | message += '%s: %.3f ' % (k, v) 254 | 255 | print(message) # print the message 256 | with open(self.log_name, "a") as log_file: 257 | log_file.write('%s\n' % message) # save the message 258 | -------------------------------------------------------------------------------- /talkingface/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import cv2 4 | 5 | INDEX_LEFT_EYEBROW = [276, 283, 282, 295, 285, 336, 296, 334, 293, 300] 6 | INDEX_RIGHT_EYEBROW = [46, 53, 52, 65, 55, 107, 66, 105, 63, 70] 7 | INDEX_EYEBROW = INDEX_LEFT_EYEBROW + INDEX_RIGHT_EYEBROW 8 | 9 | INDEX_NOSE_EDGE = [343, 355, 358, 327, 326, 2, 97, 98, 129, 126, 114] 10 | INDEX_NOSE_MID = [6, 197,195,5,4] 11 | INDEX_NOSE = INDEX_NOSE_EDGE + INDEX_NOSE_MID 12 | 13 | INDEX_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324,308,415,310,311,312,13,82,81,80,191] 14 | INDEX_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291,409,270,269,267,0,37,39,40,185,] 15 | INDEX_LIPS = INDEX_LIPS_INNER + INDEX_LIPS_OUTER 16 | 17 | INDEX_LEFT_EYE = [263, 249, 390, 373, 374, 380, 381, 382, 362, 398, 384, 385, 386, 387, 388, 466] 18 | INDEX_RIGHT_EYE = [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246] 19 | INDEX_EYE = INDEX_LEFT_EYE + INDEX_RIGHT_EYE 20 | # 下半边脸的轮廓 21 | INDEX_FACE_OVAL = [ 22 | 454, 323, 361, 288, 397, 365, 23 | 379, 378, 400, 377, 152, 148, 176, 149, 150, 24 | 136, 172, 58, 132, 93, 234, 25 | # 206, 426 26 | ] 27 | 28 | INDEX_MUSCLE = [ 29 | 371,266,425,427,434,394, 30 | 169,214,207,205,36,142 31 | ] 32 | 33 | main_keypoints_index = INDEX_EYEBROW + INDEX_NOSE + INDEX_LIPS + INDEX_EYE + INDEX_FACE_OVAL + INDEX_MUSCLE 34 | 35 | rotate_ref_index = INDEX_EYEBROW + INDEX_NOSE + INDEX_EYE + INDEX_FACE_OVAL + INDEX_MUSCLE 36 | # print(len(main_keypoints_index)) 37 | Normalized = True 38 | if Normalized: 39 | tmp = 0 40 | list_ = [] 41 | for i in [INDEX_LEFT_EYEBROW, INDEX_RIGHT_EYEBROW, INDEX_NOSE_EDGE, INDEX_NOSE_MID, INDEX_LIPS_INNER, INDEX_LIPS_OUTER, INDEX_LEFT_EYE, INDEX_RIGHT_EYE, INDEX_FACE_OVAL, INDEX_MUSCLE]: 42 | i = (tmp + np.arange(len(i))).tolist() 43 | list_.append(i) 44 | tmp += len(i) 45 | [INDEX_LEFT_EYEBROW, INDEX_RIGHT_EYEBROW, INDEX_NOSE_EDGE, INDEX_NOSE_MID, INDEX_LIPS_INNER, INDEX_LIPS_OUTER, INDEX_LEFT_EYE, INDEX_RIGHT_EYE, INDEX_FACE_OVAL, INDEX_MUSCLE] = list_ 46 | INDEX_EYEBROW = INDEX_LEFT_EYEBROW + INDEX_RIGHT_EYEBROW 47 | INDEX_NOSE = INDEX_NOSE_EDGE + INDEX_NOSE_MID 48 | INDEX_LIPS = INDEX_LIPS_INNER + INDEX_LIPS_OUTER 49 | INDEX_EYE = INDEX_LEFT_EYE + INDEX_RIGHT_EYE 50 | # print(INDEX_FACE_OVAL_STABLE) 51 | # print(len(main_keypoints_index)) 52 | # print(len(set(main_keypoints_index))) 53 | FACE_MASK_INDEX = INDEX_FACE_OVAL[2:-2] 54 | def crop_face(keypoints, is_train = False, size = [512, 512]): 55 | """ 56 | x_ratio: 裁剪出一个正方形,边长根据keypoints的宽度 * x_ratio决定 57 | """ 58 | x_min, y_min, x_max, y_max = np.min(keypoints[FACE_MASK_INDEX, 0]), np.min(keypoints[FACE_MASK_INDEX, 1]), np.max(keypoints[FACE_MASK_INDEX, 0]), np.max( 59 | keypoints[FACE_MASK_INDEX, 1]) 60 | y_min = keypoints[33, 1] # 两眼间的点开始y轴裁剪 61 | border_width_half = max(x_max - x_min, y_max - y_min) * 0.6 62 | center_x = int((x_min + x_max) / 2.0) 63 | center_y = int((y_min + y_max) / 2.0) 64 | if is_train: 65 | w_offset = random.randint(-2, 2) 66 | h_offset = random.randint(-2, 2) 67 | center_x = center_x + w_offset 68 | center_y = center_y + h_offset 69 | x_min, y_min, x_max, y_max = int(center_x - border_width_half), int(center_y - border_width_half), int( 70 | center_x + border_width_half), int(center_y + border_width_half) 71 | x_min = max(0, x_min) 72 | y_min = max(0, y_min) 73 | x_max = min(size[1], x_max) 74 | y_max = min(size[0], y_max) 75 | return [x_min, y_min, x_max, y_max] 76 | 77 | def draw_face_feature_maps(keypoints, mode = ["mouth", "nose", "eye", "oval"], size=(256, 256), im_edges = None, mouth_width = None, mouth_height = None): 78 | w, h = size 79 | # edge map for face region from keypoints 80 | if im_edges is None: 81 | im_edges = np.zeros((h, w, 3), np.uint8) # edge map for all edges 82 | 83 | if "mouth_bias" in mode: 84 | 85 | w0, w1, h0, h1 = (int(keypoints[INDEX_NOSE_EDGE[5], 0] - mouth_width / 2), 86 | int(keypoints[INDEX_NOSE_EDGE[5], 0] + mouth_width / 2), 87 | int(keypoints[INDEX_NOSE_EDGE[5], 1] + mouth_height / 4), 88 | int(keypoints[INDEX_NOSE_EDGE[5], 1] + mouth_height / 4 + mouth_height)) 89 | w0, h0 = max(0, w0), max(h0, 0) 90 | print(w0, w1, h0, h1) 91 | mouth_mask = np.zeros((h, w, 3), np.uint8) # edge map for all edges 92 | mouth_mask[h0:h1, w0:w1] = 255 93 | mouth_index = np.where(mouth_mask == 255) 94 | blur = (10, 10) 95 | img_mouth = cv2.cvtColor(im_edges, cv2.COLOR_BGR2GRAY) 96 | img_mouth = cv2.blur(img_mouth, blur) 97 | # print(mouth_index) 98 | mean_ = int(np.mean(img_mouth[(mouth_index[0], mouth_index[1])])) 99 | max_, min_ = random.randint(mean_ + 40, mean_ + 70), random.randint(mean_ - 70, mean_ - 40) 100 | img_mouth = (img_mouth.astype(np.float32) - min_) / (max_ - min_) * 255. 101 | img_mouth = img_mouth.clip(0, 255).astype(np.uint8) 102 | print(img_mouth.shape[0], img_mouth.shape[1]) 103 | img_mouth = cv2.resize(img_mouth, (100, 50)) 104 | 105 | # 定义噪声的标准差 106 | sigma = 8 # 你可以根据需要调整这个值 107 | # 生成与图片相同大小和类型的噪声 108 | noise = sigma * np.random.randn(img_mouth.shape[0], img_mouth.shape[1]) 109 | # 将噪声添加到图片上 110 | img_mouth = img_mouth + noise 111 | img_mouth = cv2.resize(img_mouth.clip(0, 255).astype(np.uint8), (im_edges.shape[0], im_edges.shape[1])) 112 | 113 | img_mouth = np.concatenate( 114 | [img_mouth[:, :, np.newaxis], img_mouth[:, :, np.newaxis], img_mouth[:, :, np.newaxis]], axis=2) 115 | 116 | # bias_x = int(min(size[0] - max(mouth_index[0]), min(mouth_index[0]))) 117 | # bias_y = int(min(size[0] - max(mouth_index[1]), min(mouth_index[1]))) 118 | # print(bias_x, bias_y) 119 | # bias_x = random.randint(-bias_x // 10, bias_x // 10) 120 | # bias_y = random.randint(-bias_y // 10, bias_y // 10) 121 | # 122 | # mouth_index2 = np.array(mouth_index) 123 | # # print(mouth_index) 124 | # mouth_index2[0] = mouth_index[0] + bias_x 125 | # mouth_index2[1] = mouth_index[1] + bias_y 126 | # mouth_index2 = (mouth_index2[0], mouth_index2[1], mouth_index2[2]) 127 | # # print(mouth_index2.T, mouth_index2.shape) 128 | # 129 | output = np.zeros((h, w, 3), np.uint8) 130 | # output[mouth_index2] = img_mouth[mouth_index] 131 | # return im_edges 132 | output[mouth_index] = img_mouth[mouth_index] 133 | im_edges = output 134 | if "nose" in mode: 135 | for ii in range(len(INDEX_NOSE_EDGE) - 1): 136 | pt1 = [int(flt) for flt in keypoints[INDEX_NOSE_EDGE[ii]]][:2] 137 | pt2 = [int(flt) for flt in keypoints[INDEX_NOSE_EDGE[ii+1]]][:2] 138 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (0, 255, 0), 2) 139 | for ii in range(len(INDEX_NOSE_MID) -1): 140 | pt1 = [int(flt) for flt in keypoints[INDEX_NOSE_MID[ii]]][:2] 141 | pt2 = [int(flt) for flt in keypoints[INDEX_NOSE_MID[ii + 1]]][:2] 142 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (0, 255, 0), 2) 143 | if "eye" in mode: 144 | for ii in range(len(INDEX_LEFT_EYE)): 145 | pt1 = [int(flt) for flt in keypoints[INDEX_LEFT_EYE[ii]]][:2] 146 | pt2 = [int(flt) for flt in keypoints[INDEX_LEFT_EYE[(ii + 1)%len(INDEX_LEFT_EYE)]]][:2] 147 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (0, 255, 0), 2) 148 | for ii in range(len(INDEX_RIGHT_EYE)): 149 | pt1 = [int(flt) for flt in keypoints[INDEX_RIGHT_EYE[ii]]][:2] 150 | pt2 = [int(flt) for flt in keypoints[INDEX_RIGHT_EYE[(ii + 1) % len(INDEX_RIGHT_EYE)]]][:2] 151 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (0, 255, 0), 2) 152 | if "oval" in mode: 153 | tmp = INDEX_FACE_OVAL[:6] 154 | for ii in range(len(tmp) -1): 155 | pt1 = [int(flt) for flt in keypoints[tmp[ii]]][:2] 156 | pt2 = [int(flt) for flt in keypoints[tmp[ii + 1]]][:2] 157 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (0, 0, 255), 2) 158 | tmp = INDEX_FACE_OVAL[-6:] 159 | for ii in range(len(tmp) - 1): 160 | pt1 = [int(flt) for flt in keypoints[tmp[ii]]][:2] 161 | pt2 = [int(flt) for flt in keypoints[tmp[ii + 1]]][:2] 162 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (0, 0, 255), 2) 163 | 164 | # if "mouth_outer" in mode: 165 | # pts = keypoints[INDEX_LIPS_OUTER] 166 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 167 | # cv2.fillPoly(im_edges, [pts], color=(255, 0, 0)) 168 | # if "mouth" in mode: 169 | # pts = keypoints[INDEX_LIPS_OUTER][:,2] 170 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 171 | # cv2.fillPoly(im_edges, [pts], color=(255, 0, 0)) 172 | # pts = keypoints[INDEX_LIPS_INNER][:,2] 173 | # pts = pts.reshape((-1, 1, 2)).astype(np.int32) 174 | # cv2.fillPoly(im_edges, [pts], color=(0, 0, 0)) 175 | # if "mouth_outer" in mode: 176 | # for ii in range(len(INDEX_LIPS_OUTER)): 177 | # pt1 = [int(flt) for flt in keypoints[INDEX_LIPS_OUTER[ii]]][:2] 178 | # pt2 = [int(flt) for flt in keypoints[INDEX_LIPS_OUTER[(ii + 1)%len(INDEX_LIPS_OUTER)]]][:2] 179 | # cv2.line(im_edges, tuple(pt1), tuple(pt2), (255, 0, 0), 2) 180 | 181 | 182 | 183 | if "mouth" in mode: 184 | for ii in range(len(INDEX_LIPS_OUTER)): 185 | pt1 = [int(flt) for flt in keypoints[INDEX_LIPS_OUTER[ii]]][:2] 186 | pt2 = [int(flt) for flt in keypoints[INDEX_LIPS_OUTER[(ii + 1)%len(INDEX_LIPS_OUTER)]]][:2] 187 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (255, 0, 0), 2) 188 | for ii in range(len(INDEX_LIPS_INNER)): 189 | pt1 = [int(flt) for flt in keypoints[INDEX_LIPS_INNER[ii]]][:2] 190 | pt2 = [int(flt) for flt in keypoints[INDEX_LIPS_INNER[(ii + 1)%len(INDEX_LIPS_INNER)]]][:2] 191 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (255, 0, 0), 2) 192 | if "muscle" in mode: 193 | for ii in range(len(INDEX_MUSCLE) - 1): 194 | pt1 = [int(flt) for flt in keypoints[INDEX_MUSCLE[ii]]][:2] 195 | pt2 = [int(flt) for flt in keypoints[INDEX_MUSCLE[(ii + 1) % len(INDEX_MUSCLE)]]][:2] 196 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (255, 255, 255), 2) 197 | 198 | if "oval_all" in mode: 199 | tmp = INDEX_FACE_OVAL 200 | for ii in range(len(tmp) -1): 201 | pt1 = [int(flt) for flt in keypoints[tmp[ii]]][:2] 202 | pt2 = [int(flt) for flt in keypoints[tmp[ii + 1]]][:2] 203 | cv2.line(im_edges, tuple(pt1), tuple(pt2), (0, 0, 255), 2) 204 | 205 | return im_edges 206 | 207 | def smooth_array(array, weight = [0.1,0.8,0.1], mode = "numpy"): 208 | ''' 209 | 210 | Args: 211 | array: [n_frames, n_values], 需要转换为[n_values, 1, n_frames] 212 | weight: Conv1d.weight, 一维卷积核权重 213 | Returns: 214 | array: [n_frames, n_values], 光滑后的array 215 | ''' 216 | if mode == "torch": 217 | import torch 218 | import torch.nn as nn 219 | import torch.nn.functional as F 220 | input = torch.Tensor(np.transpose(array[:, np.newaxis, :], (2, 1, 0))) 221 | smooth_length = len(weight) 222 | assert smooth_length % 2 == 1, "卷积核权重个数必须使用奇数" 223 | pad = (smooth_length // 2, smooth_length // 2) # 当pad只有两个参数时,仅改变最后一个维度, 左边扩充1列,右边扩充1列 224 | input = F.pad(input, pad, "replicate") 225 | 226 | with torch.no_grad(): 227 | conv1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=smooth_length) 228 | # 卷积核的元素值初始化 229 | weight = torch.tensor(weight).view(1, 1, -1) 230 | conv1.weight = torch.nn.Parameter(weight) 231 | nn.init.constant_(conv1.bias, 0) # 偏置值为0 232 | # print(conv1.weight) 233 | out = conv1(input) 234 | return out.permute(2, 1, 0).squeeze().numpy() 235 | else: 236 | # out = np.zeros([array.shape[0] + 2, array.shape[1]]) 237 | # input = np.zeros([array.shape[0] + 2, array.shape[1]]) 238 | # input[0] = array[0] 239 | # input[-1] = array[-1] 240 | # input[1:-1] = array 241 | # for i in range(out.shape[1]): 242 | # out[:, i] = np.convolve(input[:, i], weight, mode="same") 243 | # out0 = out[1:-1] 244 | fliter = np.array([weight]).T 245 | x0 = array 246 | fliter = np.repeat(fliter, x0.shape[1], axis=1) 247 | out0 = np.zeros_like(x0) 248 | for i in range(len(x0)): 249 | if i == 0 or i == len(x0) - 1: 250 | out0[i] = x0[i] 251 | else: 252 | tmp = x0[i - 1:i + 2] * fliter 253 | out0[i] = np.sum(tmp, axis=0) 254 | return out0 255 | 256 | def generate_face_mask(): 257 | face_mask = np.zeros([256, 256], dtype=np.uint8) 258 | for i in range(20): 259 | ii = 19 - i 260 | face_mask[ii, :] = 13 * i 261 | face_mask[255 - ii, :] = 13 * i 262 | face_mask[:, ii] = 13 * i 263 | face_mask[:, 255 - ii] = 13 * i 264 | face_mask = np.array([face_mask, face_mask, face_mask]).transpose(1, 2, 0).astype(float) / 255. 265 | print(face_mask.shape) 266 | return face_mask 267 | 268 | from math import cos,sin,radians 269 | def RotateAngle2Matrix(tmp): #tmp为xyz的旋转角,角度值 270 | tmp = [radians(i) for i in tmp] 271 | matX = np.array([[1.0, 0, 0], 272 | [0.0, cos(tmp[0]), -sin(tmp[0])], 273 | [0.0, sin(tmp[0]), cos(tmp[0])]]) 274 | matY = np.array([[cos(tmp[1]), 0, sin(tmp[1])], 275 | [0.0, 1, 0], 276 | [-sin(tmp[1]), 0, cos(tmp[1])]]) 277 | matZ = np.array([[cos(tmp[2]), -sin(tmp[2]), 0], 278 | [sin(tmp[2]), cos(tmp[2]), 0], 279 | [0, 0, 1]]) 280 | matRotate = np.matmul(matZ, matY) 281 | matRotate = np.matmul(matRotate, matX) 282 | return matRotate -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["kmp_duplicate_lib_ok"] = "true" 3 | from talkingface.models.common.Discriminator import Discriminator 4 | from talkingface.models.common.VGG19 import Vgg19 5 | from talkingface.models.DINet import DINet_five_Ref 6 | from talkingface.util.utils import GANLoss,get_scheduler, update_learning_rate 7 | from talkingface.config.config import DINetTrainingOptions 8 | from torch.utils.tensorboard import SummaryWriter 9 | from talkingface.util.log_board import log 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.utils.data import DataLoader 14 | import random 15 | import numpy as np 16 | import os 17 | import pandas as pd 18 | import torch.nn.functional as F 19 | import cv2 20 | from talkingface.data.few_shot_dataset import Few_Shot_Dataset,data_preparation 21 | 22 | def Tensor2img(tensor_, channel_index): 23 | frame = tensor_[channel_index:channel_index + 3, :, :].detach().squeeze(0).cpu().float().numpy() 24 | frame = np.transpose(frame, (1, 2, 0)) * 255.0 25 | frame = frame.clip(0, 255) 26 | return frame.astype(np.uint8) 27 | 28 | if __name__ == "__main__": 29 | ''' 30 | training code of person image generation 31 | ''' 32 | # load config 33 | opt = DINetTrainingOptions().parse_args() 34 | n_ref = 5 35 | opt.source_channel = 3 * 2 36 | opt.target_channel = 3 37 | opt.ref_channel = n_ref * 3 * 2 38 | opt.batch_size = 4 39 | opt.result_path = "checkpoint/Dinet_five_ref" 40 | opt.resume = False 41 | opt.resume_path = None 42 | 43 | # set seed 44 | random.seed(opt.seed) 45 | np.random.seed(opt.seed) 46 | torch.cuda.manual_seed(opt.seed) 47 | 48 | 49 | video_list = [] 50 | path_ = r"../preparation_bilibili" 51 | video_list += [os.path.join(path_, i) for i in os.listdir(path_)] 52 | 53 | print("video_selected final: ", len(video_list)) 54 | video_list.sort() 55 | train_dict_info = data_preparation(video_list[:]) 56 | train_set = Few_Shot_Dataset(train_dict_info, n_ref=n_ref, is_train=True) 57 | training_data_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=opt.batch_size, shuffle=True) 58 | train_log_path = "train_log.txt" 59 | train_data_length = len(training_data_loader) 60 | # init network 61 | net_g = DINet_five_Ref(opt.source_channel,opt.ref_channel).cuda() 62 | net_d = Discriminator(opt.target_channel, opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() 63 | net_vgg = Vgg19().cuda() 64 | 65 | # set optimizer 66 | optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr_g) 67 | optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr_d) 68 | 69 | if opt.resume: 70 | print('loading checkpoint {}'.format(opt.resume_path)) 71 | checkpoint = torch.load(opt.resume_path) 72 | # opt.start_epoch = checkpoint['epoch'] 73 | # opt.start_epoch = 200 74 | net_g_static = checkpoint['state_dict']['net_g'] 75 | net_g.load_state_dict(net_g_static) 76 | net_d.load_state_dict(checkpoint['state_dict']['net_d']) 77 | optimizer_g.load_state_dict(checkpoint['optimizer']['net_g']) 78 | optimizer_d.load_state_dict(checkpoint['optimizer']['net_d']) 79 | 80 | # set criterion 81 | criterionGAN = GANLoss().cuda() 82 | criterionL1 = nn.L1Loss().cuda() 83 | # set scheduler 84 | net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay) 85 | net_d_scheduler = get_scheduler(optimizer_d, opt.non_decay, opt.decay) 86 | 87 | 88 | 89 | train_log_path = os.path.join("checkpoint/{}/log".format("DiNet_five_ref"), "train") 90 | os.makedirs(train_log_path, exist_ok=True) 91 | train_logger = SummaryWriter(train_log_path) 92 | 93 | # start train 94 | for epoch in range(opt.start_epoch, opt.non_decay + opt.decay + 1): 95 | net_g.train() 96 | avg_loss_g_perception = 0 97 | avg_Loss_DI = 0 98 | avg_Loss_GI = 0 99 | for iteration, data in enumerate(training_data_loader): 100 | # read data 101 | source_tensor, ref_tensor, target_tensor = data 102 | source_tensor = source_tensor.float().cuda() 103 | ref_tensor = ref_tensor.float().cuda() 104 | target_tensor = target_tensor.float().cuda() 105 | 106 | source_tensor, source_prompt_tensor = source_tensor[:, :3], source_tensor[:, 3:] 107 | # network forward 108 | fake_out = net_g(source_tensor, source_prompt_tensor, ref_tensor) 109 | # down sample output image and real image 110 | fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False) 111 | target_tensor_half = F.interpolate(target_tensor, scale_factor=0.5, mode='bilinear') 112 | # (1) Update D network 113 | optimizer_d.zero_grad() 114 | # compute fake loss 115 | _,pred_fake_d = net_d(fake_out) 116 | loss_d_fake = criterionGAN(pred_fake_d, False) 117 | # compute real loss 118 | _,pred_real_d = net_d(target_tensor) 119 | loss_d_real = criterionGAN(pred_real_d, True) 120 | # Combine D loss 121 | loss_dI = (loss_d_fake + loss_d_real) * 0.5 122 | loss_dI.backward(retain_graph=True) 123 | optimizer_d.step() 124 | # (2) Update G network 125 | _, pred_fake_dI = net_d(fake_out) 126 | optimizer_g.zero_grad() 127 | # compute perception loss 128 | perception_real = net_vgg(target_tensor) 129 | perception_fake = net_vgg(fake_out) 130 | perception_real_half = net_vgg(target_tensor_half) 131 | perception_fake_half = net_vgg(fake_out_half) 132 | loss_g_perception = 0 133 | for i in range(len(perception_real)): 134 | loss_g_perception += criterionL1(perception_fake[i], perception_real[i]) 135 | loss_g_perception += criterionL1(perception_fake_half[i], perception_real_half[i]) 136 | loss_g_perception = (loss_g_perception / (len(perception_real) * 2)) * opt.lamb_perception 137 | # gan dI loss 138 | loss_g_dI = criterionGAN(pred_fake_dI, True) 139 | # combine perception loss and gan loss 140 | loss_g = loss_g_perception + loss_g_dI 141 | loss_g.backward() 142 | optimizer_g.step() 143 | message = "===> Epoch[{}]({}/{}): Loss_DI: {:.4f} Loss_GI: {:.4f} Loss_perception: {:.4f} lr_g = {:.7f} lr_d = {:.7f}".format( 144 | epoch, iteration, len(training_data_loader), float(loss_dI), float(loss_g_dI), 145 | float(loss_g_perception), optimizer_g.param_groups[0]['lr'], optimizer_d.param_groups[0]['lr']) 146 | print(message) 147 | # with open("train_log.txt", "a") as f: 148 | # f.write(message + "\n") 149 | 150 | if iteration%200 == 0: 151 | inference_out = fake_out * 255 152 | inference_out = inference_out[0].cpu().permute(1, 2, 0).float().detach().numpy().astype(np.uint8) 153 | inference_in = (target_tensor[0, :3]* 255).cpu().permute(1, 2, 0).float().detach().numpy().astype(np.uint8) 154 | inference_in_prompt = (source_prompt_tensor[0, :3] * 255).cpu().permute(1, 2, 0).float().detach().numpy().astype( 155 | np.uint8) 156 | frame2 = Tensor2img(ref_tensor[0], 0) 157 | frame3 = Tensor2img(ref_tensor[0], 3) 158 | inference_out = np.concatenate([inference_in, inference_in_prompt, inference_out, frame2, frame3], axis=1) 159 | inference_out = cv2.cvtColor(inference_out, cv2.COLOR_RGB2BGR) 160 | 161 | log(train_logger, fig=inference_out, tag="Training/epoch_{}_{}".format(epoch, iteration)) 162 | 163 | real_iteration = epoch * len(training_data_loader) + iteration 164 | message1 = "Step {}/{}, ".format(real_iteration, (epoch + 1) * len(training_data_loader)) 165 | message2 = "" 166 | losses = [loss_dI.item(), loss_g_perception.item(), loss_g_dI.item()] 167 | train_logger.add_scalar("Loss/loss_dI", losses[0], real_iteration) 168 | train_logger.add_scalar("Loss/loss_g_perception", losses[1], real_iteration) 169 | train_logger.add_scalar("Loss/loss_g_dI", losses[2], real_iteration) 170 | 171 | avg_loss_g_perception += loss_g_perception.item() 172 | avg_Loss_DI += loss_dI.item() 173 | avg_Loss_GI += loss_g_dI.item() 174 | train_logger.add_scalar("Loss/{}".format("epoch_g_perception"), avg_loss_g_perception / len(training_data_loader), epoch) 175 | train_logger.add_scalar("Loss/{}".format("epoch_DI"), 176 | avg_Loss_DI / len(training_data_loader), epoch) 177 | train_logger.add_scalar("Loss/{}".format("epoch_GI"), 178 | avg_Loss_GI / len(training_data_loader), epoch) 179 | update_learning_rate(net_g_scheduler, optimizer_g) 180 | update_learning_rate(net_d_scheduler, optimizer_d) 181 | 182 | # checkpoint 183 | if epoch % opt.checkpoint == 0: 184 | if not os.path.exists(opt.result_path): 185 | os.mkdir(opt.result_path) 186 | model_out_path = os.path.join(opt.result_path, 'epoch_{}.pth'.format(epoch)) 187 | states = { 188 | 'epoch': epoch + 1, 189 | 'state_dict': {'net_g': net_g.state_dict(), 'net_d': net_d.state_dict()}, 190 | 'optimizer': {'net_g': optimizer_g.state_dict(), 'net_d': optimizer_d.state_dict()} 191 | } 192 | torch.save(states, model_out_path) 193 | print("Checkpoint saved to {}".format(epoch)) -------------------------------------------------------------------------------- /train_input_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["kmp_duplicate_lib_ok"] = "true" 3 | import pickle 4 | import cv2 5 | import numpy as np 6 | import random 7 | 8 | import glob 9 | import copy 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from talkingface.data.few_shot_dataset import Few_Shot_Dataset,data_preparation 13 | from talkingface.utils import * 14 | path_ = r"../preparation_mix" 15 | video_list = [os.path.join(path_, i) for i in os.listdir(path_)] 16 | point_size = 1 17 | point_color = (0, 0, 255) # BGR 18 | thickness = 4 # 0 、4、8 19 | video_list = video_list[125:135] 20 | 21 | dict_info = data_preparation(video_list) 22 | 23 | device = torch.device("cuda:0") 24 | test_set = Few_Shot_Dataset(dict_info, is_train=True, n_ref = 1) 25 | testing_data_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False) 26 | 27 | def Tensor2img(tensor_, channel_index): 28 | frame = tensor_[channel_index:channel_index + 3, :, :].detach().squeeze(0).cpu().float().numpy() 29 | frame = np.transpose(frame, (1, 2, 0)) * 255.0 30 | frame = frame.clip(0, 255) 31 | return frame.astype(np.uint8) 32 | size_ = 256 33 | for iteration, batch in enumerate(testing_data_loader): 34 | # source_tensor, source_prompt_tensor, ref_tensor, ref_prompt_tensor, target_tensor = [iii.to(device) for iii in batch] 35 | source_tensor, ref_tensor, target_tensor = [iii.to(device) for iii in batch] 36 | print(source_tensor.size(), ref_tensor.size(), target_tensor.size()) 37 | 38 | frame0 = Tensor2img(source_tensor[0], 0) 39 | frame1 = Tensor2img(source_tensor[0], 3) 40 | frame2 = Tensor2img(ref_tensor[0], 0) 41 | frame3 = Tensor2img(ref_tensor[0], 3) 42 | frame4 = Tensor2img(target_tensor[0], 0) 43 | frame = np.concatenate([frame0, frame1, frame2, frame3, frame4], axis=1) 44 | 45 | cv2.imshow("ss", frame) 46 | # if iteration > 840: 47 | # cv2.waitKey(-1) 48 | cv2.waitKey(-1) 49 | # break 50 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /video_data/audio0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/video_data/audio0.wav -------------------------------------------------------------------------------- /video_data/audio1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/video_data/audio1.wav -------------------------------------------------------------------------------- /video_data/circle.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/video_data/circle.mp4 -------------------------------------------------------------------------------- /video_data/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/video_data/demo.mp4 -------------------------------------------------------------------------------- /video_data/keypoint_rotate.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/video_data/keypoint_rotate.pkl -------------------------------------------------------------------------------- /video_data/test/circle.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/video_data/test/circle.mp4 -------------------------------------------------------------------------------- /video_data/test/keypoint_rotate.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingxing2233/DH-Live-Web-UI/a26458ab769702cc57f39d004f3f974524f4ac32/video_data/test/keypoint_rotate.pkl -------------------------------------------------------------------------------- /webapp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import gradio as gr 3 | import os 4 | import cv2 5 | import uuid 6 | import shutil 7 | import subprocess 8 | import sys 9 | from talkingface.audio_model import AudioModel 10 | from talkingface.render_model import RenderModel 11 | from scipy.io import wavfile 12 | import sounddevice as sd 13 | from data_preparation import CirculateVideo, ExtractFromVideo 14 | from demo import merge_audio_video 15 | import pickle 16 | import mediapipe as mp 17 | import numpy as np 18 | 19 | from scipy.io import wavfile 20 | import requests 21 | 22 | import base64 23 | import io 24 | import pygame 25 | import edge_tts 26 | import asyncio 27 | from io import BytesIO 28 | import tqdm 29 | from datetime import datetime 30 | 31 | 32 | mp_face_mesh = mp.solutions.face_mesh 33 | mp_face_detection = mp.solutions.face_detection 34 | 35 | 36 | 37 | def display_selected_folder(folder): 38 | return f"你选择了: {folder}" 39 | 40 | 41 | def process_video(video_path): 42 | try: 43 | pts_3d = ExtractFromVideo(video_path) 44 | if isinstance(pts_3d, np.ndarray) and pts_3d.ndim == 3: 45 | result_message = "人脸点位生成成功!" 46 | elif pts_3d == -1: 47 | result_message = "第一帧人脸检测异常,请检查视频。" 48 | return result_message 49 | elif pts_3d == -2: 50 | result_message = "人脸区域变化幅度太大,请检查视频。" 51 | return result_message 52 | else: 53 | result_message = "检查点生成失败,可能是人脸检测异常或其他问题。" 54 | return result_message 55 | 56 | # 使用固定的文件夹名 57 | folder_name = "circle" 58 | 59 | # 如果文件夹已存在,则先删除 60 | if os.path.exists(folder_name): 61 | shutil.rmtree(folder_name) 62 | 63 | # 创建文件夹 64 | os.makedirs(folder_name) 65 | 66 | # 保存检查点文件 67 | checkpoint_path = os.path.join(folder_name, "keypoint_rotate.pkl") 68 | with open(checkpoint_path, "wb") as f: 69 | pickle.dump(pts_3d, f) 70 | 71 | # 移动视频文件到新文件夹 72 | new_video_path = os.path.join(folder_name, "circle.mp4") 73 | shutil.move(video_path, new_video_path) 74 | 75 | print("folder_name:", folder_name) 76 | print("new_video_path:", new_video_path) 77 | 78 | #result_message += f"'{folder_name}'" 79 | return result_message, folder_name, new_video_path, checkpoint_path, new_video_path 80 | except Exception as e: 81 | return f"处理视频时出错:{str(e)}" 82 | 83 | 84 | 85 | 86 | 87 | def convert_audio_format(audio_path): 88 | rate, wav = wavfile.read(audio_path, mmap=False) 89 | converted_audio_path = audio_path.replace(".wav", "_converted.wav") 90 | wavfile.write(converted_audio_path, rate, wav) 91 | return f"{converted_audio_path} 转换成功!" 92 | 93 | 94 | 95 | 96 | 97 | # 合成音频和视频的函数video_file_path 98 | def merge_audio_video(folder_name, audio_path, wav_path, pkl_path, video_file_path, output_video_name): 99 | try: 100 | # 将音频文件从临时目录移动到指定的根目录文件夹 101 | 102 | 103 | # 如果没有传入音频路径,使用默认路径 104 | #if audio_path is None: 105 | #audio_path = "2bj.wav" # 替换为你默认的音频路径 106 | #audio_path = f"{new_audio_path}.wav" 107 | 108 | # 检查音频文件是否存在 109 | if not os.path.exists(audio_path): 110 | return f"音频文件路径错误,未找到音频文件:{audio_path}", None, None 111 | 112 | # 使用唯一的输出视频文件名 113 | task_id = str(uuid.uuid1()) 114 | unique_output_video_name = f"{task_id}.mp4" 115 | output_video_name = os.path.join("output", unique_output_video_name) 116 | print(f"output video name is set to: {output_video_name}") 117 | 118 | # 加载音频和渲染模型 119 | audioModel = AudioModel() 120 | audioModel.loadModel("checkpoint/audio.pkl") 121 | 122 | renderModel = RenderModel() 123 | renderModel.loadModel("checkpoint/render.pth") 124 | 125 | # 初始化路径 126 | pkl_path = os.path.join(folder_name, "keypoint_rotate.pkl") 127 | video_file_path = os.path.join(folder_name, "circle.mp4") 128 | renderModel.reset_charactor(video_file_path, pkl_path) 129 | 130 | # 处理音频帧并生成视频 131 | wavpath = audio_path 132 | mouth_frame = audioModel.interface_wav(wavpath) 133 | cap_input = cv2.VideoCapture(video_file_path) 134 | vid_width = cap_input.get(cv2.CAP_PROP_FRAME_WIDTH) 135 | vid_height = cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT) 136 | cap_input.release() 137 | 138 | os.makedirs(f"output/{task_id}", exist_ok=True) 139 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 140 | save_path = f"output/{task_id}/silence.mp4" 141 | videoWriter = cv2.VideoWriter(save_path, fourcc, 25, (int(vid_width), int(vid_height))) 142 | 143 | for frame in tqdm.tqdm(mouth_frame): 144 | frame = renderModel.interface(frame) 145 | videoWriter.write(frame) 146 | 147 | videoWriter.release() 148 | 149 | # 使用 ffmpeg 合并音频和视频并自动覆盖输出文件 150 | output_video_path = os.path.join("output", unique_output_video_name) 151 | os.system( 152 | f"ffmpeg -y -i {save_path} -i {wavpath} -c:v libx264 -pix_fmt yuv420p -loglevel quiet {output_video_path}" 153 | 154 | 155 | 156 | #f"ffmpeg -y -i {save_path} -i {wavpath} -c:v libx264 -pix_fmt yuv420p -af 'async=1' -loglevel quiet {output_video_name}" 157 | 158 | 159 | 160 | 161 | ) 162 | shutil.rmtree(f"output/{task_id}") 163 | 164 | return f"数字人生成成功!", output_video_path #输出文件:{output_video_path} 165 | 166 | except Exception as e: 167 | return f"合成视频时出错:{str(e)}", None, None 168 | 169 | 170 | 171 | 172 | 173 | 174 | # 定义文本转语音函数 175 | async def convert_audio_format(text, voice): 176 | try: 177 | # 使用 edge_tts 将文本转换为音频 178 | communicate = edge_tts.Communicate(text, voice) 179 | 180 | # 存储音频数据 181 | audio_data = bytearray() 182 | async for message in communicate.stream(): 183 | if message["type"] == "audio": 184 | audio_data.extend(message["data"]) 185 | 186 | # 将音频数据保存为 WAV 文件 187 | audio_file_path = "output_audio.wav" 188 | with open(audio_file_path, "wb") as f: 189 | f.write(audio_data) 190 | 191 | # 返回成功消息和音频文件路径 192 | return "文本转语音成功!", audio_file_path 193 | except Exception as e: 194 | return str(e), None 195 | 196 | # Gradio 接口 197 | def tts_interface(text, voice_selector): 198 | # 调用异步函数并等待结果 199 | result = asyncio.run(convert_audio_format(text, voice_selector)) 200 | return result 201 | 202 | 203 | 204 | 205 | def play_audio(audio_path): 206 | return audio_path 207 | 208 | 209 | 210 | 211 | def save_audio_locally(wav_upload): 212 | # 将音频文件从临时目录移动到指定的根目录文件夹 213 | audio_folder = "audio_files" # 你希望保存音频文件的根目录文件夹 214 | os.makedirs(audio_folder, exist_ok=True) 215 | audio_filename = os.path.basename(wav_upload) 216 | new_audio_path = os.path.join(audio_folder, audio_filename) 217 | shutil.move(wav_upload, new_audio_path) 218 | return new_audio_path 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | # Gradio Interface 227 | with gr.Blocks() as demo: 228 | gr.Markdown("

数字人生成工具

") 229 | gr.Markdown("第一步") 230 | with gr.Row():#视频处理 231 | 232 | # 打开视频并展示预览 233 | video_upload_for_slicing = gr.Video(label="上传视频", include_audio=True, sources=["upload", "webcam"]) 234 | video_output = gr.Textbox(label="人脸点位生成结果") 235 | process_video_btn = gr.Button("生成人脸点位文件", variant="primary") 236 | 237 | folder_name = gr.State() # 用于存储检查点生成路径 238 | pkl_path = gr.State() 239 | video_file_path = gr.State() 240 | 241 | # 在点击按钮时,获取视频处理结果并保存检查点生成路径 242 | def process_and_store_video(video_path): 243 | # 处理视频并保存路径 244 | result_message, folder_name_value, pkl_path_value, video_file_value = process_video(video_path) 245 | return result_message, folder_name_value, pkl_path_value, video_file_value 246 | 247 | process_video_btn.click( 248 | fn=process_video, 249 | inputs=video_upload_for_slicing, 250 | #outputs=video_output 251 | outputs=[video_output, folder_name, pkl_path, video_file_path] 252 | ) 253 | gr.Markdown("第二步") 254 | with gr.Row(): 255 | # 音频处理部分 256 | 257 | # 定义发音人选项 258 | voices = [ 259 | "zh-CN-XiaoxiaoNeural", 260 | "zh-CN-YunxiNeural", 261 | "zh-CN-YunjianNeural", 262 | "zh-CN-YunyangNeural", 263 | "zh-CN-shaanxi-XiaoniNeural", 264 | "zh-HK-WanLungNeural", 265 | "zh-CN-liaoning-XiaobeiNeural", 266 | "zh-TW-HsiaoYuNeural" 267 | 268 | ] 269 | 270 | 271 | wav_path_input = gr.Textbox(label="输入文本") 272 | voice_selector = gr.Dropdown(label="选择发音人", choices=voices, value=voices[0]) 273 | 274 | convert_btn = gr.Button("文本到语音转换", variant="primary") 275 | convert_result = gr.Textbox(label="转换结果") 276 | 277 | # 音频文件加载和播放窗口 278 | audio_playback = gr.Audio(label="音频播放窗口") 279 | #vc_output2 = gr.Audio(label="Output Audio", interactive=False) 280 | 281 | 282 | # 绑定转换按钮点击事件 283 | convert_btn.click( 284 | 285 | fn=tts_interface, 286 | inputs=[wav_path_input, voice_selector], 287 | outputs=[convert_result, audio_playback] # 输出转换结果文本和音频文件路径 288 | 289 | ) 290 | 291 | # 绑定音频文件加载和播放事件 292 | wav_path_input.change( 293 | fn=play_audio, 294 | inputs=wav_path_input, 295 | outputs=audio_playback 296 | ) 297 | 298 | 299 | 300 | gr.Markdown("第三步") 301 | with gr.Row(): 302 | 303 | 304 | # 音频文件上传并播放 305 | wav_upload = gr.Audio(label="上传合成音频文件", type="filepath") 306 | checkpoint_path_output = gr.Textbox(label="数字人生成结果") 307 | 308 | process_btn = gr.Button("生成数字人", variant="primary") 309 | output_folder_list = gr.Video(label="生成的数字人", width=720, height=480) 310 | 311 | 312 | 313 | # 在视频合成按钮点击时,将生成的视频路径和检查点路径关联显示 314 | def merge_and_display_checkpoint(folder_name, audio_path, wav_upload, pkl_path, video_file_path): 315 | audio_path = save_audio_locally(wav_upload) 316 | # 调用封装的merge_audio_video函数 317 | output_video_path = merge_audio_video(folder_name, audio_path, wav_upload, pkl_path, video_file_path) 318 | return folder_name, output_video_path 319 | 320 | 321 | process_btn.click( 322 | fn=merge_audio_video, 323 | inputs=[folder_name, wav_upload, pkl_path, video_file_path], 324 | 325 | outputs=[checkpoint_path_output, output_folder_list] 326 | ) 327 | 328 | 329 | 330 | 331 | 332 | gr.Markdown("

麦克风实时驱动数字人

") 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | # Display the uploaded video 341 | video_upload_for_slicing.change( 342 | fn=lambda file_path: file_path if file_path.endswith(('.mp4', '.avi')) else None, 343 | inputs=video_upload_for_slicing, 344 | outputs=video_upload_for_slicing 345 | ) 346 | 347 | demo.launch() 348 | 349 | 350 | --------------------------------------------------------------------------------