├── .gitignore ├── README.md ├── app.py ├── checkpoints ├── generator │ └── README.md ├── image2image │ └── README.md └── syncnet │ └── README.md ├── dummy ├── README.md ├── face.gif ├── landmark.gif ├── me.png └── test_audio.mp3 ├── env.yml ├── evaluation ├── README.md ├── calc_flop.py ├── eval_vdo │ └── README.md ├── gen_eval_vdo.py └── temp │ └── README.md ├── filelists └── README.md ├── hparams.py ├── logs ├── generator │ └── README.md └── syncnet │ └── README.md ├── preprocess_data.py ├── requirements.txt ├── run_inference.py ├── run_train_generator.py ├── run_train_syncnet.py ├── src ├── __init__.py ├── dataset │ ├── generator.py │ └── syncnet.py ├── main │ ├── generator.py │ ├── inference.py │ └── syncnet.py └── models │ ├── __init__.py │ ├── attnlstm.py │ ├── image2image.py │ ├── lstmgen.py │ └── syncnet.py ├── temp └── README.md └── utils ├── audio.py ├── loss.py ├── plot.py ├── utils.py └── wav2lip.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # checkpoints 3 | checkpoints/front/* 4 | checkpoints/generator/* 5 | checkpoints/image2image/* 6 | checkpoints/syncnet/* 7 | checkpoints/end2end/* 8 | 9 | 10 | #old folder 11 | old/* 12 | 13 | # filelist (dataset) 14 | filelists/*.txt 15 | 16 | #eval 17 | evaluation/temp/temp.wav 18 | evaluation/test_filelists/* 19 | 20 | 21 | 22 | # results 23 | results/* 24 | temp/* 25 | #dummy/* 26 | face_front/ 27 | results.mp4 28 | logs/generator/* 29 | logs/syncnet/* 30 | 31 | 32 | # excluding all .md 33 | !*.md 34 | 35 | 36 | # ignore 37 | *.pyc 38 | *.mp4 39 | *.pth 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audio-Visua Lip Synthesis via intermediate landmark representation | Final Year Project (Dissertation) of Wish Suharitdamrong 2 | 3 | 4 | This is a code implementation for Wish Suharitdamrong's Final Year Project Year 3 BSc Computer Science at University of Surrey on the topic of Audio-Visua Lip Synthesis via intermediate landmark representation. 5 | 6 | 7 | ![Alt Text](./dummy/face.gif) 8 | ![Alt Text](./dummy/landmark.gif) 9 | 10 | # Demo 11 | 12 | Online demonstration is available at 🤗 [HuggingFace](https://huggingface.co/spaces/peterwisu/lip_synthesis) 13 | 14 | ## Installation 15 | 16 | 17 | There are two ways of installing package using conda or pip 18 | 19 | 1.Create virtual conda environment from `environment.yml` 20 | 21 | 2.Use pip to install a pakages (make sure you use `python 3.7`or above since older version might not support some libraries) 22 | 23 | ### Use Conda 24 | 25 | ```bash 26 | # Create virtual environment from .yml file 27 | conda env create -f environment.yml 28 | 29 | # activate virtual environment 30 | conda activate fyp 31 | ``` 32 | 33 | ### Use pip 34 | 35 | 36 | ```bash 37 | # Use pip to install require packages 38 | pip install -r requirement.txt 39 | ``` 40 | 41 | ## Dataset 42 | 43 | The audio-visual dataset used in this proejct are LRS2 and LRS3. LRS2 data was use for both model training and evaluation. LRS3 data was only used for model evaluation. 44 | 45 | | Dataset | Page| 46 | |---------- |:-------------:| 47 | | LRS2 | [Link](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs2.html)| 48 | | LRS3 | [Link](https://paperswithcode.com/dataset/lrs3-ted) | 49 | 50 | 51 | ## Pre-train weights 52 | 53 | 54 | 55 | ### Generator model 56 | Download weights Generator model 57 | | Model | Donwload Link | 58 | |---------- |:-------------:| 59 | | Generator | [Link](https://drive.google.com/file/d/19-zLzCKeH6tp5grxoRYnEEKLgIZrj-4f/view?usp=sharing)| 60 | | Generator + SyncLoss | [Link](https://drive.google.com/file/d/1Ck-54fOBeY87c6_CFXfMF0FwgWK92DqG/view?usp=sharing) | 61 | | Attention Generator + SyncLoss | [Link](https://drive.google.com/file/d/1sEM7Aqrg-YILx8dyuT2zxQkU5xRJXc_T/view?usp=sharing) | 62 | 63 | ### Landmark SyncNet discriminator 64 | 65 | 66 | Download weights for Landmark-based SyncNet model [Download Link](https://drive.google.com/file/d/1fJj-zYkfr1gSGgq5ISWGCE1byxNc6Mdp/view?usp=sharing) 67 | 68 | ### Image-to-Image Translation 69 | 70 | Pre-trained weight for Image2Image Translation model can be download from MakeItTalk repository on their pre-trained models section [Repo Link](https://github.com/yzhou359/MakeItTalk). 71 | 72 | ### Directory 73 | ```bash 74 | ├── checkpoint # Directory for model checkpoint 75 | │ └── generator # put Generator model weights here 76 | │ └── syncnet # put Landmark SyncNet model weights here 77 | │ └── image2image # put Image2Image Translation model weights here 78 | ``` 79 | 80 | ## Run Inference 81 | 82 | ``` 83 | python run_inference.py --generator_checkpoint --image2image_checkpoint --input_face --input_audio 84 | ``` 85 | 86 | ## Data Preprocessing 87 | 88 | I used same ways of data preprocessing as [Wav2Lip](https://github.com/Rudrabha/Wav2Lip) for more details of folder structure can be find in their repository [Here](https://github.com/Rudrabha/Wav2Lip). 89 | 90 | ``` 91 | python preprocess_data.py --data_root data_root/main --preprocessed_root preprocessed_lrs2_landmark/ 92 | ``` 93 | 94 | ## Train Model 95 | 96 | ### Generator 97 | 98 | 99 | ``` 100 | # CLI for traning attention generator with pretrain landmark SyncNet discriminator 101 | python run_train_generator.py --model_type attnlstm --train_type pretrain --data_root preprocessed_lrs2_landmark/ --checkpoint_dir 102 | ``` 103 | 104 | ### Landmark SyncNet 105 | 106 | 107 | ``` 108 | # CLI for training pretrain landmark SyncNet discriminator 109 | python run_train_syncnet.py --data_root preprocessed_lrs2_landmark/ --checkpoint_dir 110 | ``` 111 | 112 | ## Generate video for evaluation & benchmark from LRS2 and LRS3 113 | 114 | This project used data from LRS2 and LRS3 dataset for quantitative evaluation, the list of evaluation data is provide from [Wav2Lip](https://github.com/Rudrabha/Wav2Lip). The filelist(video and audio data used for evaluation) and details about Lip Sync benchmark are available in their repository [Here](https://github.com/Rudrabha/Wav2Lip). 115 | 116 | ### Generate evaluation from filelist 117 | ``` 118 | cd evaluation 119 | # generate evaluation videos 120 | python gen_eval_vdo.py --filelist --data_root --model_type --result_dir --generator_checkpoint --image2image_checkpoint 121 | ``` 122 | 123 | 124 | 125 | 126 | # Acknowledgement 127 | 128 | 129 | The code base of this project was inspired from [Wav2Lip](https://github.com/Rudrabha/Wav2Lip) and [MakeItTalk](https://github.com/yzhou359/MakeItTalk). I would like to thanks author of both project for making code implementation of their amazing work available online. 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Gradio app 2 | 3 | import gradio as gr 4 | import os 5 | import argparse 6 | from src.main.inference import Inference 7 | 8 | MODEL_TYPE = ['lstm','attn_lstm'] 9 | 10 | MODEL_NAME = { 'lstm':('./checkpoints/generator/benchmark/pure_lstmgen_l1.pth','lstm'), 11 | 'lstm_syncnet' : ('./checkpoints/generator/benchmark/pretrain_lstmgen_l1.pth','lstm'), 12 | 'attn_lstm_syncnet': ('./checkpoints/generator/benchmark/attn_generator_020_l1_1e_2.pth','attn_lstm')} 13 | 14 | print(MODEL_NAME.keys()) 15 | 16 | def func(video,audio,check,drop_down): 17 | 18 | path , model_type = MODEL_NAME[drop_down] 19 | 20 | print(path) 21 | 22 | print(model_type) 23 | 24 | parser = argparse.ArgumentParser(description="File for running Inference") 25 | 26 | parser.add_argument('--model_type', help='Type of generator model', default=model_type, type=str) 27 | 28 | parser.add_argument('--generator_checkpoint', type=str ,default=path) 29 | 30 | parser.add_argument('--image2image_checkpoint', type=str, default='./checkpoints/image2image/image2image.pth',required=False) 31 | 32 | parser.add_argument('--input_face', type=str,default=video, required=False) 33 | 34 | parser.add_argument('--input_audio', type=str, default=audio, required=False) 35 | 36 | # parser.add_argument('--output_path', type=str, help="Path for saving the result", default='result.mp4', required=False) 37 | 38 | parser.add_argument('--fps', type=float, default=25,required=False) 39 | 40 | parser.add_argument('--fl_detector_batchsize', type=int , default = 2) 41 | 42 | parser.add_argument('--generator_batchsize', type=int, default=2) 43 | 44 | parser.add_argument('--output_name', type=str , default="results.mp4") 45 | 46 | 47 | parser.add_argument('--only_fl', type=bool , default=False) 48 | 49 | parser.add_argument('--vis_fl', type=bool, default=check) 50 | 51 | parser.add_argument('--test_img2img', type=bool, help="Testing image2image module with no lip generation" , default=False) 52 | 53 | args = parser.parse_args() 54 | 55 | 56 | Inference(args=args).start() 57 | 58 | 59 | return './results.mp4' 60 | 61 | 62 | 63 | def gui(): 64 | with gr.Blocks() as video_tab: 65 | 66 | with gr.Row(): 67 | 68 | with gr.Column(): 69 | video = gr.Video().style() 70 | 71 | audio = gr.Audio(source="upload", type="filepath") 72 | 73 | with gr.Column(): 74 | outputs = gr.PlayableVideo() 75 | 76 | 77 | 78 | with gr.Row(): 79 | 80 | 81 | with gr.Column(): 82 | 83 | check_box = gr.Checkbox(value=False,label="Do you want to visualize reconstructed facial landmark??") 84 | 85 | 86 | drop_down = gr.Dropdown(list(MODEL_NAME.keys()), label="Select Model") 87 | 88 | with gr.Row(): 89 | with gr.Column(): 90 | 91 | inputs = [video,audio,check_box,drop_down] 92 | gr.Button("Sync").click( 93 | 94 | fn=func, 95 | inputs=inputs, 96 | outputs=outputs 97 | ) 98 | 99 | 100 | 101 | 102 | with gr.Blocks() as image_tab: 103 | 104 | 105 | with gr.Row(): 106 | 107 | with gr.Column(): 108 | video = gr.Image(type="filepath") 109 | 110 | audio = gr.Audio(source="upload", type="filepath") 111 | 112 | 113 | 114 | with gr.Column(): 115 | outputs = gr.PlayableVideo() 116 | 117 | 118 | with gr.Row(): 119 | 120 | with gr.Column(): 121 | 122 | check_box = gr.Checkbox(value=False,label="Do you want to visualize reconstructed facial landmark??") 123 | 124 | drop_down = gr.Dropdown(list(MODEL_NAME.keys()), label="Select Model") 125 | 126 | with gr.Row(): 127 | with gr.Column(): 128 | 129 | inputs = [video,audio,check_box,drop_down] 130 | gr.Button("Sync").click( 131 | 132 | fn=func, 133 | inputs=inputs, 134 | outputs=outputs 135 | ) 136 | 137 | 138 | 139 | with gr.Blocks() as main: 140 | 141 | gr.Markdown( 142 | """ 143 | # Audio-Visual Lip Synthesis! 144 | 145 | ### Creator : Wish Suharitdamrong 146 | 147 | 148 | Start typing below to see the output. 149 | """ 150 | ) 151 | gui = gr.TabbedInterface([video_tab,image_tab],['Using Video as input','Using Image as input']) 152 | 153 | 154 | main.launch() 155 | 156 | 157 | 158 | 159 | if __name__ == "__main__": 160 | gui() 161 | -------------------------------------------------------------------------------- /checkpoints/generator/README.md: -------------------------------------------------------------------------------- 1 | This folder shall contains the checkpoint of of generator model 2 | -------------------------------------------------------------------------------- /checkpoints/image2image/README.md: -------------------------------------------------------------------------------- 1 | This folde shall contains the checkpoint for Image2Image translation model from Makeittalk by **Yang Zhou** at repo: https://github.com/yzhou359/MakeItTalk 2 | -------------------------------------------------------------------------------- /checkpoints/syncnet/README.md: -------------------------------------------------------------------------------- 1 | This folder should contain checkpoint for lip landmark syncnet 2 | -------------------------------------------------------------------------------- /dummy/README.md: -------------------------------------------------------------------------------- 1 | put dummy input file here 2 | -------------------------------------------------------------------------------- /dummy/face.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterwisu/lip-synthesis/423fa4e2857fc420533481b01d1c163184d0d516/dummy/face.gif -------------------------------------------------------------------------------- /dummy/landmark.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterwisu/lip-synthesis/423fa4e2857fc420533481b01d1c163184d0d516/dummy/landmark.gif -------------------------------------------------------------------------------- /dummy/me.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterwisu/lip-synthesis/423fa4e2857fc420533481b01d1c163184d0d516/dummy/me.png -------------------------------------------------------------------------------- /dummy/test_audio.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterwisu/lip-synthesis/423fa4e2857fc420533481b01d1c163184d0d516/dummy/test_audio.mp3 -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: fyp 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - aom=3.5.0=h27087fc_0 10 | - blas=1.0=mkl 11 | - brotlipy=0.7.0=py37h540881e_1004 12 | - bzip2=1.0.8=h7f98852_4 13 | - ca-certificates=2022.9.24=ha878542_0 14 | - cffi=1.15.1=py37h43b0acd_1 15 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 16 | - cryptography=38.0.2=py37h5994e8b_1 17 | - cudatoolkit=11.3.1=h2bc3f7f_2 18 | - expat=2.5.0=h27087fc_0 19 | - ffmpeg=5.1.2=gpl_hc51e5dc_103 20 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 21 | - font-ttf-inconsolata=3.000=h77eed37_0 22 | - font-ttf-source-code-pro=2.038=h77eed37_0 23 | - font-ttf-ubuntu=0.83=hab24e00_0 24 | - fontconfig=2.14.1=hc2a2eb6_0 25 | - fonts-conda-ecosystem=1=0 26 | - fonts-conda-forge=1=0 27 | - freetype=2.12.1=hca18f0e_0 28 | - gettext=0.21.1=h27087fc_0 29 | - gmp=6.2.1=h58526e2_0 30 | - gnutls=3.7.8=hf3e180e_0 31 | - icu=70.1=h27087fc_0 32 | - idna=3.4=pyhd8ed1ab_0 33 | - intel-openmp=2021.4.0=h06a4308_3561 34 | - jpeg=9e=h166bdaf_2 35 | - lame=3.100=h166bdaf_1003 36 | - lcms2=2.12=hddcbb42_0 37 | - ld_impl_linux-64=2.39=hc81fddc_0 38 | - lerc=4.0.0=h27087fc_0 39 | - libdeflate=1.14=h166bdaf_0 40 | - libdrm=2.4.113=h166bdaf_0 41 | - libffi=3.4.2=h7f98852_5 42 | - libgcc-ng=12.2.0=h65d4601_19 43 | - libgomp=12.2.0=h65d4601_19 44 | - libiconv=1.17=h166bdaf_0 45 | - libidn2=2.3.4=h166bdaf_0 46 | - libnsl=2.0.0=h7f98852_0 47 | - libpciaccess=0.16=h516909a_0 48 | - libpng=1.6.38=h753d276_0 49 | - libsqlite=3.39.4=h753d276_0 50 | - libstdcxx-ng=12.2.0=h46fd767_19 51 | - libtasn1=4.19.0=h166bdaf_0 52 | - libtiff=4.4.0=h55922b4_4 53 | - libunistring=0.9.10=h7f98852_0 54 | - libuuid=2.32.1=h7f98852_1000 55 | - libva=2.16.0=h166bdaf_0 56 | - libvpx=1.11.0=h9c3ff4c_3 57 | - libwebp-base=1.2.4=h166bdaf_0 58 | - libxcb=1.13=h7f98852_1004 59 | - libxml2=2.10.3=h7463322_0 60 | - libzlib=1.2.13=h166bdaf_4 61 | - mkl=2021.4.0=h06a4308_640 62 | - mkl-service=2.4.0=py37h402132d_0 63 | - mkl_fft=1.3.1=py37h3e078e5_1 64 | - mkl_random=1.2.2=py37h219a48f_0 65 | - ncurses=6.3=h27087fc_1 66 | - nettle=3.8.1=hc379101_1 67 | - numpy-base=1.21.5=py37ha15fc14_3 68 | - openh264=2.3.1=h27087fc_1 69 | - openjpeg=2.5.0=h7d73246_1 70 | - openssl=3.0.5=h166bdaf_2 71 | - p11-kit=0.24.1=hc5aa10d_0 72 | - pillow=9.2.0=py37h850a105_2 73 | - pip=22.3=pyhd8ed1ab_0 74 | - pthread-stubs=0.4=h36c2ea0_1001 75 | - pycparser=2.21=pyhd8ed1ab_0 76 | - pyopenssl=22.1.0=pyhd8ed1ab_0 77 | - pysocks=1.7.1=py37h89c1867_5 78 | - python=3.7.12=hf930737_100_cpython 79 | - python_abi=3.7=2_cp37m 80 | - pytorch=1.12.1=py3.7_cuda11.3_cudnn8.3.2_0 81 | - pytorch-mutex=1.0=cuda 82 | - readline=8.1.2=h0f457ee_0 83 | - requests=2.28.1=pyhd8ed1ab_1 84 | - setuptools=65.5.0=pyhd8ed1ab_0 85 | - six=1.16.0=pyh6c4a22f_0 86 | - sqlite=3.39.4=h4ff8645_0 87 | - svt-av1=1.3.0=h27087fc_0 88 | - tk=8.6.12=h27826a3_0 89 | - torchaudio=0.12.1=py37_cu113 90 | - typing_extensions=4.4.0=pyha770c72_0 91 | - wheel=0.37.1=pyhd8ed1ab_0 92 | - x264=1!164.3095=h166bdaf_2 93 | - x265=3.5=h924138e_3 94 | - xorg-fixesproto=5.0=h7f98852_1002 95 | - xorg-kbproto=1.0.7=h7f98852_1002 96 | - xorg-libx11=1.7.2=h7f98852_0 97 | - xorg-libxau=1.0.9=h7f98852_0 98 | - xorg-libxdmcp=1.1.3=h7f98852_0 99 | - xorg-libxext=1.3.4=h7f98852_1 100 | - xorg-libxfixes=5.0.3=h7f98852_1004 101 | - xorg-xextproto=7.3.0=h7f98852_1002 102 | - xorg-xproto=7.0.31=h7f98852_1007 103 | - xz=5.2.6=h166bdaf_0 104 | - zstd=1.5.2=h6239696_4 105 | - pip: 106 | - absl-py==1.3.0 107 | - appdirs==1.4.4 108 | - attrs==22.1.0 109 | - audioread==3.0.0 110 | - cachetools==5.2.0 111 | - certifi==2022.9.24 112 | - cycler==0.11.0 113 | - decorator==5.1.1 114 | - face-alignment==1.3.5 115 | - fonttools==4.38.0 116 | - glob2==0.7 117 | - google-auth==2.13.0 118 | - google-auth-oauthlib==0.4.6 119 | - grpcio==1.50.0 120 | - imageio==2.22.2 121 | - importlib-metadata==5.0.0 122 | - joblib==1.2.0 123 | - kiwisolver==1.4.4 124 | - librosa==0.9.2 125 | - llvmlite==0.39.1 126 | - markdown==3.4.1 127 | - markupsafe==2.1.1 128 | - matplotlib==3.5.3 129 | - mediapipe==0.8.11 130 | - networkx==2.6.3 131 | - numba==0.56.3 132 | - numpy==1.21.6 133 | - oauthlib==3.2.2 134 | - opencv-contrib-python==4.6.0.66 135 | - opencv-python==4.6.0.66 136 | - packaging==21.3 137 | - pandas==1.3.5 138 | - pooch==1.6.0 139 | - protobuf==3.19.6 140 | - pyasn1==0.4.8 141 | - pyasn1-modules==0.2.8 142 | - pyparsing==3.0.9 143 | - python-dateutil==2.8.2 144 | - pytz==2022.5 145 | - pywavelets==1.3.0 146 | - requests-oauthlib==1.3.1 147 | - resampy==0.3.1 148 | - rsa==4.9 149 | - scikit-image==0.19.3 150 | - scikit-learn==1.0.2 151 | - scipy==1.7.3 152 | - seaborn==0.12.1 153 | - soundfile==0.11.0 154 | - tensorboard==2.10.1 155 | - tensorboard-data-server==0.6.1 156 | - tensorboard-plugin-wit==1.8.1 157 | - threadpoolctl==3.1.0 158 | - tifffile==2021.11.2 159 | - torch==1.12.1 160 | - torchvision==0.13.1 161 | - tqdm==4.64.1 162 | - urllib3==1.26.12 163 | - werkzeug==2.2.2 164 | - zipp==3.10.0 165 | prefix: /home/peter/anaconda3/envs/wav2lip 166 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterwisu/lip-synthesis/423fa4e2857fc420533481b01d1c163184d0d516/evaluation/README.md -------------------------------------------------------------------------------- /evaluation/calc_flop.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | from pthflops import count_ops 5 | import torch 6 | from src.models.attnlstm import LstmGen as Attnlstm 7 | from src.models.lstmgen import LstmGen as Lstm 8 | from src.models.syncnet import SyncNet 9 | from fvcore.nn import FlopCountAnalysis 10 | from ptflops import get_model_complexity_info 11 | 12 | 13 | 14 | def prepare_input_gen(resolution): 15 | au = torch.FloatTensor(1,5,1,80,18) 16 | lip = torch.FloatTensor(1,5,20,3) 17 | 18 | return dict({'au' :au , 'lip': lip}) 19 | 20 | 21 | def prepare_input_syncnet(resolution): 22 | 23 | audio = torch.FloatTensor(1,1,80,18) 24 | lip = torch.FloatTensor(1,5,60) 25 | 26 | return dict({'audio' :audio , 'lip': lip}) 27 | 28 | with torch.cuda.device(0): 29 | attnlstm_model = Attnlstm() 30 | 31 | lstm_model = Lstm() 32 | 33 | syncnet = SyncNet() 34 | 35 | attn_lstm_macs, attn_lstm_params = get_model_complexity_info(attnlstm_model, ((1,5,1,80,18),(1,5,20,3)), 36 | input_constructor=prepare_input_gen, 37 | as_strings=True, print_per_layer_stat=True, verbose=True, ) 38 | 39 | lstm_macs, lstm_params = get_model_complexity_info(lstm_model, ((1,5,1,80,18),(1,5,20,3)), 40 | input_constructor=prepare_input_gen, 41 | as_strings=True, print_per_layer_stat=True, verbose=True, ) 42 | 43 | 44 | syncnet_macs, syncnet_params = get_model_complexity_info(syncnet, ((1,5,1,80,18),(1,5,20,3)), 45 | input_constructor=prepare_input_syncnet, 46 | as_strings=True, print_per_layer_stat=True, verbose=True, ) 47 | 48 | 49 | 50 | print('{:<30} {:<8}'.format('Syncnet Computational complexity: ', syncnet_macs)) 51 | print('{:<30} {:<8}'.format('Syncnet Number of parameters: ', syncnet_params)) 52 | print() 53 | print('{:<30} {:<8}'.format('Attn_lstm Computational complexity: ', lstm_macs)) 54 | print('{:<30} {:<8}'.format('Attn_lstm Number of parameters: ', lstm_params)) 55 | print() 56 | print('{:<30} {:<8}'.format('Attn_lstm Computational complexity: ', attn_lstm_macs)) 57 | print('{:<30} {:<8}'.format('Attn_lstm Number of parameters: ', attn_lstm_params)) 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /evaluation/eval_vdo/README.md: -------------------------------------------------------------------------------- 1 | Folder for storing the result video in the test set (For benchmarks) 2 | -------------------------------------------------------------------------------- /evaluation/gen_eval_vdo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | import sys 6 | 7 | sys.path.append('../') 8 | 9 | from src.main.inference import Inference 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = "cuda" if use_cuda else "cpu" 14 | 15 | main_parser = argparse.ArgumentParser("Generate lip sync video for evaluation") 16 | 17 | main_parser.add_argument('--filelist', type=str, help="File path for filelist", default="./test_filelists/lrs3.txt") 18 | main_parser.add_argument('--data_root', type=str,default='/media/peter/peterwish/dataset/lrs3/test/') 19 | main_parser.add_argument('--model_type', type=str, default='attn_lstm') 20 | main_parser.add_argument('--result_dir', type=str, help="Directory to save result", default="./eval_vdo/Lrs3_adversarial") 21 | main_parser.add_argument('--generator_checkpoint', type=str, help="File path for Generator model checkpoint weights" ,default='/home/peter/Peter/audio-visual/fyp/checkpoints/generator/adversarial_generator.pth') 22 | main_parser.add_argument('--image2image_checkpoint', type=str, help="File path for Image2Image Translation model checkpoint weigths", default='../checkpoints/image2image/image2image.pth',required=False) 23 | 24 | main_args = main_parser.parse_args() 25 | 26 | 27 | 28 | def call_inference(model_type, gen_ckpt, img2img_ckpt, video, audio, save_path): 29 | 30 | parser = argparse.ArgumentParser(description="File for running Inference") 31 | 32 | parser.add_argument('--model_type', help='Type of generator model', default=model_type, type=str) 33 | 34 | parser.add_argument('--generator_checkpoint', type=str, help="File path for Generator model checkpoint weights" ,default=gen_ckpt) 35 | 36 | parser.add_argument('--image2image_checkpoint', type=str, help="File path for Image2Image Translation model checkpoint weigths", default=img2img_ckpt,required=False) 37 | 38 | parser.add_argument('--input_face', type=str, help="File path for input videos/images contain face",default=video, required=False) 39 | 40 | parser.add_argument('--input_audio', type=str, help="File path for input audio/speech as .wav files", default=audio, required=False) 41 | 42 | parser.add_argument('--fps', type=float, help= "Can only be specified only if using static image default(25 FPS)", default=25,required=False) 43 | 44 | parser.add_argument('--fl_detector_batchsize', type=int , help='Batch size for landmark detection', default = 32) 45 | 46 | parser.add_argument('--generator_batchsize', type=int, help="Batch size for Generator model", default=5) 47 | 48 | parser.add_argument('--output_name', type=str , help="Name and path of the output file", default=save_path) 49 | 50 | parser.add_argument('--vis_fl', type=bool, help="Visualize Facial Landmark ??", default=False) 51 | 52 | parser.add_argument('--test_img2img', type=bool, help="Testing image2image module with no lip generation" , default=False) 53 | 54 | parser.add_argument('--only_fl', type=bool, help="Visualize only Facial Landmark ??", default=False) 55 | 56 | args = parser.parse_args() 57 | 58 | eval_result = Inference(args=args) 59 | 60 | eval_result.start() 61 | 62 | 63 | 64 | def main(): 65 | 66 | data_root = main_args.data_root 67 | result_folder = main_args.result_dir 68 | model_type = main_args.model_type 69 | gen_ckpt = main_args.generator_checkpoint 70 | img2img_ckpt = main_args.image2image_checkpoint 71 | 72 | 73 | 74 | if not os.path.exists(result_folder): 75 | 76 | os.mkdir(result_folder) 77 | 78 | 79 | 80 | with open(main_args.filelist, 'r') as filelist: 81 | 82 | lines = filelist.readlines() 83 | 84 | for idx, line in enumerate(tqdm(lines)): 85 | 86 | 87 | 88 | audio , video = line.strip().split() 89 | 90 | audio = os.path.join(data_root, audio) + '.mp4' 91 | 92 | video = os.path.join(data_root, video) + '.mp4' 93 | 94 | save_path = os.path.join(result_folder, "{}.mp4".format(idx)) 95 | 96 | print(save_path) 97 | 98 | print(audio) 99 | 100 | print(video) 101 | 102 | 103 | call_inference(model_type,gen_ckpt,img2img_ckpt,video,audio,save_path) 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | main() 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /evaluation/temp/README.md: -------------------------------------------------------------------------------- 1 | Folder for storing a temporay file during a video generation during evaluation 2 | -------------------------------------------------------------------------------- /filelists/README.md: -------------------------------------------------------------------------------- 1 | Put LRS2 filelists here for training (pretrain.txt, train.txt, val.txt, test.txt) 2 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This code is orginally from ***Wav2Lip*** repository 4 | 5 | Link : https://github.com/Rudrabha/Wav2Lip 6 | 7 | """ 8 | 9 | 10 | 11 | from glob import glob 12 | import os 13 | 14 | 15 | 16 | class HParams: 17 | def __init__(self, **kwargs): 18 | self.data = {} 19 | 20 | for key, value in kwargs.items(): 21 | self.data[key] = value 22 | 23 | def __getattr__(self, key): 24 | if key not in self.data: 25 | raise AttributeError("'HParams' object has no attribute %s" % key) 26 | return self.data[key] 27 | 28 | def set_hparam(self, key, value): 29 | self.data[key] = value 30 | 31 | 32 | # Default hyperparameters 33 | hparams = HParams( 34 | 35 | ###### Audio preprocessing ############# 36 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 37 | # network 38 | rescale=True, # Whether to rescale audio prior to preprocessing 39 | rescaling_max=0.9, # Rescaling value 40 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 41 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 42 | # Does not work if n_ffit is not multiple of hop_size!! 43 | use_lws=False, 44 | 45 | n_fft=800, # Extr window size is filled with 0 paddings to match this parameter 46 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 47 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 48 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 49 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 50 | # Mel and Linear spectrograms normalization/scaling and clipping 51 | signal_normalization=True, 52 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 53 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 54 | symmetric_mels=True, 55 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 56 | # faster and cleaner convergence) 57 | max_abs_value=4., 58 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 59 | # be too big to avoid gradient explosion, 60 | # not too small for fast convergence) 61 | # Contribution by @begeekmyfriend 62 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 63 | # levels. Also allows for better G&L phase reconstruction) 64 | preemphasize=True, # whether to apply filter 65 | preemphasis=0.97, # filter coefficient. 66 | # Limits 67 | min_level_db=-100, 68 | ref_level_db=20, 69 | fmin=55, 70 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 71 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 72 | fmax=7600, # To be increased/reduced depending on data. 73 | 74 | 75 | 76 | 77 | ###################### Our training parameters ################################# 78 | save_optimizer_state=True, 79 | ########### Dataset & Dataloder ######### 80 | fps = 25, 81 | num_workers =8, 82 | # window_size = 18, 83 | # window_step = 5, 84 | ########## SyncNet hyperparameters ############### 85 | syncnet_batch_size = 128, 86 | syncnet_lr = 1e-4, 87 | syncnet_nepochs = 500, 88 | ######### Generator hyperparameters ################ 89 | gen_batch_size = 16, 90 | gen_gen_lr = 1e-3, 91 | gen_disc_lr = 5e-6, # refer to syncnet_lr for training discrimiantor together with generator ** only for adversarial training** 92 | gen_nepochs = 100, 93 | gen_recon_coeff = 0.8, # only use when trainig with pretrain or end2end 94 | gen_sync_coeff = 0.2 # only use when trainig with pretrain or end2end 95 | 96 | ) 97 | 98 | 99 | def hparams_debug_string(): 100 | values = hparams.values() 101 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 102 | return "Hyperparameters:\n" + "\n".join(hp) 103 | -------------------------------------------------------------------------------- /logs/generator/README.md: -------------------------------------------------------------------------------- 1 | This folder shall contain log for generator 2 | -------------------------------------------------------------------------------- /logs/syncnet/README.md: -------------------------------------------------------------------------------- 1 | This folder shall contain log for syncnet 2 | -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code originally from ***Wav2Lip*** repository 3 | 4 | Link repo: https://github.com/Rudrabha/Wav2Lip 5 | 6 | This code has been modified to preprocess Facial Landmark dataset instead of face images 7 | 8 | """ 9 | 10 | import sys 11 | 12 | if sys.version_info[0] < 3 and sys.version_info[1] < 2: 13 | raise Exception("Must be using >= Python 3.2") 14 | 15 | from os import listdir, path 16 | import multiprocessing as mp 17 | from concurrent.futures import ThreadPoolExecutor, as_completed 18 | import numpy as np 19 | import argparse, os, cv2, traceback, subprocess 20 | from tqdm import tqdm 21 | from glob import glob 22 | import utils.audio as audio 23 | from hparams import hparams as hp 24 | import face_alignment 25 | 26 | import torch 27 | import matplotlib.pyplot as plt 28 | from utils.plot import vis_landmark_on_img 29 | 30 | print('Running Preprocess FL ') 31 | parser = argparse.ArgumentParser() 32 | 33 | parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int) 34 | parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=16, type=int) 35 | parser.add_argument("--data_root", help="Root folder of the LRS2 dataset" , default="/media/peter/peterwish/dataset/lrs2_v1/mvlrs_v1/main/") 36 | parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", default="../lrs2_test/") 37 | 38 | args = parser.parse_args() 39 | 40 | fa = [face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False, 41 | device='cuda:{}'.format(id)) for id in range(args.ngpu)] 42 | 43 | 44 | 45 | fa_landmark = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, device='cuda', flip_input=True) 46 | template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}' 47 | # template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}' 48 | 49 | 50 | 51 | 52 | def detect_bug_fl_batch(fls): 53 | 54 | for i in range(len(fls)): 55 | 56 | if len(fls[i]) > 68: 57 | 58 | bug = fls[i] 59 | 60 | fl = bug[:68] 61 | 62 | fls[i] = fl 63 | 64 | return fls 65 | 66 | 67 | 68 | 69 | def process_video_file(vfile, args, gpu_id): 70 | video_stream = cv2.VideoCapture(vfile) 71 | 72 | frames = [] 73 | while 1: 74 | still_reading, frame = video_stream.read() 75 | if not still_reading: 76 | video_stream.release() 77 | break 78 | #print(np.array(frames).shape) 79 | frame = cv2.resize(frame, (256,256)) 80 | frames.append(frame) 81 | 82 | vidname = os.path.basename(vfile).split('.')[0] 83 | dirname = vfile.split('/')[-2] 84 | 85 | fulldir = path.join(args.preprocessed_root, dirname, vidname) 86 | os.makedirs(fulldir, exist_ok=True) 87 | 88 | batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)] 89 | i = -1 90 | 91 | for batch in batches: 92 | 93 | batch = np.array(batch) 94 | 95 | batch = np.transpose(batch, (0,3,1,2)) 96 | 97 | batch = torch.from_numpy(batch) 98 | 99 | preds = fa[gpu_id].get_landmarks_from_batch(batch) 100 | 101 | preds = detect_bug_fl_batch(preds) 102 | 103 | for j, fl in enumerate(preds): 104 | 105 | i +=1 106 | 107 | if len(fl) == 0: 108 | 109 | continue 110 | 111 | np.savetxt(path.join(fulldir,'{}.txt'.format(i)),fl ,fmt='%.4f') 112 | 113 | 114 | 115 | def process_audio_file(vfile, args): 116 | vidname = os.path.basename(vfile).split('.')[0] 117 | dirname = vfile.split('/')[-2] 118 | 119 | fulldir = path.join(args.preprocessed_root, dirname, vidname) 120 | os.makedirs(fulldir, exist_ok=True) 121 | 122 | wavpath = path.join(fulldir, 'audio.wav') 123 | 124 | command = template.format(vfile, wavpath) 125 | subprocess.call(command, shell=True) 126 | 127 | 128 | def mp_handler(job): 129 | vfile, args, gpu_id = job 130 | try: 131 | process_video_file(vfile, args, gpu_id) 132 | except KeyboardInterrupt: 133 | exit(0) 134 | except: 135 | traceback.print_exc() 136 | 137 | 138 | def main(args): 139 | print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu)) 140 | 141 | print("Saving Path at {}".format(args.preprocessed_root)) 142 | 143 | filelist = glob(path.join(args.data_root, '*/*.mp4')) 144 | 145 | jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)] 146 | 147 | 148 | p = ThreadPoolExecutor(args.ngpu) 149 | 150 | 151 | futures = [p.submit(mp_handler, j) for j in jobs] 152 | 153 | _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))] 154 | 155 | print('Dumping audios...') 156 | 157 | for vfile in tqdm(filelist): 158 | try: 159 | process_audio_file(vfile, args) 160 | except KeyboardInterrupt: 161 | exit(0) 162 | except: 163 | traceback.print_exc() 164 | continue 165 | 166 | 167 | if __name__ == '__main__': 168 | main(args) 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | aiohttp==3.8.3 3 | aiosignal==1.3.1 4 | altair==4.2.0 5 | anyio==3.6.2 6 | appdirs==1.4.4 7 | argon2-cffi==21.3.0 8 | argon2-cffi-bindings==21.2.0 9 | async-timeout==4.0.2 10 | asynctest==0.13.0 11 | attrs==22.1.0 12 | audioread==3.0.0 13 | backcall==0.2.0 14 | beautifulsoup4==4.11.2 15 | bleach==6.0.0 16 | brotlipy==0.7.0 17 | cachetools==5.2.0 18 | certifi==2022.12.7 19 | cffi==1.15.1 20 | charset-normalizer==2.1.1 21 | click==8.1.3 22 | cryptography==38.0.2 23 | cycler==0.11.0 24 | debugpy==1.5.1 25 | decorator==5.1.1 26 | defusedxml==0.7.1 27 | entrypoints==0.4 28 | face-alignment==1.3.5 29 | fastapi==0.88.0 30 | fastjsonschema==2.16.2 31 | ffmpy==0.3.0 32 | filelock==3.9.0 33 | fonttools==4.38.0 34 | frozenlist==1.3.3 35 | fsspec==2022.11.0 36 | fvcore==0.1.5.post20221221 37 | glob2==0.7 38 | google-auth==2.13.0 39 | google-auth-oauthlib==0.4.6 40 | gradio==3.15.0 41 | grpcio==1.50.0 42 | h11==0.14.0 43 | httpcore==0.16.3 44 | httpx==0.23.1 45 | huggingface-hub==0.12.1 46 | idna==3.4 47 | imageio==2.22.2 48 | importlib-metadata==5.0.0 49 | importlib-resources==5.10.1 50 | iopath==0.1.10 51 | ipykernel==6.15.2 52 | ipython==7.31.1 53 | ipython-genutils==0.2.0 54 | jedi==0.18.1 55 | Jinja2==3.1.2 56 | joblib==1.2.0 57 | jsonschema==4.17.3 58 | jupyter_client==7.4.8 59 | jupyter_core==4.12.0 60 | jupyter-server==1.23.6 61 | jupyterlab-pygments==0.2.2 62 | kaggle==1.5.12 63 | kiwisolver==1.4.4 64 | librosa==0.9.2 65 | linkify-it-py==1.0.3 66 | llvmlite==0.39.1 67 | Markdown==3.4.1 68 | markdown-it-py==2.1.0 69 | MarkupSafe==2.1.1 70 | matplotlib==3.5.3 71 | matplotlib-inline==0.1.6 72 | mdit-py-plugins==0.3.3 73 | mdurl==0.1.2 74 | mediapipe==0.8.11 75 | mistune==2.0.5 76 | multidict==6.0.4 77 | nbclassic==0.5.1 78 | nbclient==0.7.2 79 | nbconvert==7.2.9 80 | nbformat==5.7.3 81 | nest-asyncio==1.5.6 82 | networkx==2.6.3 83 | notebook==6.5.2 84 | notebook_shim==0.2.2 85 | numba==0.56.3 86 | numpy==1.21.6 87 | oauthlib==3.2.2 88 | opencv-contrib-python==4.6.0.66 89 | opencv-python==4.6.0.66 90 | orjson==3.8.3 91 | packaging==22.0 92 | pandas==1.3.5 93 | pandocfilters==1.5.0 94 | parso==0.8.3 95 | pexpect==4.8.0 96 | pickleshare==0.7.5 97 | Pillow==9.2.0 98 | pip==22.3 99 | pkgutil_resolve_name==1.3.10 100 | pooch==1.6.0 101 | portalocker==2.7.0 102 | prometheus-client==0.16.0 103 | prompt-toolkit==3.0.36 104 | protobuf==3.19.6 105 | psutil==5.9.0 106 | ptyprocess==0.7.0 107 | pyasn1==0.4.8 108 | pyasn1-modules==0.2.8 109 | pycparser==2.21 110 | pycryptodome==3.16.0 111 | pydantic==1.10.2 112 | pydub==0.25.1 113 | Pygments==2.11.2 114 | pyOpenSSL==22.1.0 115 | pyparsing==3.0.9 116 | pyrsistent==0.19.2 117 | PySocks==1.7.1 118 | python-dateutil==2.8.2 119 | python-multipart==0.0.5 120 | python-slugify==8.0.0 121 | pytorch-fid==0.3.0 122 | pytz==2022.5 123 | PyWavelets==1.3.0 124 | PyYAML==6.0 125 | pyzmq==23.2.0 126 | regex==2022.10.31 127 | requests==2.28.1 128 | requests-oauthlib==1.3.1 129 | resampy==0.3.1 130 | rfc3986==1.5.0 131 | rsa==4.9 132 | scikit-image==0.19.3 133 | scikit-learn==1.0.2 134 | scipy==1.7.3 135 | seaborn==0.12.1 136 | Send2Trash==1.8.0 137 | setuptools==65.5.0 138 | six==1.16.0 139 | sniffio==1.3.0 140 | soundfile==0.11.0 141 | soupsieve==2.4 142 | starlette==0.22.0 143 | tabulate==0.9.0 144 | tensorboard==2.10.1 145 | tensorboard-data-server==0.6.1 146 | tensorboard-plugin-wit==1.8.1 147 | termcolor==2.2.0 148 | terminado==0.17.1 149 | text-unidecode==1.3 150 | threadpoolctl==3.1.0 151 | tifffile==2021.11.2 152 | tinycss2==1.2.1 153 | tokenizers==0.13.2 154 | toolz==0.12.0 155 | torch==1.12.1 156 | torchaudio==0.12.1 157 | torchvision==0.13.1 158 | tornado==6.2 159 | tqdm==4.64.1 160 | traitlets==5.7.1 161 | transformers==4.26.1 162 | typing_extensions==4.4.0 163 | uc-micro-py==1.0.1 164 | urllib3==1.26.14 165 | uvicorn==0.20.0 166 | wcwidth==0.2.5 167 | webencodings==0.5.1 168 | websocket-client==1.5.1 169 | websockets==10.4 170 | Werkzeug==2.2.2 171 | wheel==0.37.1 172 | yacs==0.1.8 173 | yarl==1.8.2 174 | zipp==3.10.0 175 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | # Inference model 2 | 3 | 4 | 5 | import argparse 6 | 7 | from src.main.inference import Inference 8 | import os 9 | 10 | 11 | MODEL_TYPE = ['lstm','attn_lstm'] 12 | 13 | parser = argparse.ArgumentParser(description="File for running Inference") 14 | 15 | parser.add_argument('--model_type', help='Type of generator model', default='attn_lstm', type=str) 16 | 17 | parser.add_argument('--generator_checkpoint', type=str, help="File path for Generator model checkpoint weights" ,default='./checkpoints/generator/attn_lstmgen_syncloss.pth') 18 | 19 | parser.add_argument('--image2image_checkpoint', type=str, help="File path for Image2Image Translation model checkpoint weigths", default='./checkpoints/image2image/ckpt_116_i2i_comb.pth',required=False) 20 | 21 | parser.add_argument('--input_face', type=str, help="File path for input videos/images contain face",default='./dummy/me.png', required=False) 22 | 23 | parser.add_argument('--input_audio', type=str, help="File path for input audio/speech as .wav files", default='./dummy/test_audio.mp3', required=False) 24 | 25 | # parser.add_argument('--output_path', type=str, help="Path for saving the result", default='result.mp4', required=False) 26 | 27 | parser.add_argument('--fps', type=float, help= "Can only be specified only if using static image default(25 FPS)", default=25,required=False) 28 | 29 | parser.add_argument('--fl_detector_batchsize', type=int , help='Batch size for landmark detection', default = 64) 30 | 31 | parser.add_argument('--generator_batchsize', type=int, help="Batch size for Generator model", default=2) 32 | 33 | parser.add_argument('--seq_len', type=int, help="Sequence length for Generator model", default=5) 34 | 35 | parser.add_argument('--output_name', type=str , help="Name and path of the output file", default="results.mp4") 36 | 37 | parser.add_argument('--vis_fl', type=bool, help="Visualize Facial Landmark ??", default=False) 38 | 39 | parser.add_argument('--only_fl', type=bool, help="Visualize only Facial Landmark ??", default=False) 40 | 41 | parser.add_argument('--test_img2img', type=bool, help="Testing image2image module with no lip generation" , default=False) 42 | 43 | 44 | 45 | 46 | args = parser.parse_args() 47 | 48 | 49 | def main(args): 50 | 51 | 52 | if (args.model_type not in MODEL_TYPE): 53 | 54 | raise ValueError("Argument --model_type mus be in {}".format(MODEL_TYPE)) 55 | 56 | import time 57 | 58 | 59 | 60 | start_time = time.time() 61 | 62 | inference = Inference(args=args) 63 | 64 | inference.start() 65 | 66 | end_time = time.time() 67 | 68 | duration = end_time - start_time 69 | 70 | 71 | print("Time Taken {}".format(duration)) 72 | 73 | 74 | if __name__ == "__main__": 75 | 76 | main(args=args) 77 | -------------------------------------------------------------------------------- /run_train_generator.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | from src.main.generator import TrainGenerator 5 | import os 6 | #from utils.utils import str2bool 7 | 8 | 9 | TRAIN_TYPE = ["pretrain","normal","adversarial"] 10 | MODEL_TYPE = ['lstm','attn_lstm'] 11 | 12 | parser = argparse.ArgumentParser(description='Code for training a lip sync generator via landmark') 13 | """ ---------- Dataset --------""" 14 | parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", default=None, type=str) 15 | 16 | """ --------- Generator --------""" 17 | parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', default='./checkpoints/generator/', type=str) 18 | parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoints', default=None, type=str) 19 | parser.add_argument('--model_type', help='Type of generator model', default='attn_lstm', type=str) 20 | 21 | """---------- SyncNet ----------""" 22 | 23 | parser.add_argument('--train_type', 24 | help='--train_type select "pretrain" for training generator with pretrain SyncNet, "normal" for training only generator without SyncNet, and "adversarial" for training generator and SyncNet together', 25 | default="pretrain", type=str) 26 | parser.add_argument('--pretrain_syncnet_path', help="Path of pretrain syncnet", default='./checkpoints/syncnet/landmark_syncnet.pth') 27 | 28 | """---------- Save name --------""" 29 | 30 | parser.add_argument("--checkpoint_interval", help="Checkpoint interval and eval video", default=20, type=int) 31 | parser.add_argument('--save_name', help='name of a save', default="train_attn_generator.pth", type=str) 32 | 33 | 34 | args = parser.parse_args() 35 | 36 | 37 | def main(): 38 | 39 | 40 | if (args.model_type not in MODEL_TYPE): 41 | 42 | raise ValueError("Argument --model_type mus be in {}".format(MODEL_TYPE)) 43 | 44 | if (args.train_type not in TRAIN_TYPE): 45 | 46 | raise ValueError("Argument --train_type mus be in {}".format(TRAIN_TYPE)) 47 | 48 | if not os.path.exists(args.data_root): 49 | raise ValueError("Data root does not exist") 50 | 51 | if args.checkpoint_path is not None and not os.path.exists(args.checkpoint_path): 52 | 53 | raise ValueError("Checkpoint for Generator does not exists") 54 | 55 | if args.save_name is None: 56 | 57 | raise ValueError('Please provide a save name') 58 | 59 | if args.checkpoint_dir is None: 60 | 61 | raise ValueError("Please provide a checkpoint_dir") 62 | 63 | if args.train_type == "pretrain" and not os.path.exists(args.pretrain_syncnet_path): 64 | 65 | raise ValueError("Please provide a checkpoint_path for pretrain_syncnet for using pretrain discriminator") 66 | 67 | # if create checkpoint dir if it does not exist 68 | if not os.path.exists(args.checkpoint_dir): 69 | os.mkdir(args.checkpoint_dir) 70 | 71 | model = TrainGenerator(args=args) 72 | 73 | model.start_training() 74 | 75 | 76 | if __name__ == "__main__": 77 | 78 | main() 79 | -------------------------------------------------------------------------------- /run_train_syncnet.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import warnings 4 | from src.main.syncnet import TrainSyncNet 5 | import os 6 | from utils.utils import str2bool 7 | from datetime import date 8 | warnings.simplefilter(action='ignore', category=FutureWarning) 9 | 10 | 11 | parser = argparse.ArgumentParser(description="Code for training SyncNet with the lip key points ") 12 | 13 | parser.add_argument("--data_root", help="Root for preprocessed lip keypoint dataset", default='None') 14 | 15 | parser.add_argument("--checkpoint_dir", help="dir to save checkpoints for SyncNet for lip keypoint", default='./checkpoints/syncnet/', 16 | type=str) 17 | 18 | parser.add_argument('--checkpoint_path', help="Resume from checkpoints or testing a model from checkpoints", default=None) 19 | 20 | parser.add_argument('--save_name', help="name of a save", default="landmark_syncnet",type=str) 21 | 22 | parser.add_argument('--do_train' , help="Train a mode or testing a model", default='True' , type=str2bool) 23 | 24 | 25 | args = parser.parse_args() 26 | 27 | def main(): 28 | 29 | if args.do_train : 30 | 31 | if not os.path.exists(args.data_root): 32 | raise ValueError("Data root does not exist") 33 | 34 | 35 | if args.checkpoint_path is not None and not os.path.exists(args.checkpoint_path): 36 | 37 | raise ValueError("Checkpoint for SyncNet does not exists") 38 | 39 | 40 | if args.checkpoint_dir is None: 41 | 42 | raise ValueError("Please provide a checkpoint_dir") 43 | 44 | # if create checkpoint dir if it does not exist 45 | if not os.path.exists(args.checkpoint_dir): 46 | os.mkdir(args.checkpoint_dir) 47 | 48 | 49 | model = TrainSyncNet(args=args) 50 | model.start_training() 51 | 52 | else : 53 | 54 | if args.checkpoint_path is None: 55 | 56 | raise ValueError("Required the path of model's checkpoint for Testing model --checkpoint_path") 57 | 58 | if not os.path.exists(args.checkpoint_path): 59 | 60 | raise ValueError("Give path for model checkpoint does not exists") 61 | 62 | 63 | model = TrainSyncNet(args=args) 64 | model.start_testing() 65 | 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | main() 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # This file is empty but required for Python to treat this folder as a package Do not delete this file !!!!!! -------------------------------------------------------------------------------- /src/dataset/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This code is originally from ***Wav2Lip*** repository 4 | 5 | link : https://github.com/Rudrabha/Wav2Lip 6 | 7 | This code has been modified to load a datasets containing Facial Landmarks instead face images 8 | 9 | """ 10 | 11 | from hparams import hparams 12 | from utils.wav2lip import get_fl_list 13 | from os.path import basename, isfile, dirname, join 14 | #import cv2 15 | import os 16 | import utils.audio as audio 17 | import time 18 | import random 19 | from glob import glob 20 | import torch 21 | import numpy as np 22 | import warnings 23 | from utils.utils import procrustes 24 | from utils.plot import plot_scatter_facial_landmark , plot_lip_landmark 25 | 26 | 27 | 28 | warnings.simplefilter(action='ignore', category=FutureWarning) 29 | syncnet_T = 5 30 | syncnet_mel_step_size = 16 31 | 32 | 33 | class Dataset(object): 34 | def __init__(self, split, args): 35 | 36 | self.split = split 37 | self.all_videos = get_fl_list(args.data_root, split) 38 | 39 | def get_frame_id(self, frame): 40 | return int(basename(frame).split('.')[0]) 41 | 42 | def get_window(self, start_frame): 43 | start_id = self.get_frame_id(start_frame) 44 | vidname = dirname(start_frame) 45 | 46 | window_fnames = [] 47 | for frame_id in range(start_id, start_id + syncnet_T): 48 | frame = join(vidname, '{}.txt'.format(frame_id)) 49 | if not isfile(frame): 50 | return None 51 | window_fnames.append(frame) 52 | return window_fnames 53 | 54 | def read_window(self, window_fnames): 55 | if window_fnames is None: return None 56 | window = [] 57 | for fname in window_fnames: 58 | 59 | fl = np.loadtxt(fname) # load facial landmark at that particular frame 60 | 61 | # check whether is facial landmark can be frontalize or not? 62 | try: 63 | 64 | 65 | #fl = fl[48:,:] # take only lip 66 | 67 | #norm_lip, _ = procrustes(lip) # normalize only lip 68 | 69 | #fl[48:,:] = norm_lip # note that in the fl array containing 3d facial landamrks but lip are in normalize form 70 | 71 | #from utils.plot import plot_lip_landmark 72 | 73 | #fl = fl[48:,:] 74 | fl ,_ =procrustes(fl) 75 | 76 | #plot_lip_landmark(fl) 77 | fl = fl[48:,:] 78 | 79 | except Exception as e: 80 | 81 | fl = None 82 | 83 | if fl is None: 84 | return None 85 | 86 | window.append(fl) 87 | 88 | return window 89 | 90 | def crop_audio_window(self, spec, start_frame): 91 | if type(start_frame) == int: 92 | start_frame_num = start_frame 93 | else: 94 | start_frame_num = self.get_frame_id(start_frame) 95 | start_idx = int(80. * (start_frame_num / float(hparams.fps))) 96 | 97 | end_idx = start_idx + syncnet_mel_step_size 98 | 99 | 100 | return spec[start_idx: end_idx, :] 101 | 102 | def get_segmented_mels(self, spec, start_frame): 103 | mels = [] 104 | assert syncnet_T == 5 105 | start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing 106 | if start_frame_num - 2 < 0: 107 | return None 108 | 109 | for i in range(start_frame_num, start_frame_num + syncnet_T): 110 | m = self.crop_audio_window(spec, i - 2) 111 | if m.shape[0] != syncnet_mel_step_size: 112 | return None 113 | mels.append(m.T) 114 | 115 | 116 | mels = np.asarray(mels) 117 | 118 | return mels 119 | 120 | def __len__(self): 121 | return len(self.all_videos) 122 | 123 | def __getitem__(self, idx): 124 | 125 | 126 | while 1: 127 | 128 | idx = random.randint(0, len(self.all_videos) - 1) 129 | vidname = self.all_videos[idx] 130 | 131 | self.vidname=vidname 132 | 133 | 134 | img_names = list(glob(join(vidname, '*.txt'))) 135 | 136 | if len(img_names) <= 3 * syncnet_T: 137 | continue 138 | 139 | img_name = random.choice(img_names) 140 | wrong_img_name = random.choice(img_names) 141 | while wrong_img_name == img_name: 142 | wrong_img_name = random.choice(img_names) 143 | 144 | window_fnames = self.get_window(img_name) 145 | wrong_window_fnames = self.get_window(wrong_img_name) 146 | if window_fnames is None or wrong_window_fnames is None: 147 | continue 148 | 149 | window = self.read_window(window_fnames) 150 | if window is None: 151 | continue 152 | 153 | wrong_window = self.read_window(wrong_window_fnames) 154 | if wrong_window is None: 155 | continue 156 | 157 | try: 158 | wavpath = join(vidname, "audio.wav") 159 | wav = audio.load_wav(wavpath, hparams.sample_rate) 160 | 161 | orig_mel = audio.melspectrogram(wav).T 162 | except Exception as e: 163 | continue 164 | 165 | mel = self.crop_audio_window(orig_mel.copy(), img_name) 166 | 167 | if mel.shape[0] != syncnet_mel_step_size: 168 | 169 | continue 170 | 171 | indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) 172 | if indiv_mels is None: 173 | 174 | continue 175 | 176 | y = np.array(window.copy()) 177 | 178 | x = np.array(wrong_window) 179 | x = torch.FloatTensor(x[:, :, :]) 180 | x = [x[0] for i in range(x.size(0))] 181 | x = torch.stack(x , dim=0) 182 | mel = torch.FloatTensor(mel.T).unsqueeze(0) 183 | indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1) 184 | y = torch.FloatTensor(y[:,:, :]) 185 | 186 | return x, indiv_mels, mel, y 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /src/dataset/syncnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This code is originally from ***Wav2Lip*** repository 4 | 5 | Link : https://github.com/Rudrabha/Wav2Lip 6 | 7 | The code have been modified to load a dataset containing Facial Landmarks instead of face images 8 | 9 | """ 10 | import os 11 | from os.path import dirname, join, basename, isfile 12 | import random 13 | from hparams import hparams 14 | from glob import glob 15 | import argparse 16 | import torch 17 | import numpy as np 18 | import utils.audio as audio 19 | import math 20 | import matplotlib.pyplot as plt 21 | from utils.wav2lip import get_fl_list 22 | from utils.utils import procrustes 23 | import warnings 24 | 25 | 26 | 27 | 28 | syncnet_T = 5 29 | syncnet_mel_step_size = 16 30 | 31 | 32 | class Dataset(torch.utils.data.Dataset): 33 | 34 | def __init__(self, split, args): 35 | 36 | self.all_videos = get_fl_list(args.data_root, split) 37 | 38 | 39 | def get_frame_id(self, frame): 40 | 41 | return int(basename(frame).split('.')[0]) 42 | 43 | def get_window(self, start_frame): 44 | 45 | start_id = self.get_frame_id(start_frame) 46 | vidname = dirname(start_frame) 47 | 48 | window_fnames = [] 49 | 50 | for frame_id in range(start_id, start_id + syncnet_T): 51 | 52 | frame = join(vidname, '{}.txt'.format(frame_id)) 53 | 54 | if not isfile(frame): 55 | return None 56 | window_fnames.append(frame) 57 | 58 | return window_fnames 59 | 60 | # Get audio of five frame 61 | def crop_audio_window(self, spec, start_frame): 62 | 63 | start_frame_num = self.get_frame_id(start_frame) 64 | start_idx = int(80. * (start_frame_num / float(hparams.fps))) 65 | 66 | end_idx = start_idx + syncnet_mel_step_size 67 | 68 | return spec[start_idx:end_idx, :] 69 | 70 | 71 | # Lenght of dataset 72 | def __len__(self): 73 | 74 | return len(self.all_videos) 75 | 76 | def __getitem__(self, idx): 77 | 78 | while 1: 79 | 80 | # Random idx between len of dataset 81 | 82 | idx = random.randint(0, len(self.all_videos) - 1) 83 | 84 | # get a video at idx 85 | 86 | vidname = self.all_videos[idx] 87 | 88 | 89 | # get all image from dir vidname (11231/*.txt) in a list 90 | 91 | fl_names = list(glob(join(vidname, '*.txt'))) 92 | 93 | if len(fl_names) <= 3 * syncnet_T: 94 | continue 95 | 96 | fl_name = random.choice(fl_names) 97 | 98 | wrong_fl_name = random.choice(fl_names) 99 | 100 | while wrong_fl_name == fl_name: 101 | wrong_fl_name = random.choice(fl_names) 102 | 103 | # choose which image to be x or y 104 | 105 | if random.choice([True, False]): 106 | 107 | y = torch.ones(1).float() 108 | chosen = fl_name 109 | 110 | else: 111 | 112 | y = torch.zeros(1).float() 113 | chosen = wrong_fl_name 114 | 115 | # get all the fl path of each frame in vdo 116 | 117 | window_fnames = self.get_window(chosen) 118 | 119 | if window_fnames is None: 120 | continue 121 | 122 | # load all frame of vdo 123 | window = [] 124 | all_read = True 125 | 126 | for fname in window_fnames: 127 | 128 | fl = np.loadtxt(fname) 129 | 130 | #fl = fl[48:,:] 131 | fl,_ = procrustes(fl) 132 | fl = fl[48:,:] 133 | 134 | #from utils.plot import plot_scatter_facial_landmark, plot_lip_landmark 135 | #plot_lip_landmark(fl) 136 | 137 | if fl is None: 138 | all_read = False 139 | break 140 | 141 | 142 | window.append(fl) 143 | 144 | if not all_read: continue 145 | 146 | # get audio 147 | 148 | try: 149 | wavpath = join(vidname, "audio.wav") 150 | wav = audio.load_wav(wavpath, hparams.sample_rate) 151 | 152 | orig_mel = audio.melspectrogram(wav).T #(timestep, 80) 153 | except Exception as e: 154 | continue 155 | 156 | # get audio of 5 frames 157 | 158 | mel = self.crop_audio_window(orig_mel.copy(), fl_name) 159 | 160 | if (mel.shape[0] != syncnet_mel_step_size): 161 | continue 162 | 163 | window = np.array(window).reshape(5, -1) 164 | x = torch.FloatTensor(window) 165 | 166 | mel = torch.FloatTensor(mel.T).unsqueeze(0) 167 | 168 | 169 | 170 | 171 | return x, mel, y 172 | -------------------------------------------------------------------------------- /src/main/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parallel import DataParallel 4 | from torch import optim 5 | import numpy as np 6 | from hparams import hparams 7 | import os 8 | from tqdm import tqdm 9 | import matplotlib.pyplot as plt 10 | from src.dataset.generator import Dataset 11 | from torch.utils import data as data_utils 12 | from src.models.syncnet import SyncNet 13 | from torch.utils.tensorboard import SummaryWriter 14 | from utils.plot import plot_comp, plot_lip_comparision 15 | from utils.wav2lip import load_checkpoint , save_checkpoint 16 | from utils.utils import save_logs,load_logs 17 | from utils.loss import CosineBCELoss 18 | 19 | 20 | use_cuda = torch.cuda.is_available() 21 | device = torch.device("cuda" if use_cuda else "cpu") 22 | 23 | 24 | class TrainGenerator(): 25 | """ 26 | ************************ 27 | Training Generator Model 28 | ************************ 29 | """ 30 | def __init__ (self, args): 31 | # arguement and hyperparameters 32 | self.save_name = args.save_name 33 | self.checkpoint_dir = args.checkpoint_dir 34 | self.checkpoint_path = args.checkpoint_path 35 | self.batch_size = hparams.gen_batch_size 36 | self.global_epoch = 0 37 | self.nepochs = hparams.gen_nepochs 38 | self.sync_coeff = hparams.gen_sync_coeff 39 | self.recon_coeff = hparams.gen_recon_coeff 40 | self.gen_lr = hparams.gen_gen_lr 41 | self.disc_lr = hparams.gen_disc_lr 42 | self.train_type = args.train_type 43 | self.pretrain_path = args.pretrain_syncnet_path 44 | self.model_type = args.model_type 45 | 46 | # if not using Discriminator 47 | if self.train_type == "normal": 48 | 49 | self.recon_coeff = 1.0 50 | self.sync_coeff = 0 51 | 52 | 53 | 54 | self.checkpoint_interval = args.checkpoint_interval 55 | 56 | if (self.train_type != "normal") and (self.recon_coeff + self.sync_coeff) != 1 : 57 | 58 | raise ValueError("Sum of the loss coeff should be sum up to 1, the recon_coeff is {} and sync_coeff is {}".format(self.recon_coeff,self.sync_coeff)) 59 | 60 | 61 | # Tensorboard for log and result visualization 62 | self.writer = SummaryWriter("../tensorboard/{}".format(self.save_name)) 63 | 64 | 65 | """<---------------------------Dataset -------------------------------------->""" 66 | self.train_dataset = Dataset(split='train', args=args) 67 | 68 | self.vali_dataset = Dataset(split='val', args=args) 69 | 70 | self.train_loader = data_utils.DataLoader(self.train_dataset, 71 | batch_size=self.batch_size, 72 | shuffle=True, 73 | num_workers=hparams.num_workers) 74 | 75 | self.vali_loader = data_utils.DataLoader(self.vali_dataset, 76 | batch_size=self.batch_size, 77 | shuffle=False, 78 | num_workers=hparams.num_workers) 79 | 80 | 81 | 82 | 83 | """ <------------------------------SyncNet Discriminator ------------------------------------->""" 84 | if self.train_type != "normal": 85 | # load Syncnet model 86 | self.syncnet = SyncNet().to(device=device) 87 | 88 | #check if using pretrain Discriminator 89 | if self.train_type == "pretrain": 90 | print("######################") 91 | print("Using Pretrain Syncnet") 92 | print("######################") 93 | 94 | self.syncnet = load_checkpoint(path=self.pretrain_path, 95 | model=self.syncnet, 96 | optimizer=None, 97 | use_cuda=use_cuda, 98 | reset_optimizer=True, 99 | pretrain=True 100 | ) 101 | 102 | for params in self.syncnet.parameters(): 103 | 104 | params.requires_grad = False 105 | 106 | self.syncnet.to(device) 107 | #self.syncnet.eval() 108 | 109 | 110 | 111 | else: 112 | 113 | print("##################################################################") 114 | print("Not using pretrain Syncnet and training it together with generator") 115 | print("##################################################################") 116 | 117 | self.disc_optimizer = optim.Adam([params for params in self.syncnet.parameters() if params.requires_grad], lr=self.disc_lr) 118 | 119 | 120 | print("Finish loading Syncnet !!") 121 | else: 122 | 123 | print("##############################################################") 124 | print("Not using Discriminator(SyncNet), training only generator") 125 | print("##############################################################") 126 | 127 | 128 | 129 | """<----------------------------Generator------------------------------------------->""" 130 | 131 | if self.model_type == "lstm": 132 | 133 | from src.models.lstmgen import LstmGen as Generator 134 | 135 | print("Import LSTM generator") 136 | 137 | elif self.model_type == "attn_lstm": 138 | 139 | from src.models.attnlstm import LstmGen as Generator 140 | 141 | 142 | print("Import Attention LSTM generator") 143 | 144 | else: 145 | 146 | raise ValueError("please put the valid type of model") 147 | 148 | 149 | 150 | 151 | 152 | # load lip generator 153 | self.generator = Generator().to(device=device) 154 | 155 | print(self.generator) 156 | 157 | print("Number of Parameters : ", sum([params.numel() for params in self.generator.parameters() ])) 158 | 159 | self.gen_optimizer = optim.Adam([params for params in self.generator.parameters() if params.requires_grad], lr=self.gen_lr) 160 | 161 | # load checkpoint if the path is given 162 | self.continue_ckpt = False 163 | if self.checkpoint_path is not None: 164 | 165 | self.continue_ckpt =True 166 | 167 | self.generator, self.gen_optimizer, self.global_epoch = load_checkpoint(path = self.checkpoint_path, 168 | model = self.generator, 169 | optimizer = self.optimizer, 170 | use_cuda = use_cuda, 171 | reset_optimizer = False, 172 | pretain=False 173 | ) 174 | print("Load generator checkpoint") 175 | 176 | 177 | if self.continue_ckpt: 178 | 179 | self.train_loss , self.vali_loss = load_logs(model_name="generator", savename="{}.csv".format(self.save_name),epoch=self.global_epoch, type_model='generator') 180 | 181 | self.global_epoch +=1 182 | 183 | else: 184 | 185 | print("Not continue form Checkpoint") 186 | self.train_loss = np.array([]) 187 | self.vali_loss = np.array([]) 188 | 189 | 190 | """<-----------------------Parallel Trainining-------------------------------->""" 191 | # If GPU detect more that one then train model in parallel 192 | if torch.cuda.device_count() > 1: 193 | 194 | self.generator = DataParallel(self.generator) 195 | self.batch_size = self.batch_size * torch.cuda.device_count() 196 | print("Training or Testing model with {} GPU " .format(torch.cuda.device_count())) 197 | 198 | self.generator.to(device) 199 | 200 | 201 | """<----------List of loss funtion--------->""" 202 | # SyncLoss 203 | self.sync_loss = CosineBCELoss() 204 | # Mean Square Error loss 205 | self.mse_loss = nn.MSELoss() 206 | # L1 loss 207 | self.l1_loss = nn.L1Loss() 208 | # chosen reconstruction loss 209 | self.recon_loss = self.l1_loss 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | def __get_disc_loss__ (self,disc_pred, y): 218 | 219 | """ 220 | Calculate SyncLoss from Syncnet 221 | 222 | """ 223 | 224 | # prdicted embedding 225 | s , v = disc_pred 226 | 227 | # create a torch tensor for the groundtruth 228 | y= torch.ones(s.shape[0],1).to(device) if y == 1 else torch.zeros(s.shape[0],1).to(device) 229 | 230 | 231 | # caculate sync loss 232 | loss , _, _,_ = self.sync_loss(s,v,y) 233 | 234 | return loss 235 | 236 | 237 | 238 | def __train_model__ (self): 239 | 240 | 241 | running_gen_disc_loss = 0. 242 | running_recon_loss = 0. 243 | running_disc_loss = 0. 244 | 245 | running_gen_loss =0 246 | iter_inbatch = 0 247 | 248 | prog_bar = tqdm(self.train_loader) 249 | 250 | for (con_fl, seq_mels, mel, gt_fl) in prog_bar: 251 | 252 | con_lip = con_fl.to(device) 253 | gt_lip = gt_fl.to(device) 254 | seq_mels = seq_mels.to(device) 255 | mel = mel.to(device) 256 | 257 | 258 | 259 | ###################### Discriminator ############################# 260 | if self.train_type == "adversarial": 261 | 262 | 263 | self.syncnet.train() 264 | self.disc_optimizer.zero_grad() 265 | # generate a fake lip from generator 266 | fake_lip, _ = self.generator(seq_mels, con_lip) 267 | disc_fake_pred = self.syncnet(mel,fake_lip.detach()) 268 | disc_real_pred = self.syncnet(mel,gt_lip) 269 | 270 | 271 | disc_fake_loss = self.__get_disc_loss__(disc_fake_pred, y=0) 272 | disc_real_loss = self.__get_disc_loss__(disc_real_pred, y=1) 273 | 274 | disc_loss = disc_fake_loss + disc_real_loss 275 | disc_loss.backward(retain_graph=True) 276 | self.disc_optimizer.step() 277 | 278 | 279 | else : 280 | 281 | disc_loss = torch.zeros(1) 282 | 283 | running_disc_loss += disc_loss.item() 284 | ################################################################# 285 | 286 | 287 | 288 | ####################### Generator ############################### 289 | 290 | self.gen_optimizer.zero_grad() 291 | self.generator.train() 292 | gen_lip, _ = self.generator(seq_mels, con_lip) 293 | 294 | 295 | if self.train_type != "normal": 296 | 297 | disc_gen_pred = self.syncnet(mel, gen_lip) 298 | gen_disc_loss = self.__get_disc_loss__(disc_gen_pred,y=1) 299 | 300 | else: 301 | 302 | gen_disc_loss = torch.zeros(1) 303 | 304 | gt_lip = gt_lip.reshape(gt_lip.size(0),-1) 305 | gen_lip = gen_lip.reshape(gen_lip.size(0),-1) 306 | recon_loss = self.recon_loss(gen_lip,gt_lip) 307 | 308 | 309 | if self.train_type != "normal": 310 | 311 | gen_loss = (self.recon_coeff * recon_loss) + (self.sync_coeff * gen_disc_loss) 312 | 313 | else : 314 | 315 | gen_loss = recon_loss 316 | 317 | gen_loss.backward() 318 | self.gen_optimizer.step() 319 | 320 | #################################################################### 321 | 322 | 323 | running_recon_loss += recon_loss.item() * self.recon_coeff if self.train_type != "gen" else recon_loss.item() 324 | 325 | running_gen_disc_loss += gen_disc_loss.item() * self.sync_coeff if self.train_type != "gen" else gen_disc_loss.item() 326 | running_gen_loss += gen_loss.item() 327 | 328 | iter_inbatch+=1 329 | 330 | 331 | 332 | 333 | 334 | prog_bar.set_description("TRAIN Epochs: {} || Generator Loss : {:.5f} , Recon({})({}) : {:.5f}, Sync({}) : {:.5f} || Disciminator : {:.5f}".format( 335 | self.global_epoch, 336 | running_gen_loss/iter_inbatch, 337 | self.recon_coeff, 338 | self.recon_loss, 339 | running_recon_loss/iter_inbatch, 340 | self.sync_coeff, 341 | running_gen_disc_loss/iter_inbatch, 342 | running_disc_loss/iter_inbatch)) 343 | 344 | avg_gen_loss = running_gen_loss / iter_inbatch 345 | avg_recon_loss = running_recon_loss / iter_inbatch 346 | avg_gen_disc_loss = running_gen_disc_loss/ iter_inbatch 347 | avg_disc_loss = running_disc_loss/iter_inbatch 348 | 349 | 350 | return avg_gen_loss, avg_recon_loss, avg_gen_disc_loss, avg_disc_loss 351 | 352 | 353 | def __eval_model__ (self, split=None): 354 | 355 | running_gen_disc_loss = 0. 356 | running_recon_loss = 0. 357 | running_disc_loss = 0. 358 | 359 | running_gen_loss =0 360 | iter_inbatch = 0 361 | 362 | prog_bar = tqdm(self.vali_loader) 363 | 364 | with torch.no_grad(): 365 | for (con_fl, seq_mels, mel, gt_fl) in prog_bar: 366 | 367 | con_lip = con_fl.to(device) 368 | gt_lip = gt_fl.to(device) 369 | seq_mels = seq_mels.to(device) 370 | mel = mel.to(device) 371 | 372 | 373 | if self.train_type == "adversarial": 374 | ################### Discriminator ########################## 375 | 376 | self.syncnet.eval() 377 | # generate a fake lip from generator 378 | fake_lip, _ = self.generator(seq_mels, con_lip) 379 | disc_fake_pred = self.syncnet(mel,fake_lip.detach()) 380 | disc_real_pred = self.syncnet(mel,gt_lip) 381 | 382 | 383 | disc_fake_loss = self.__get_disc_loss__(disc_fake_pred, y=0) 384 | disc_real_loss = self.__get_disc_loss__(disc_real_pred, y=1) 385 | 386 | disc_loss = disc_fake_loss + disc_real_loss 387 | 388 | running_disc_loss += disc_loss.item() 389 | 390 | else : 391 | 392 | 393 | disc_loss = torch.zeros(1) 394 | 395 | ################### Generator ############################## 396 | 397 | self.generator.eval() 398 | 399 | gen_lip, _ = self.generator(seq_mels, con_lip) 400 | 401 | 402 | if self.train_type != "gen": 403 | disc_gen_pred = self.syncnet(mel, gen_lip) 404 | gen_disc_loss = self.__get_disc_loss__(disc_gen_pred,y=1) 405 | 406 | else : 407 | 408 | gen_disc_loss = torch.zeros(1) 409 | 410 | gt_lip = gt_lip.reshape(gt_lip.size(0),-1) 411 | gen_lip = gen_lip.reshape(gen_lip.size(0),-1) 412 | recon_loss = self.recon_loss(gen_lip,gt_lip) 413 | ############################################################ 414 | 415 | 416 | 417 | if self.train_type != "gen": 418 | 419 | gen_loss = (self.recon_coeff * recon_loss) + (self.sync_coeff * gen_disc_loss) 420 | 421 | else : 422 | 423 | gen_loss = recon_loss 424 | 425 | 426 | 427 | running_recon_loss += recon_loss.item() * self.recon_coeff if self.train_type != "gen" else recon_loss.item() 428 | running_gen_disc_loss += gen_disc_loss.item() * self.sync_coeff if self.train_type != "gen" else gen_disc_loss.item() 429 | running_gen_loss += gen_loss.item() 430 | 431 | iter_inbatch+=1 432 | 433 | 434 | prog_bar.set_description("VALI Epochs : {} || Generator Loss : {:.5f} , Recon({})({}) : {:.5f}, Sync({}) : {:.5f} || Disciminator : {:.5f}".format( 435 | self.global_epoch, 436 | running_gen_loss/iter_inbatch, 437 | self.recon_coeff, 438 | self.recon_loss, 439 | running_recon_loss/iter_inbatch, 440 | self.sync_coeff, 441 | running_gen_disc_loss/iter_inbatch, 442 | running_disc_loss/iter_inbatch)) 443 | 444 | avg_gen_loss = running_gen_loss / iter_inbatch 445 | avg_recon_loss = running_recon_loss / iter_inbatch 446 | avg_gen_disc_loss = running_gen_disc_loss/ iter_inbatch 447 | avg_disc_loss = running_disc_loss/iter_inbatch 448 | 449 | 450 | return avg_gen_loss, avg_recon_loss, avg_gen_disc_loss, avg_disc_loss 451 | 452 | 453 | 454 | def __update_logs__ (self, 455 | cur_train_gen_loss, 456 | cur_vali_gen_loss, 457 | cur_train_recon_loss, 458 | cur_vali_recon_loss, 459 | cur_train_sync_loss, 460 | cur_vali_sync_loss, 461 | cur_train_disc_loss, 462 | cur_vali_disc_loss, 463 | com_fig, com_seq_fig): 464 | 465 | 466 | #self.train_loss = np.append(self.train_loss, cur_train_loss) 467 | #self.vali_loss = np.append(self.vali_loss, cur_vali_loss) 468 | 469 | # save_logs(train_loss=self.train_loss, 470 | # vali_loss=self.vali_loss, 471 | # model_name="generator", 472 | # savename='{}.csv'.format(self.save_name) 473 | # ) 474 | 475 | # ******* plot metrics ********* 476 | # plot metrics comparison (train vs validation) 477 | loss_comp = plot_comp(self.train_loss,self.vali_loss, name="Loss") 478 | # ***** Tensorboard **** 479 | # Figure 480 | self.writer.add_figure('Comp/loss', loss_comp, self.global_epoch) 481 | self.writer.add_scalar("Loss/train_gen", cur_train_gen_loss, self.global_epoch) 482 | self.writer.add_scalar("Loss/val_gen" , cur_vali_gen_loss ,self.global_epoch) 483 | 484 | self.writer.add_scalar("Loss/train_disc", cur_train_disc_loss, self.global_epoch) 485 | self.writer.add_scalar("Loss/val_disc" , cur_vali_disc_loss ,self.global_epoch) 486 | 487 | self.writer.add_scalar("Loss/train_recon" , cur_train_recon_loss ,self.global_epoch) 488 | self.writer.add_scalar("Loss/val_recon" , cur_vali_recon_loss ,self.global_epoch) 489 | self.writer.add_scalar("Loss/train_sync" , cur_train_sync_loss ,self.global_epoch) 490 | self.writer.add_scalar("Loss/val_sync" , cur_vali_sync_loss ,self.global_epoch) 491 | self.writer.add_figure("Vis/compare", com_fig, self.global_epoch) 492 | 493 | 494 | for frame in range(len(com_seq_fig)): 495 | 496 | self.writer.add_figure("Vis/seq", com_seq_fig[frame], frame) 497 | 498 | 499 | 500 | 501 | def __training_stage__ (self): 502 | 503 | 504 | while self.global_epoch < self.nepochs: 505 | 506 | 507 | cur_train_gen_loss , cur_train_recon_loss , cur_train_sync_loss, cur_train_disc_loss= self.__train_model__() 508 | 509 | 510 | 511 | cur_vali_gen_loss , cur_vali_recon_loss , cur_vali_sync_loss, cur_vali_disc_loss = self.__eval_model__() 512 | 513 | com_fig, com_seq_fig = self.__vis_lip_result__() 514 | 515 | self.__update_logs__(cur_train_gen_loss, cur_vali_gen_loss, cur_train_recon_loss, cur_vali_recon_loss, cur_train_sync_loss, cur_vali_sync_loss, cur_train_disc_loss, cur_vali_disc_loss, com_fig, com_seq_fig) 516 | 517 | 518 | if (((self.global_epoch % self.checkpoint_interval == 0) or (self.global_epoch == self.nepochs-1)) or self.global_epoch == 5) and (self.global_epoch != 0): 519 | 520 | # save checkpoint 521 | save_checkpoint(self.generator, self.gen_optimizer, self.checkpoint_dir, self.global_epoch, '{}.pth'.format(self.save_name)) 522 | 523 | if self.train_type == "adversarial": 524 | 525 | save_checkpoint(self.syncnet, self.disc_optimizer, self.checkpoint_dir, self.global_epoch,'disc_{}.pth'.format(self.save_name)) 526 | 527 | self.__vis_vdo_result__() 528 | 529 | self.global_epoch +=1 530 | 531 | 532 | def start_training(self): 533 | 534 | print("Save name : {}".format(self.save_name)) 535 | print("Using CUDA : {} ".format(use_cuda)) 536 | 537 | if use_cuda: print ("Using {} GPU".format(torch.cuda.device_count())) 538 | 539 | print("Training dataset {}".format(len(self.train_dataset))) 540 | print("Validation dataset {}".format(len(self.vali_dataset))) 541 | 542 | 543 | #self.__vis_comp_graph__() 544 | print("Start training generator") 545 | 546 | self.__training_stage__() 547 | 548 | print("Finish Trainig generator") 549 | 550 | self.writer.close() 551 | 552 | 553 | def __vis_comp_graph__ (self): 554 | """ 555 | ******************************************************************* 556 | __vis_comp_graph__ : visualising computational graph on tensorboard 557 | ******************************************************************* 558 | """ 559 | 560 | # iterate over data loader 561 | data = iter(self.vali_loader) 562 | # get the the first iteration 563 | (con_fl, seq_mels, _, _ ) = next(data) 564 | # computational graph visualization on tensorboard 565 | self.writer.add_graph(self.generator, (seq_mels.to(device),con_fl.to(device))) 566 | 567 | del data # remove data iterator 568 | 569 | 570 | 571 | def __vis_lip_result__ (self): 572 | """ 573 | ****************************************************** 574 | __vis_lip_result__ : visualising result on tensorboard 575 | ****************************************************** 576 | """ 577 | 578 | data = iter(self.vali_loader) 579 | 580 | (con_fl, seq_mels , mel , gt_fl) = next(data) 581 | 582 | 583 | with torch.no_grad() : 584 | 585 | self.generator.eval() 586 | 587 | con_lip = con_fl.to(device) 588 | gt_lip = gt_fl.to(device) 589 | 590 | seq_mels = seq_mels.to(device) 591 | mel = mel.to(device) 592 | 593 | seq_len = con_lip.size(1) 594 | 595 | gen_lip, _ = self.generator(seq_mels, con_lip) 596 | 597 | # get one sample from a batch and convert to numpy array 598 | gen,ref, gt = gen_lip[0].detach().clone().cpu().numpy() ,con_lip[0].detach().clone().cpu().numpy(), gt_lip[0].detach().clone().cpu().numpy() 599 | # reshape to seq 600 | gen = gen.reshape(seq_len,20,-1) 601 | ref = ref.reshape(seq_len,20,-1) 602 | gt = gt.reshape(seq_len,20,-1) 603 | 604 | 605 | # plot single lip landmark 606 | single_fig = plot_lip_comparision(gen[0],ref[0],gt[0]) 607 | # plot sequence of lip landmark 608 | seq_fig = [] 609 | for idx in range(seq_len): 610 | 611 | vis_fig = plot_lip_comparision(gen[idx],ref[idx],gt[idx]) 612 | 613 | seq_fig.append(vis_fig) 614 | 615 | 616 | return single_fig, seq_fig 617 | 618 | def __vis_vdo_result__ (self): 619 | """ 620 | ********************************************************************* 621 | __vis_vdo_result__ : visualising the result of the model on inference 622 | ********************************************************************* 623 | """ 624 | from .inference import Inference 625 | import argparse 626 | 627 | folder = "./results/training/{}/".format(self.save_name) 628 | 629 | if not os.path.exists(folder): 630 | 631 | os.mkdir(folder) 632 | 633 | save_path = os.path.join(folder, "eval_epoch{}.mp4".format(self.global_epoch)) 634 | print(save_path) 635 | 636 | parser = argparse.ArgumentParser(description="File for running Inference") 637 | 638 | 639 | parser.add_argument('--model_type', help='Type of generator model', default=self.model_type, type=str) 640 | 641 | parser.add_argument('--generator_checkpoint', type=str, help="File path for Generator model checkpoint weights" ,default='./checkpoints/generator/{}.pth'.format(self.save_name)) 642 | 643 | parser.add_argument('--image2image_checkpoint', type=str, help="File path for Image2Image Translation model checkpoint weigths", default='./checkpoints/image2image/image2image.pth',required=False) 644 | 645 | parser.add_argument('--input_face', type=str, help="File path for input videos/images contain face",default='dummy/me.jpeg', required=False) 646 | 647 | parser.add_argument('--input_audio', type=str, help="File path for input audio/speech as .wav files", default='./dummy/main_testing.wav', required=False) 648 | 649 | parser.add_argument('--fps', type=float, help= "Can only be specified only if using static image default(25 FPS)", default=25,required=False) 650 | 651 | parser.add_argument('--fl_detector_batchsize', type=int , help='Batch size for landmark detection', default = 32) 652 | 653 | parser.add_argument('--generator_batchsize', type=int, help="Batch size for Generator model", default=5) 654 | 655 | parser.add_argument('--output_name', type=str , help="Name and path of the output file", default=save_path) 656 | 657 | parser.add_argument('--vis_fl', type=bool, help="Visualize Facial Landmark ??", default=True) 658 | 659 | parser.add_argument('--only_fl', type=bool, help="Visualize only Facial Landmark ??", default=False) 660 | 661 | parser.add_argument('--test_img2img', type=bool, help="Testing image2image module with no lip generation" , default=False) 662 | 663 | args = parser.parse_args() 664 | 665 | eval_result = Inference(args=args) 666 | 667 | eval_result.start() 668 | 669 | 670 | 671 | 672 | 673 | -------------------------------------------------------------------------------- /src/main/inference.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parallel import DataParallel 5 | import face_alignment 6 | from tqdm import tqdm 7 | import numpy as np 8 | from utils.plot import vis_landmark_on_img 9 | from utils import audio 10 | import cv2 11 | import subprocess 12 | import platform 13 | import os 14 | from src.models.image2image import ResUnetGenerator 15 | #from src.models.lstmgen import LstmGen as Lip_Gen 16 | #from src.models.lstmattn import LstmGen as Lip_Gen 17 | 18 | #from src.models.transgen import TransformerGenerator as Lip_Gen 19 | from utils.wav2lip import prepare_audio, prepare_video, load_checkpoint 20 | from utils.utils import procrustes 21 | import matplotlib.pyplot as plt 22 | from utils.plot import plot_scatter_facial_landmark 23 | 24 | use_cuda = torch.cuda.is_available() 25 | 26 | device = "cuda" if use_cuda else "cpu" 27 | 28 | class Inference(): 29 | 30 | 31 | def __init__ (self, args): 32 | 33 | self.fl_batchsize = args.fl_detector_batchsize 34 | self.gen_batchsize = args.generator_batchsize 35 | self.image2image_ckpt = args.image2image_checkpoint 36 | self.generator_ckpt = args.generator_checkpoint 37 | self.input_face = args.input_face 38 | self.fps = args.fps 39 | self.input_audio = args.input_audio 40 | self.vis_fl = args.vis_fl 41 | self.only_fl = args.only_fl 42 | self.output_name = args.output_name 43 | self.test_img2img = args.test_img2img 44 | self.seq_len = 5#args.seq_len 45 | self.model_type = args.model_type 46 | 47 | 48 | 49 | 50 | self.all_frames , self.fps = prepare_video(args.input_face, args.fps) 51 | self.mel_chunk = prepare_audio(args.input_audio, self.fps) 52 | 53 | # crop timestamp of a video incase video is longer than audio 54 | self.all_frames = self.all_frames[:len(self.mel_chunk)] 55 | 56 | 57 | # Image2Image translation model 58 | self.image2image = ResUnetGenerator(input_nc=6,output_nc=3,num_downs=6,use_dropout=False).to(device) 59 | 60 | # Load pretrained weights to image2image model 61 | image2image_weight = torch.load(self.image2image_ckpt, map_location=torch.device(device))['G'] 62 | # Since the checkpoint of model was trained using DataParallel with multiple GPU 63 | # It required to wrap a model with DataParallel wrapper class 64 | self.image2image = DataParallel(self.image2image).to(device) 65 | # assgin weight to model 66 | self.image2image.load_state_dict(image2image_weight) 67 | 68 | self.image2image = self.image2image.module # access model (remove DataParallel) 69 | 70 | 71 | 72 | if self.model_type == "lstm": 73 | 74 | from src.models.lstmgen import LstmGen as Lip_Gen 75 | 76 | print("Import LSTM generator") 77 | 78 | elif self.model_type == "attn_lstm": 79 | 80 | from src.models.attnlstm import LstmGen as Lip_Gen 81 | 82 | print("Import Attention LSTM generator") 83 | 84 | else: 85 | 86 | raise ValueError("please put the valid type of model") 87 | 88 | 89 | 90 | self.generator = Lip_Gen().to(device=device) 91 | 92 | self.generator = load_checkpoint(model=self.generator, 93 | path= self.generator_ckpt, 94 | optimizer=None, 95 | use_cuda=use_cuda, 96 | reset_optimizer=True, 97 | pretrain=True) 98 | 99 | 100 | def __landmark_detection__(self,images, batch_size): 101 | """ 102 | *************************************************************************************** 103 | Detect 3D Facial Landmark from images using Landmark Detector Tools from Face_Alignment 104 | Link repo : https://github.com/1adrianb/face-alignment 105 | *************************************************************************************** 106 | @author : Wish Suharitdamrong 107 | -------- 108 | arguments 109 | --------- 110 | images : list of images 111 | ------ 112 | return 113 | ------ 114 | """ 115 | 116 | def detect_bug_136(fls): 117 | """ 118 | Some times when using detector.get_landmarks_from_batch it does has some bug. Instead of returning facial landmarks (68,3) for single person in image it instead 119 | return (136,3) or (204,3). The first 68 point still a valid facial landamrk of that image (as I visualised). So this fuction basically removed the extra 68 point in landmarks. 120 | This can cause from the image that have more than one face in the image 121 | """ 122 | 123 | 124 | for i in range(len(fls)): 125 | 126 | print(np.array(fls[i]).shape) 127 | if len(fls[i]) != 68: 128 | 129 | bug = fls[i] 130 | 131 | fl1 = bug[:68] 132 | 133 | #fl2 = bug[68:] 134 | 135 | 136 | 137 | fls[i] = fl1 138 | if len(fls[i]) == 0: 139 | 140 | fls[i] = fls[i-1] 141 | 142 | 143 | detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False, device=device) 144 | images = np.array(images) # shape (Batch , height, width, 3) 145 | images = np.transpose(images,(0,3,1,2)) # shape (Batch, 3, height, width) 146 | images = torch.from_numpy(images) 147 | """ 148 | fls = detector.get_landmarks_from_batch(images) 149 | fls = np.array(fls) 150 | """ 151 | 152 | 153 | fls = [] 154 | transforms = [] 155 | 156 | for i in tqdm(range(0, len(images), batch_size)): 157 | 158 | img = images[i:i+batch_size] 159 | 160 | 161 | fl_batch = detector.get_landmarks_from_batch(img) 162 | 163 | 164 | 165 | detect_bug_136(fl_batch) 166 | 167 | 168 | 169 | fl_batch = np.array(fl_batch)#[:,:,:] # take only 3d 170 | 171 | 172 | 173 | fl = [] 174 | for idx in range(fl_batch.shape[0]): 175 | 176 | fl_inbatch, trans_info = procrustes(fl_batch[idx]) 177 | fl.append(fl_inbatch) 178 | transforms.append(trans_info) 179 | 180 | fl = np.array(fl) 181 | 182 | fls.append(fl) 183 | 184 | 185 | fls = np.concatenate(fls, axis=0) 186 | transforms = np.array(transforms) 187 | 188 | 189 | return fls, transforms 190 | 191 | def __keypoints2landmarks__(self,fls): 192 | """ 193 | 194 | """ 195 | 196 | frames = [] 197 | for fl in fls: 198 | 199 | img = np.ones(shape=(256,256,3)) * 255 # blank white image 200 | 201 | fl = fl.astype(int) 202 | 203 | img = vis_landmark_on_img(img,fl).astype(int) 204 | 205 | frames.append(img) 206 | 207 | frames = np.stack(frames, axis=0) 208 | 209 | return frames 210 | 211 | 212 | def __reverse_trans__(self,fl , tran): 213 | 214 | scale = tran['scale'] 215 | translate = tran['translate'] 216 | 217 | fl = fl * scale # reverse scaling 218 | fl = fl + translate # reverse translation 219 | 220 | return fl 221 | 222 | def __reverse_trans_batch__ (self, fl , trans) : 223 | 224 | trans_fls =[] 225 | 226 | for idx in range(fl.shape[0]): 227 | 228 | trans_fl = self.__reverse_trans__(fl[idx], trans[idx]) 229 | 230 | trans_fls.append(trans_fl) 231 | 232 | trans_fls = np.array(trans_fls) 233 | 234 | return trans_fls 235 | 236 | 237 | def __data_generator__(self): 238 | """ 239 | 240 | """ 241 | 242 | fl_batch , trans_batch, mel_batch, frame_batch = [],[],[],[] 243 | 244 | fl_seq , trans_seq, mel_seq, frame_seq = [],[],[],[] 245 | 246 | frames = self.all_frames 247 | mels = self.mel_chunk 248 | 249 | 250 | print("Detecting Facial Landmark ....") 251 | fl_detected, transformation = self.__landmark_detection__(frames, self.fl_batchsize) 252 | print("Finish detecting Facial Landmark !!!") 253 | 254 | for i, m in enumerate(mels): 255 | 256 | idx = i % len(frames) # if input if static image only select frame and landmark at index 0 257 | 258 | frame_to_trans = frames[idx].copy() 259 | fl = fl_detected[idx].copy() 260 | transforms = transformation[idx].copy() 261 | 262 | fl_seq.append(fl) 263 | trans_seq.append(transforms) 264 | mel_seq.append(m) 265 | frame_seq.append(frame_to_trans) 266 | 267 | 268 | if len(fl_seq) >= self.seq_len: 269 | 270 | fl_batch.append(fl_seq) 271 | trans_batch.append(trans_seq) 272 | mel_batch.append(mel_seq) 273 | frame_batch.append(frame_seq) 274 | 275 | fl_seq , trans_seq, mel_seq, frame_seq = [],[],[],[] 276 | 277 | 278 | if len(fl_batch) >= self.gen_batchsize: 279 | 280 | fl_batch = np.array(fl_batch) 281 | trans_batch = np.array(trans_batch) # this might cause error by wrapping a dict in np 282 | mel_batch = np.array(mel_batch) 283 | mel_batch = np.reshape(mel_batch, [len(mel_batch), self.seq_len , 1 , mel_batch.shape[2], mel_batch.shape[3]]) # b ,s ,1 , 80 , 18 (old 80,18,1) 284 | frame_batch = np.array(frame_batch) 285 | 286 | 287 | 288 | yield fl_batch, trans_batch, mel_batch, frame_batch 289 | 290 | fl_batch, trans_batch, mel_batch, frame_batch = [], [], [], [] 291 | 292 | #print(np.array(fl_batch).shape) 293 | #print(np.array(fl_seq).shape) 294 | 295 | 296 | if len(fl_batch) > 0 : 297 | #print("tt") 298 | fl_batch = np.array(fl_batch) 299 | #print(fl_batch.shape) 300 | trans_batch = np.array(trans_batch) # this might cause error by wrapping a dict in np 301 | #print(trans_batch.shape) 302 | mel_batch = np.array(mel_batch) 303 | #print(mel_batch.shape) 304 | mel_batch = np.reshape(mel_batch, [len(mel_batch), self.seq_len,1 ,mel_batch.shape[2], mel_batch.shape[3]]) 305 | #print(mel_batch.shape) 306 | frame_batch = np.array(frame_batch) 307 | 308 | yield fl_batch, trans_batch, mel_batch, frame_batch 309 | 310 | fl_batch, trans_batch, mel_batch, frame_batch = [], [], [], [] 311 | 312 | if len(fl_seq) > 0: 313 | 314 | #print("hello") 315 | 316 | 317 | fl_batch = np.expand_dims(np.array(fl_seq),axis=0) 318 | #print(fl_batch.shape) 319 | trans_batch = np.expand_dims(np.array(trans_seq),axis=0) # this might cause error by wrapping a dict in np 320 | #print(trans_batch.shape) 321 | mel_batch = np.expand_dims(np.array(mel_seq),axis=0) 322 | curr_mel_seq = mel_batch.shape[1] 323 | #print(mel_batch.shape) 324 | mel_batch = np.reshape(mel_batch, [len(mel_batch), curr_mel_seq,1 ,mel_batch.shape[2], mel_batch.shape[3]]) 325 | #print(mel_batch.shape) 326 | frame_batch = np.expand_dims(np.array(frame_seq),axis=0) 327 | 328 | #exit() 329 | 330 | yield fl_batch, trans_batch, mel_batch, frame_batch 331 | 332 | fl_batch, trans_batch, mel_batch, frame_batch = [], [], [], [] 333 | 334 | 335 | def start(self): 336 | """ 337 | """ 338 | 339 | self.data = self.__data_generator__() 340 | 341 | 342 | if self.vis_fl and not self.only_fl: 343 | writer = cv2.VideoWriter('./temp/out.mp4', cv2.VideoWriter_fourcc(*'mjpg'), self.fps, (256*3,256)) 344 | else : 345 | writer = cv2.VideoWriter('./temp/out.mp4', cv2.VideoWriter_fourcc(*'mjpg'), self.fps, (256,256)) 346 | 347 | for (fl, trans, mel, ref_frame) in tqdm(self.data): 348 | 349 | # fl shape (B, 68, 3) 350 | # mel shape (B, 80, 18, 1) 351 | # ref frame (B, 256, 256, 3) 352 | lip_fl = torch.FloatTensor(fl).to(device) 353 | 354 | lip_fl = lip_fl[:,:,48:,:] # take only lip keypoints 355 | 356 | lip_seq = lip_fl.size(0) 357 | lip_fl = torch.stack([lip_fl[0] for _ in range(lip_seq)], dim=0) 358 | lip_fl = lip_fl.reshape(lip_fl.shape[0],lip_fl.shape[1],-1) 359 | mel = torch.FloatTensor(mel).to(device) 360 | #print(mel.shape) 361 | #mel = mel.reshape(-1,80,18) 362 | 363 | if not self.test_img2img: # check if not testing image2image translation module only no lip generator 364 | with torch.no_grad(): 365 | 366 | self.generator.eval() 367 | out_fl,_ = self.generator(mel, lip_fl) 368 | 369 | 370 | out_fl = out_fl.detach().cpu().numpy() # convert output to numpy array 371 | out_fl = out_fl.reshape(out_fl.shape[0],out_fl.shape[1],20,-1) 372 | 373 | out_fl = out_fl 374 | fl[:,:,48:,:] = out_fl 375 | 376 | 377 | fl = fl.reshape(-1,fl.shape[2],fl.shape[3]) 378 | #ref_frame = ref_frame.reshape(-1,ref_frame.shape[2], ref_frame[3]) 379 | trans = trans.reshape(-1) 380 | fl = self.__reverse_trans_batch__(fl , trans) 381 | 382 | 383 | # plot a image of landmarks 384 | fl_image = self.__keypoints2landmarks__(fl) 385 | 386 | 387 | fl_image = fl_image.reshape(ref_frame.shape[0],ref_frame.shape[1],ref_frame.shape[2],ref_frame.shape[3],ref_frame.shape[4]) 388 | 389 | 390 | if not self.only_fl: 391 | # image translation 392 | for (img_batch,ref_batch) in zip(fl_image, ref_frame): 393 | 394 | for img, ref in zip(img_batch, ref_batch): 395 | 396 | trans_in = np.concatenate((img,ref), axis=2).astype(np.float32)/255.0 397 | trans_in = trans_in.transpose((2, 0, 1)) 398 | trans_in = torch.tensor(trans_in, requires_grad=False) 399 | trans_in = trans_in.reshape(-1, 6, 256, 256) 400 | trans_in = trans_in.to(device) 401 | 402 | with torch.no_grad(): 403 | self.image2image.eval() 404 | trans_out = self.image2image(trans_in) 405 | trans_out = torch.tanh(trans_out) 406 | 407 | trans_out = trans_out.detach().cpu().numpy().transpose((0,2,3,1)) 408 | trans_out[trans_out<0] = 0 409 | trans_out = trans_out * 255.0 410 | 411 | if self.vis_fl: 412 | frame = np.concatenate((ref,img,trans_out[0]),axis=1) 413 | else : 414 | frame = trans_out[0] 415 | writer.write(frame.astype(np.uint8)) 416 | 417 | 418 | if self.only_fl: 419 | 420 | for fl_batch in fl_image: 421 | 422 | 423 | for fl in fl_batch: 424 | 425 | writer.write(fl.astype(np.uint8)) 426 | 427 | 428 | # Write video and close writer 429 | writer.release() 430 | 431 | command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(self.input_audio, 'temp/out.mp4', self.output_name) 432 | subprocess.call(command, shell=platform.system() != 'Windows') 433 | 434 | 435 | 436 | 437 | 438 | -------------------------------------------------------------------------------- /src/main/syncnet.py: -------------------------------------------------------------------------------- 1 | 2 | #from os.path import dirname, join, basename, isfile 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.utils import data as data_utils 8 | import numpy as np 9 | import os 10 | import pandas as pd 11 | from hparams import hparams 12 | from torch.nn import DataParallel 13 | from torch.utils.tensorboard import SummaryWriter 14 | from src.models.syncnet import SyncNet 15 | 16 | #from src.models.attn_syncnet import SyncNet 17 | from src.dataset.syncnet import Dataset 18 | from utils.wav2lip import save_checkpoint, load_checkpoint 19 | from utils.utils import save_logs, load_logs 20 | from utils.plot import plot_comp, plot_roc, plot_cm , plot_single 21 | from utils.loss import CosineBCELoss 22 | 23 | 24 | use_cuda = torch.cuda.is_available() 25 | device = torch.device("cuda" if use_cuda else "cpu") 26 | 27 | class TrainSyncNet(): 28 | """ 29 | ***************************************************************************** 30 | TrainSyncNet : Training pretrain SyncNet model (Expert LipSync Discriminator) 31 | ***************************************************************************** 32 | @author : Wish Suharitdarmong 33 | """ 34 | 35 | def __init__(self,args): 36 | 37 | # arguments and hyperparameters 38 | self.save_name = args.save_name 39 | self.checkpoint_dir = args.checkpoint_dir 40 | self.checkpoint_path = args.checkpoint_path 41 | self.batch_size = hparams.syncnet_batch_size 42 | self.global_epoch = 0 43 | self.nepochs = hparams.syncnet_nepochs 44 | self.do_train = args.do_train 45 | 46 | 47 | 48 | # if create checkpoint dir if it does not exist 49 | if not os.path.exists(self.checkpoint_dir): 50 | os.mkdir(self.checkpoint_dir) 51 | 52 | 53 | # Tensorboard 54 | self.writer = SummaryWriter("../tensorboard/{}".format(self.save_name)) 55 | 56 | # Dataset 57 | 58 | 59 | # if not training stage then do not load training and validations set 60 | if self.do_train: 61 | self.train_dataset = Dataset(split='train', args=args) 62 | 63 | self.vali_dataset = Dataset(split='val', args=args) 64 | 65 | self.train_loader = data_utils.DataLoader(self.train_dataset, 66 | 67 | batch_size=self.batch_size, 68 | shuffle=True, 69 | num_workers=hparams.num_workers) 70 | 71 | self.vali_loader = data_utils.DataLoader(self.vali_dataset, 72 | batch_size=self.batch_size, 73 | shuffle=True, 74 | num_workers=hparams.num_workers) 75 | 76 | # Load Testing Set 77 | self.test_dataset = Dataset(split='test',args=args) 78 | 79 | 80 | self.test_loader = data_utils.DataLoader(self.test_dataset, 81 | batch_size=self.batch_size, 82 | shuffle=True, 83 | num_workers=hparams.num_workers) 84 | 85 | 86 | # SyncNet Model 87 | self.model = SyncNet().to(device) 88 | 89 | #print(self.model) 90 | # optimizer 91 | self.optimizer = optim.Adam([p for p in self.model.parameters() if p.requires_grad], 92 | lr=hparams.syncnet_lr, 93 | ) 94 | 95 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer,mode='min', factor=0.1, patience=10) 96 | 97 | # Loss/Cost/Objective function 98 | self.bce_loss = nn.BCELoss() 99 | 100 | if self.scheduler: 101 | 102 | print("Using LR scheduler") 103 | 104 | 105 | 106 | 107 | 108 | 109 | # load checkpoint if the path is given 110 | self.continue_ckpt = False 111 | if self.checkpoint_path is not None: 112 | self.continue_ckpt = True 113 | self.model, self.optimizer, self.global_epoch = load_checkpoint(self.checkpoint_path, self.model, self.optimizer,use_cuda, reset_optimizer=False) 114 | 115 | print("Load Model from Checkpoint Path") 116 | 117 | 118 | # If not training dont load training logs 119 | if self.do_train : 120 | 121 | # if contutinuing from checkpoint then log previous log if not create new empty logs 122 | if self.continue_ckpt: 123 | 124 | print("Starting Checkpoint from epochs : ", self.global_epoch) 125 | 126 | self.train_loss , self.train_acc, self.vali_loss, self.vali_acc = load_logs(model_name = "syncnet", savename="{}.csv".format(self.save_name), epoch=self.global_epoch, type_model="syncnet") 127 | self.global_epoch +=1 128 | else: 129 | 130 | print("Not Continue from Checkpoint") 131 | self.train_loss = np.array([]) 132 | self.vali_loss = np.array([]) 133 | self.train_acc = np.array([]) 134 | self.vali_acc = np.array([]) 135 | 136 | # If GPU detect more that one then train model in parallel 137 | if torch.cuda.device_count() > 1: 138 | 139 | self.model = DataParallel(self.model) 140 | self.batch_size = self.batch_size * torch.cuda.device_count() 141 | print("Training or Testing model with {} GPU " .format(torch.cuda.device_count())) 142 | 143 | self.model.to(device) 144 | 145 | 146 | 147 | self.cosine_bce = CosineBCELoss() 148 | 149 | 150 | 151 | def __train_model__(self): 152 | """ 153 | ******************************************** 154 | __train_model__ : Training stage for SyncNet 155 | ******************************************** 156 | @author : Wish Suharitdarmong 157 | """ 158 | 159 | running_loss = 0. 160 | running_acc = 0. 161 | iter_inbatch = 0 162 | prog_bar = tqdm(self.train_loader) 163 | 164 | for (x, mel, y) in prog_bar: 165 | 166 | # X shape : (B, 5, 20, 3) -- 5 consecutive lip landmarks 167 | # Mel shape : (B, 1 , 80, 18) -- Melspectrogram features 168 | 169 | self.model.train() 170 | self.optimizer.zero_grad() 171 | # allocate data to CUDA 172 | 173 | x = x.to(device) 174 | mel = mel.to(device) 175 | y = y.to(device) 176 | 177 | a , v = self.model(mel, x) 178 | 179 | loss, acc,_ ,_ = self.cosine_bce(a,v,y) 180 | 181 | loss.backward() # Backprop 182 | self.optimizer.step() # Gradient descent step 183 | 184 | running_loss += loss.item() 185 | running_acc += acc 186 | iter_inbatch +=1 187 | 188 | 189 | prog_bar.set_description('TRAIN EPOCHS : {} Loss: {:.3f} Accuracy: {:.3f}'.format(self.global_epoch, 190 | running_loss / iter_inbatch, 191 | running_acc / iter_inbatch)) 192 | 193 | avg_loss = running_loss / iter_inbatch 194 | avg_acc = running_acc / iter_inbatch 195 | 196 | 197 | return avg_loss, avg_acc 198 | 199 | 200 | def __eval_model__ (self,split=None): 201 | """ 202 | 203 | """ 204 | 205 | running_loss = 0. 206 | running_acc = 0. 207 | iter_inbatch = 0 208 | 209 | # check which data spilt will be use for validation 210 | if split == 'test': 211 | 212 | prog_bar = tqdm(self.test_loader) 213 | vis_emb = [] 214 | au_emb = [] 215 | 216 | elif split == 'val': 217 | 218 | prog_bar = tqdm(self.vali_loader) 219 | 220 | else : 221 | # throw error if spilt does not exist 222 | raise ValueError("Wrong data spilt for __eval_model__ , only vali and test are accept in model evaluation") 223 | 224 | # array for visualizing CM and ROC 225 | y_pred_label = np.array([]) 226 | y_pred_proba = np.array([]) 227 | y_gt = np.array([]) 228 | 229 | 230 | 231 | 232 | 233 | with torch.no_grad(): 234 | 235 | for (x, mel, y) in prog_bar: 236 | 237 | self.model.eval() 238 | 239 | x = x.to(device) 240 | mel = mel.to(device) 241 | y = y.to(device) 242 | 243 | a, v = self.model(mel, x) 244 | 245 | 246 | loss, acc,pred_label,pred_proba = self.cosine_bce(a,v,y) 247 | 248 | # add label and proba to a array 249 | y_pred_label = np.append(y_pred_label, pred_label) 250 | y_pred_proba = np.append(y_pred_proba, pred_proba) 251 | y_gt = np.append(y_gt, y.clone().detach().cpu().numpy()) 252 | 253 | if split == 'test': 254 | 255 | a , v = a.clone().detach().cpu().numpy().tolist(), v.clone().detach().cpu().numpy().tolist() 256 | 257 | vis_emb.extend(v) 258 | 259 | au_emb.extend(a) 260 | 261 | 262 | running_loss += loss.item() 263 | running_acc += acc 264 | iter_inbatch +=1 265 | 266 | 267 | prog_bar.set_description('EVAL EPOCHS : {} Loss: {:.3f} Accuracy: {:.3f}'.format(self.global_epoch, 268 | running_loss / iter_inbatch, 269 | running_acc / iter_inbatch 270 | )) 271 | 272 | if split == 'test': 273 | 274 | vis_emb = np.stack(vis_emb, axis=0) 275 | 276 | au_emb = np.stack(au_emb, axis=0) 277 | 278 | 279 | 280 | # plot roc and cm 281 | roc_fig = plot_roc(y_pred_proba,y_gt) 282 | cm_fig = plot_cm(y_pred_label,y_gt) 283 | 284 | avg_loss = running_loss / iter_inbatch 285 | avg_acc = running_acc / iter_inbatch 286 | 287 | if split == 'test': 288 | 289 | return avg_loss, avg_acc , cm_fig, roc_fig , (vis_emb, au_emb, y_gt) 290 | 291 | else: 292 | 293 | return avg_loss, avg_acc , cm_fig, roc_fig 294 | 295 | 296 | def __update_logs__ (self, cur_train_loss, cur_train_acc, cur_vali_loss, cur_vali_acc, cm_fig, roc_fig): 297 | """ 298 | ************************************* 299 | __update_logs__ : update training log 300 | ************************************* 301 | @author : Wish Suharitdarmong 302 | """ 303 | 304 | # Logs in array 305 | self.train_loss = np.append(self.train_loss, cur_train_loss) 306 | self.train_acc = np.append(self.train_acc , cur_train_acc) 307 | self.vali_loss = np.append(self.vali_loss, cur_vali_loss) 308 | self.vali_acc = np.append(self.vali_acc, cur_vali_acc) 309 | 310 | # save logs in csv file 311 | save_logs(train_loss = self.train_loss, 312 | train_acc = self.train_acc, 313 | vali_loss = self.vali_loss, 314 | vali_acc = self.vali_acc, 315 | model_name="syncnet",savename='{}.csv'.format(self.save_name)) 316 | 317 | # ******* plot metrics ********* 318 | # plot metrics comparison (train vs validation) 319 | loss_comp = plot_comp(self.train_loss,self.vali_loss, name="Loss") 320 | acc_comp = plot_comp(self.train_acc, self.vali_acc, name="Accuracy") 321 | # plot individually 322 | train_loss_plot = plot_single(self.train_loss, 'train_loss',name="Train Loss") 323 | vali_loss_plot = plot_single(self.vali_loss, 'vali_loss', name='Validation Loss') 324 | train_acc_plot = plot_single(self.train_acc, 'train_acc', name="Train Accuracy") 325 | vali_acc_plot = plot_single(self.vali_acc, 'vali_acc' , name='Validation Accuracy') 326 | 327 | # ******** Tensorboard ********** 328 | # Scalar 329 | self.writer.add_scalar("Optim/Lr", self.optimizer.param_groups[0]['lr'],self.global_epoch) 330 | self.writer.add_scalar('Loss/train', cur_train_loss , self.global_epoch) 331 | self.writer.add_scalar('Loss/vali', cur_vali_loss, self.global_epoch) 332 | self.writer.add_scalar('Acc/train', cur_train_acc, self.global_epoch) 333 | self.writer.add_scalar('Acc/vali', cur_vali_acc, self.global_epoch) 334 | # Figure 335 | self.writer.add_figure('Comp/loss', loss_comp, self.global_epoch) 336 | self.writer.add_figure('Comp/acc', acc_comp, self.global_epoch) 337 | self.writer.add_figure('Train/acc', train_acc_plot, self.global_epoch) 338 | self.writer.add_figure('Train/loss', train_loss_plot, self.global_epoch) 339 | self.writer.add_figure('Vali/acc' , vali_acc_plot, self.global_epoch) 340 | self.writer.add_figure('Vali/loss', vali_loss_plot, self.global_epoch) 341 | self.writer.add_figure("Vis/confusion_matrix", cm_fig, self.global_epoch) 342 | self.writer.add_figure("Vis/ROC_curve", roc_fig, self.global_epoch) 343 | 344 | 345 | 346 | def __training_stage__ (self): 347 | """ 348 | 349 | """ 350 | 351 | while self.global_epoch < self.nepochs: 352 | 353 | # train model 354 | cur_train_loss , cur_train_acc = self.__train_model__() 355 | # validate model 356 | cur_vali_loss , cur_vali_acc , cm_fig, roc_fig = self.__eval_model__(split='val') 357 | 358 | if self.scheduler: 359 | 360 | self.scheduler.step(cur_vali_loss) 361 | 362 | self.__update_logs__(cur_train_loss, cur_train_acc, cur_vali_loss, cur_vali_acc, cm_fig, roc_fig) 363 | 364 | # Save model checkpoint 365 | save_checkpoint(self.model, self.optimizer, self.checkpoint_dir, self.global_epoch, self.save_name) 366 | # increment global epoch 367 | self.global_epoch +=1 368 | 369 | 370 | def start_training(self): 371 | """ 372 | 373 | Training Model 374 | 375 | """ 376 | print("Save name : {}".format(self.save_name)) 377 | print("Using CUDA : {} ".format(use_cuda)) 378 | 379 | if use_cuda: print ("Using {} GPU".format(torch.cuda.device_count())) 380 | 381 | print("Training dataset {}".format(len(self.train_dataset))) 382 | print("Validation dataset {}".format(len(self.vali_dataset))) 383 | print("Testing dataset {}".format(len(self.test_dataset))) 384 | 385 | print("Start training SyncNet") 386 | 387 | self.__training_stage__() 388 | 389 | print("Finish Trainig SyncNet") 390 | 391 | print(" Evaluating SyncNet with Test Set") 392 | 393 | # evaluate model on test set 394 | test_loss, test_acc, cm_fig, roc_fig, vec_emb = self.__eval_model__(split='test') 395 | 396 | self.__save_vec_emb__(vec_emb[0],vec_emb[1],vec_emb[2]) 397 | 398 | print("Testing Stage") 399 | print("Loss : {} , Acc : {} ".format(test_loss, test_acc)) 400 | # plot metrics for test set on tensorboard 401 | self.writer.add_figure("Test/confusion_matrix",cm_fig,0) 402 | self.writer.add_figure("Test/ROC_curve",roc_fig,0) 403 | self.writer.add_scalar("Test/loss",test_loss,0) 404 | self.writer.add_scalar("Test/acc",test_acc,0) 405 | 406 | self.writer.close() 407 | 408 | 409 | 410 | def __save_vec_emb__ (self, vis_emb, au_emb, y): 411 | 412 | 413 | 414 | #df = pd.DataFrame({'vis' : vis_emb.tolist(), 'au' : au_emb.tolist() ,'gt':y}) 415 | 416 | save_logs(vis = vis_emb.tolist(), 417 | au = au_emb.tolist(), 418 | gt = y, 419 | model_name="syncnet",savename='vec_emb_{}.csv'.format(self.save_name)) 420 | 421 | 422 | def start_testing(self): 423 | """ 424 | 425 | Testing model 426 | 427 | """ 428 | 429 | print("Using CUDA : {} ".format(use_cuda)) 430 | 431 | if use_cuda: print ("Using {} GPU".format(torch.cuda.device_count())) 432 | 433 | print("Testing dataset {}".format(len(self.test_dataset))) 434 | 435 | test_loss, test_acc, cm_fig, roc_fig, vec_emb = self.__eval_model__(split='test') 436 | 437 | self.__save_vec_emb__(vec_emb[0],vec_emb[1],vec_emb[2]) 438 | 439 | print("Testing Stage") 440 | print("Loss : {} , Acc : {} ".format(test_loss, test_acc)) 441 | 442 | 443 | 444 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /src/models/attnlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | class LstmGen(nn.Module): 8 | 9 | def __init__ (self): 10 | super(LstmGen, self).__init__() 11 | 12 | self.n_values = 60 13 | 14 | self.au_encoder = nn.Sequential( #input (1,80,18) 15 | ConvBlock(1, 64, kernel_size=(3,3),stride=1, padding=0), 16 | ResidualBlock(64,64, kernel_size=(3,3),stride=1, padding=1), 17 | 18 | ConvBlock(64,128, kernel_size=(5,3), stride=(3,1), padding=1), 19 | ResidualBlock(128,128, kernel_size=(3,3), stride=(1,1), padding=1), 20 | 21 | ConvBlock(128,256, kernel_size=(5,3), stride=(3,3), padding=0), 22 | ResidualBlock(256,256, kernel_size=(3,3), stride=(1,1), padding=1), 23 | 24 | ConvBlock(256,256, kernel_size=(3,3), stride=(3,3), padding=1), 25 | ResidualBlock(256,256, kernel_size=(3,3), stride=(1,1), padding=1), 26 | 27 | ConvBlock(256,256, kernel_size=(3,2), stride=(1,1), padding=0), 28 | 29 | nn.Flatten() 30 | ) 31 | 32 | 33 | self.lstm_encoder = Encoder(input_size=256, hidden_size=256, num_layers=4, dropout=0.25, bidirectional=True, batch_first=False) 34 | 35 | self.pe = PositionalEncoding(d_model=512, dropout=0.1, max_len=5) 36 | 37 | self.self_attn = nn.MultiheadAttention(512, num_heads=8) 38 | 39 | self.feed = nn.Sequential( 40 | LinearBlock(512+self.n_values, 256), 41 | LinearBlock(256, 128), 42 | LinearBlock(128, 60 , dropout=False, batchnorm=False, activation=False), 43 | ) 44 | 45 | 46 | def forward(self, au, lip): 47 | 48 | 49 | 50 | # AU input shape : (B, Seq, 1, 80 , 16) 51 | 52 | # inputs shape ( B , seq, 20 , 3) 53 | lip = lip.reshape(lip.size(0),lip.size(1),-1) # outshape(Seq, B , 60) 54 | 55 | # list for seq of extract features 56 | in_feats = [] 57 | # length of sequence 58 | seq_len = au.size(1) 59 | # batch_size 60 | batch_size = au.size(0) 61 | 62 | au = au.reshape(batch_size * seq_len , 1 , 80 , -1) # (Batchsize * seq , 1 , 80 (num mel) , segments ) 63 | 64 | in_feats = self.au_encoder(au) 65 | 66 | in_feats = in_feats.reshape(seq_len,batch_size, -1) 67 | 68 | lstm_outs , hidden, cell = self.lstm_encoder(in_feats) 69 | 70 | pos_out = self.pe(lstm_outs) 71 | 72 | 73 | attn_out, attn_weight = self.self_attn(pos_out,pos_out,pos_out)#[0] 74 | 75 | 76 | 77 | attn_out = attn_out.reshape(-1,attn_out.shape[-1]) 78 | 79 | lip = lip.reshape(-1,lip.shape[-1]) 80 | 81 | concat_input = torch.concat((attn_out,lip),dim=1) 82 | 83 | pred = self.feed(concat_input) 84 | 85 | pred = pred.reshape(batch_size, seq_len, self.n_values) 86 | 87 | 88 | return pred , lip 89 | 90 | 91 | class Encoder(nn.Module): 92 | 93 | def __init__ (self, input_size, hidden_size, num_layers, dropout,bidirectional=True,batch_first=False): 94 | 95 | super(Encoder,self).__init__() 96 | 97 | self.lstm = nn.LSTM( 98 | input_size=input_size, 99 | hidden_size=hidden_size, 100 | num_layers=num_layers, 101 | dropout=dropout, 102 | bidirectional=bidirectional, 103 | batch_first=batch_first 104 | ) 105 | 106 | def forward(self, inputs): 107 | 108 | out, (hidden, cell) = self.lstm(inputs) 109 | 110 | return out , hidden , cell 111 | 112 | 113 | class PositionalEncoding(nn.Module): 114 | """" 115 | Positional Encoding from Pytorch website 116 | """ 117 | 118 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 119 | super().__init__() 120 | self.dropout = nn.Dropout(p=dropout) 121 | 122 | position = torch.arange(max_len).unsqueeze(1) 123 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 124 | pe = torch.zeros(max_len, 1, d_model) 125 | pe[:, 0, 0::2] = torch.sin(position * div_term) 126 | pe[:, 0, 1::2] = torch.cos(position * div_term) 127 | self.register_buffer('pe', pe) 128 | 129 | def forward(self, x): 130 | """ 131 | Args: 132 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 133 | """ 134 | 135 | x = x + self.pe[:x.size(0)] 136 | return self.dropout(x) 137 | 138 | 139 | 140 | class LinearBlock(nn.Module): 141 | """ 142 | Custom Linear Layer block with regularization (Dropout and Batchnorm) and Activation function 143 | """ 144 | def __init__ (self, in_features, out_features, dropout=True, dropout_rate=0.2, batchnorm=True, activation=True): 145 | 146 | super().__init__() 147 | 148 | self.mlp = nn.Linear(in_features = in_features,out_features = out_features) # Linear Layer 149 | self.activation = nn.LeakyReLU(0.2) # activation function layer 150 | self.batchnorm = nn.BatchNorm1d(out_features) # Batch Normalization 1D layer 151 | self.do_dropout = dropout # perform dropout 152 | self.do_batchnorm = batchnorm # perform batchnorm 153 | self.do_activation = activation # perform activation 154 | self.dropout = nn.Dropout(dropout_rate) # Dropout rate 155 | 156 | def forward(self, x): 157 | """ 158 | forward propagation of this layer 159 | """ 160 | 161 | outs = self.mlp(x) 162 | 163 | 164 | if self.do_batchnorm: 165 | 166 | outs = self.batchnorm(outs) 167 | 168 | if self.do_activation: 169 | 170 | outs = self.activation(outs) 171 | 172 | if self.do_dropout: 173 | 174 | outs = self.dropout(outs) 175 | 176 | return outs 177 | 178 | 179 | 180 | class ConvBlock(nn.Module): 181 | """ 182 | Convolutional Layer (With batchnorm and activation) 183 | """ 184 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 185 | 186 | super().__init__() 187 | 188 | self.conv_layer = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 189 | nn.BatchNorm2d(out_channels), 190 | ) 191 | 192 | self.activation = nn.ReLU() 193 | 194 | def forward(self, inputs): 195 | 196 | cnn_out = self.conv_layer(inputs) 197 | cnn_out = self.activation(cnn_out) 198 | 199 | return cnn_out 200 | 201 | class ResidualBlock(nn.Module): 202 | """ 203 | Convolutional Layers with Residual connection 204 | """ 205 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 206 | 207 | super().__init__() 208 | 209 | self.conv_layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 210 | nn.BatchNorm2d(out_channels), 211 | ) 212 | 213 | 214 | self.conv_layer2 = nn.Sequential(nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 215 | nn.BatchNorm2d(out_channels), 216 | ) 217 | 218 | self.activation = nn.ReLU() 219 | 220 | 221 | def forward(self,x): 222 | 223 | residual = x 224 | # first conv layer 225 | out = self.activation(self.conv_layer1(x)) 226 | # second conv layer 227 | out = self.activation(self.conv_layer2(out)) 228 | # residual connection 229 | out = out + residual 230 | 231 | return out 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | -------------------------------------------------------------------------------- /src/models/image2image.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 4 | This file contains a code from ***MakeItTalk*** 5 | 6 | This code is originally from the project name MakeItTalk from "" at MakeItTalk repository 7 | 8 | 9 | Link :https://github.com/yzhou359/MakeItTalk 10 | 11 | 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | #import torch.nn.parallel 17 | #from torch.autograd import Variable 18 | import torch.nn.functional as F 19 | #from torchvision import models 20 | #import torch.utils.model_zoo as model_zoo 21 | #from torch.nn import init 22 | import os 23 | import numpy as np 24 | 25 | 26 | 27 | class ResUnetGenerator(nn.Module): 28 | 29 | """ 30 | 31 | Main Image2Image Translation Network 32 | 33 | """ 34 | 35 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 36 | norm_layer=nn.BatchNorm2d, use_dropout=False): 37 | super(ResUnetGenerator, self).__init__() 38 | # construct unet structure 39 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, 40 | innermost=True) 41 | 42 | for i in range(num_downs - 5): 43 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, 44 | norm_layer=norm_layer, use_dropout=use_dropout) 45 | unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, 46 | norm_layer=norm_layer) 47 | unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, 48 | norm_layer=norm_layer) 49 | unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, 50 | norm_layer=norm_layer) 51 | unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, 52 | norm_layer=norm_layer) 53 | 54 | self.model = unet_block 55 | 56 | def forward(self, input): 57 | 58 | output = self.model(input) 59 | 60 | return output 61 | 62 | 63 | 64 | 65 | 66 | # Defines the submodule with skip connection. 67 | # X -------------------identity---------------------- X 68 | # |-- downsampling -- |submodule| -- upsampling --| 69 | class ResUnetSkipConnectionBlock(nn.Module): 70 | 71 | """ 72 | 73 | Unet Layers with Residual Connection 74 | 75 | """ 76 | 77 | def __init__(self, outer_nc, inner_nc, input_nc=None, 78 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 79 | super(ResUnetSkipConnectionBlock, self).__init__() 80 | self.outermost = outermost 81 | use_bias = norm_layer == nn.InstanceNorm2d 82 | 83 | if input_nc is None: 84 | input_nc = outer_nc 85 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, 86 | stride=2, padding=1, bias=use_bias) 87 | # add two resblock 88 | res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] 89 | res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] 90 | 91 | # res_downconv = [ResidualBlock(inner_nc)] 92 | # res_upconv = [ResidualBlock(outer_nc)] 93 | 94 | downrelu = nn.ReLU(True) 95 | uprelu = nn.ReLU(True) 96 | if norm_layer != None: 97 | downnorm = norm_layer(inner_nc) 98 | upnorm = norm_layer(outer_nc) 99 | 100 | if outermost: 101 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 102 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 103 | down = [downconv, downrelu] + res_downconv 104 | # up = [uprelu, upsample, upconv, upnorm] 105 | up = [upsample, upconv] 106 | model = down + [submodule] + up 107 | elif innermost: 108 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 109 | upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 110 | down = [downconv, downrelu] + res_downconv 111 | if norm_layer == None: 112 | up = [upsample, upconv, uprelu] + res_upconv 113 | else: 114 | up = [upsample, upconv, upnorm, uprelu] + res_upconv 115 | model = down + up 116 | else: 117 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 118 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 119 | if norm_layer == None: 120 | down = [downconv, downrelu] + res_downconv 121 | up = [upsample, upconv, uprelu] + res_upconv 122 | else: 123 | down = [downconv, downnorm, downrelu] + res_downconv 124 | up = [upsample, upconv, upnorm, uprelu] + res_upconv 125 | 126 | if use_dropout: 127 | model = down + [submodule] + up + [nn.Dropout(0.5)] 128 | else: 129 | model = down + [submodule] + up 130 | 131 | self.model = nn.Sequential(*model) 132 | 133 | def forward(self, x): 134 | if self.outermost: 135 | return self.model(x) 136 | else: 137 | return torch.cat([x, self.model(x)], 1) 138 | 139 | 140 | 141 | 142 | 143 | 144 | class ResidualBlock(nn.Module): 145 | 146 | """ 147 | 148 | Residual Connection Layers 149 | 150 | """ 151 | 152 | def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): 153 | super(ResidualBlock, self).__init__() 154 | self.relu = nn.ReLU(True) 155 | if norm_layer == None: 156 | # hard to converge with out batch or instance norm 157 | self.block = nn.Sequential( 158 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 159 | nn.ReLU(inplace=True), 160 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 161 | ) 162 | else: 163 | self.block = nn.Sequential( 164 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 165 | norm_layer(in_features), 166 | nn.ReLU(inplace=True), 167 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 168 | norm_layer(in_features) 169 | ) 170 | 171 | def forward(self, x): 172 | residual = x 173 | out = self.block(x) 174 | out += residual 175 | out = self.relu(out) 176 | return out 177 | -------------------------------------------------------------------------------- /src/models/lstmgen.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import time 8 | 9 | 10 | class LstmGen(nn.Module): 11 | 12 | def __init__ (self, is3D=True): 13 | super(LstmGen, self).__init__() 14 | 15 | self.n_values = 60 16 | 17 | self.au_encoder = nn.Sequential( #input (1,80,18) 18 | ConvBlock(1, 64, kernel_size=(3,3),stride=1, padding=0), 19 | ResidualBlock(64,64, kernel_size=(3,3),stride=1, padding=1), 20 | 21 | ConvBlock(64,128, kernel_size=(5,3), stride=(3,1), padding=1), 22 | ResidualBlock(128,128, kernel_size=(3,3), stride=(1,1), padding=1), 23 | 24 | ConvBlock(128,256, kernel_size=(5,3), stride=(3,3), padding=0), 25 | ResidualBlock(256,256, kernel_size=(3,3), stride=(1,1), padding=1), 26 | 27 | ConvBlock(256,256, kernel_size=(3,3), stride=(3,3), padding=1), 28 | ResidualBlock(256,256, kernel_size=(3,3), stride=(1,1), padding=1), 29 | 30 | ConvBlock(256,256, kernel_size=(3,2), stride=(1,1), padding=0), 31 | 32 | nn.Flatten() 33 | ) 34 | 35 | 36 | self.lstm_encoder = Encoder(input_size=256, hidden_size=256, num_layers=4, dropout=0.25, bidirectional=True, batch_first=False) 37 | 38 | self.feed = nn.Sequential( 39 | LinearBlock(512+self.n_values, 256), 40 | LinearBlock(256, 128), 41 | LinearBlock(128, 60 , dropout=False, batchnorm=False, activation=False), 42 | ) 43 | 44 | def forward(self, au, lip): 45 | 46 | 47 | # AU input shape : (B, Seq, 1, 80 , 16) 48 | 49 | # inputs shape ( B , seq, 20 , 3) 50 | lip = lip.reshape(lip.size(0),lip.size(1),-1) # outshape(Seq, B , 60) 51 | 52 | # list for seq of extract features 53 | in_feats = [] 54 | # length of sequence 55 | seq_len = au.size(1) 56 | # batch_size 57 | batch_size = au.size(0) 58 | 59 | 60 | au = au.reshape(batch_size * seq_len , 1 , 80 , -1) # (Batchsize * seq , 1 , 80 (num mel) , segments ) 61 | 62 | in_feats = self.au_encoder(au) 63 | 64 | in_feats = in_feats.reshape(batch_size, seq_len, -1) 65 | 66 | 67 | lstm_out, hidden, cell = self.lstm_encoder(in_feats) 68 | 69 | """ 70 | # Out Shape of lstm encoder 71 | # 72 | # en : (B, seq, , hidden_size *2 )) 73 | # hidden : (num_lay * 2, B, hidden_size) 74 | # cell : (num_lay * 2 , B , hidden_size) 75 | """ 76 | 77 | 78 | # concat feature with lip 79 | concat_input =torch.cat((lstm_out,lip),dim=2) 80 | 81 | concat_input = concat_input.reshape(-1,concat_input.size(2)) # (B * seq , 1064) 82 | 83 | outs = self.feed(concat_input) 84 | 85 | outs = outs.reshape(batch_size ,seq_len, -1) 86 | 87 | return outs , lip 88 | 89 | 90 | class Encoder(nn.Module): 91 | 92 | def __init__ (self, input_size, hidden_size, num_layers, dropout,bidirectional=True,batch_first=False): 93 | 94 | super(Encoder,self).__init__() 95 | 96 | self.lstm = nn.LSTM( 97 | input_size=input_size, 98 | hidden_size=hidden_size, 99 | num_layers=num_layers, 100 | dropout=dropout, 101 | bidirectional=bidirectional, 102 | batch_first=batch_first 103 | ) 104 | 105 | def forward(self, inputs): 106 | 107 | out, (hidden, cell) = self.lstm(inputs) 108 | 109 | return out , hidden , cell 110 | 111 | 112 | class LinearBlock(nn.Module): 113 | """ 114 | Custom Linear Layer block with regularization (Dropout and Batchnorm) and Activation function 115 | """ 116 | def __init__ (self, in_features, out_features, dropout=True, dropout_rate=0.2, batchnorm=True, activation=True): 117 | 118 | super().__init__() 119 | 120 | self.mlp = nn.Linear(in_features = in_features,out_features = out_features) # Linear Layer 121 | self.activation = nn.LeakyReLU(0.2) # activation function layer 122 | self.batchnorm = nn.BatchNorm1d(out_features) # Batch Normalization 1D layer 123 | self.do_dropout = dropout # perform dropout 124 | self.do_batchnorm = batchnorm # perform batchnorm 125 | self.do_activation = activation # perform activation 126 | self.dropout = nn.Dropout(dropout_rate) # Dropout rate 127 | 128 | def forward(self, x): 129 | """ 130 | forward propagation of this layer 131 | """ 132 | 133 | outs = self.mlp(x) 134 | 135 | 136 | if self.do_batchnorm: 137 | 138 | outs = self.batchnorm(outs) 139 | 140 | if self.do_activation: 141 | 142 | outs = self.activation(outs) 143 | 144 | if self.do_dropout: 145 | 146 | outs = self.dropout(outs) 147 | 148 | return outs 149 | 150 | 151 | 152 | class ConvBlock(nn.Module): 153 | """ 154 | Convolutional Layer (With batchnorm and activation) 155 | """ 156 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 157 | 158 | super().__init__() 159 | 160 | self.conv_layer = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 161 | nn.BatchNorm2d(out_channels), 162 | ) 163 | 164 | self.activation = nn.ReLU() 165 | 166 | def forward(self, inputs): 167 | 168 | cnn_out = self.conv_layer(inputs) 169 | cnn_out = self.activation(cnn_out) 170 | 171 | return cnn_out 172 | 173 | class ResidualBlock(nn.Module): 174 | """ 175 | Convolutional Layers with Residual connection 176 | """ 177 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 178 | 179 | super().__init__() 180 | 181 | self.conv_layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 182 | nn.BatchNorm2d(out_channels), 183 | ) 184 | 185 | 186 | self.conv_layer2 = nn.Sequential(nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 187 | nn.BatchNorm2d(out_channels), 188 | ) 189 | 190 | self.activation = nn.ReLU() 191 | 192 | 193 | def forward(self,x): 194 | 195 | residual = x 196 | # first conv layer 197 | out = self.activation(self.conv_layer1(x)) 198 | # second conv layer 199 | out = self.activation(self.conv_layer2(out)) 200 | # residual connection 201 | out = out + residual 202 | 203 | return out 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /src/models/syncnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.plot import plot_scatter_facial_landmark 4 | """ 5 | 6 | SyncNet version for lip landmarks 7 | 8 | """ 9 | class SyncNet(nn.Module): 10 | def __init__(self,bilstm=True): 11 | super(SyncNet,self).__init__() 12 | 13 | self.bilstm = bilstm 14 | self.lip_hidden = 128 15 | self.n_values = 60 16 | 17 | self.audio_encoder = nn.Sequential( #input (1,80,18) 18 | ConvBlock(1, 64, kernel_size=(3,3),stride=1, padding=0), # (78,16) 19 | ResidualBlock(64,64, kernel_size=(3,3),stride=1, padding=1,), 20 | ResidualBlock(64,64, kernel_size=(3,3),stride=1, padding=1,), 21 | 22 | ConvBlock(64,128, kernel_size=(5,3), stride=(3,1), padding=1,), # (26,16) 23 | ResidualBlock(128,128, kernel_size=(3,3), stride=(1,1), padding=1,), 24 | ResidualBlock(128,128, kernel_size=(3,3), stride=(1,1), padding=1,), 25 | 26 | ConvBlock(128,256, kernel_size=(5,3), stride=(3,3), padding=0,), # (8,5) 27 | ResidualBlock(256,256, kernel_size=(3,3), stride=(1,1), padding=1,), 28 | ResidualBlock(256,256, kernel_size=(3,3), stride=(1,1), padding=1,), 29 | 30 | ConvBlock(256,512, kernel_size=(3,3), stride=(3,3), padding=1,), # (3,2) 31 | ConvBlock(512,512, kernel_size=(3,2), stride=(1,1), padding=0,), # (3,2) 32 | nn.Flatten() 33 | ) 34 | 35 | 36 | 37 | self.lip_encoder = nn.Sequential( 38 | LinearBlock(self.n_values, 512), 39 | nn.Dropout(0.2), 40 | LinearBlock(512,256), 41 | ) 42 | 43 | 44 | self.visual_encoder =nn.LSTM(input_size=256, 45 | hidden_size=self.lip_hidden, 46 | num_layers=4, #2, 47 | batch_first=True, 48 | bidirectional=bilstm, 49 | ) 50 | 51 | 52 | self.lip_size = 2 * self.lip_hidden if self.bilstm else self.lip_hidden # output hidden size of lip lstm 53 | 54 | 55 | self.lip_fc = nn.Sequential( 56 | 57 | LinearBlock(self.lip_size, 1024), 58 | nn.Dropout(0.2), 59 | LinearBlock(1024, 512), 60 | ) 61 | 62 | def forward(self, audio, lip): 63 | """ 64 | forward propagation of this layer 65 | """ 66 | 67 | # lip shape (batch,seq,60) 68 | # audio shape (batch,1,80,18) 69 | 70 | # extract features from melspectrogram 71 | au = self.audio_encoder(audio) 72 | 73 | 74 | 75 | lip_seq = lip.shape[1] 76 | batch_size= lip.shape[0] 77 | lip = lip.reshape(-1,self.n_values) 78 | 79 | # extract features from landmarks 80 | lip = self.lip_encoder(lip) 81 | lip = lip.reshape(batch_size,lip_seq, -1) 82 | 83 | # pass extracted lip features to BiLSTM 84 | vis_hidden, _ =self.visual_encoder(lip) 85 | 86 | # last hidden layers of lstm 87 | vis_hidden = vis_hidden[:,-1,:] 88 | 89 | # embeddings 90 | lip = self.lip_fc(vis_hidden) 91 | #au = self.audio_fc(au) 92 | 93 | # apply Euclidean(L2) norm 94 | au = nn.functional.normalize(au, p=2, dim=1) 95 | lip = nn.functional.normalize(lip, p=2, dim=1) 96 | 97 | return au, lip 98 | 99 | 100 | class LinearBlock(nn.Module): 101 | """ 102 | Custom Linear Layer block with regularization (Dropout and Batchnorm) and Activation function 103 | """ 104 | def __init__ (self, in_features, out_features,): 105 | 106 | super().__init__() 107 | 108 | self.linear_layer = nn.Sequential( 109 | nn.Linear(in_features,out_features), 110 | nn.BatchNorm1d(out_features), 111 | ) 112 | 113 | self.activation = nn.ReLU() 114 | 115 | 116 | def forward(self, x): 117 | """ 118 | forward propagation of this layer 119 | """ 120 | outs = self.activation(self.linear_layer(x)) 121 | 122 | return outs 123 | 124 | 125 | 126 | class ConvBlock(nn.Module): 127 | """ 128 | Convolutional Layer (With batchnorm and activation) 129 | """ 130 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dropout=False, dropout_rate=0.2): 131 | 132 | super().__init__() 133 | 134 | self.conv_layer = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 135 | nn.BatchNorm2d(out_channels), 136 | ) 137 | 138 | self.activation = nn.ReLU() 139 | 140 | self.dropout = nn.Dropout2d(dropout_rate) if dropout else None 141 | 142 | 143 | def forward(self, inputs): 144 | """ 145 | forward propagation of this layer 146 | """ 147 | cnn_out = self.conv_layer(inputs) 148 | cnn_out = self.activation(cnn_out) 149 | 150 | 151 | if self.dropout: 152 | 153 | cnn_out = self.dropout(cnn_out) 154 | 155 | return cnn_out 156 | 157 | class ResidualBlock(nn.Module): 158 | """ 159 | Convolutional Layers with Residual connection 160 | """ 161 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dropout=False, dropout_rate=0.2): 162 | 163 | super().__init__() 164 | 165 | self.conv_layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 166 | nn.BatchNorm2d(out_channels), 167 | ) 168 | 169 | 170 | self.conv_layer2 = nn.Sequential(nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 171 | nn.BatchNorm2d(out_channels), 172 | ) 173 | 174 | self.activation = nn.ReLU() 175 | self.dropout = nn.Dropout2d(dropout_rate) if dropout else None 176 | 177 | 178 | def forward(self,x): 179 | """ 180 | forward propagation of this layer 181 | """ 182 | residual = x 183 | # first conv layer 184 | out = self.activation(self.conv_layer1(x)) 185 | # second conv layer 186 | out = self.activation(self.conv_layer2(out)) 187 | # residual connection 188 | out = out + residual 189 | 190 | if self.dropout: 191 | 192 | out = self.dropout(out) 193 | 194 | return out 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | -------------------------------------------------------------------------------- /temp/README.md: -------------------------------------------------------------------------------- 1 | This folder shall contain temporary file for model inference results 2 | -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | This file is from ***Wav2Lip*** 4 | 5 | Link: https://github.com/Rudrabha/Wav2Lip 6 | 7 | 8 | """ 9 | import librosa 10 | import librosa.filters 11 | import numpy as np 12 | # import tensorflow as tf 13 | from scipy import signal 14 | from scipy.io import wavfile 15 | from hparams import hparams as hp 16 | 17 | def load_wav(path, sr): 18 | return librosa.core.load(path, sr=sr)[0] 19 | 20 | def save_wav(wav, path, sr): 21 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 22 | #proposed by @dsmiller 23 | wavfile.write(path, sr, wav.astype(np.int16)) 24 | 25 | def save_wavenet_wav(wav, path, sr): 26 | librosa.output.write_wav(path, wav, sr=sr) 27 | 28 | def preemphasis(wav, k, preemphasize=True): 29 | if preemphasize: 30 | return signal.lfilter([1, -k], [1], wav) 31 | return wav 32 | 33 | def inv_preemphasis(wav, k, inv_preemphasize=True): 34 | if inv_preemphasize: 35 | return signal.lfilter([1], [1, -k], wav) 36 | return wav 37 | 38 | def get_hop_size(): 39 | hop_size = hp.hop_size 40 | if hop_size is None: 41 | assert hp.frame_shift_ms is not None 42 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 43 | return hop_size 44 | 45 | def linearspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | def melspectrogram(wav): 54 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 55 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 56 | 57 | if hp.signal_normalization: 58 | return _normalize(S) 59 | return S 60 | 61 | def _lws_processor(): 62 | import lws 63 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 64 | 65 | def _stft(y): 66 | if hp.use_lws: 67 | 68 | print('adasda') 69 | return _lws_processor(hp).stft(y).T 70 | else: 71 | 72 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 73 | 74 | ########################################################## 75 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 76 | def num_frames(length, fsize, fshift): 77 | """Compute number of time frames of spectrogram 78 | """ 79 | pad = (fsize - fshift) 80 | if length % fshift == 0: 81 | M = (length + pad * 2 - fsize) // fshift + 1 82 | else: 83 | M = (length + pad * 2 - fsize) // fshift + 2 84 | return M 85 | 86 | 87 | def pad_lr(x, fsize, fshift): 88 | """Compute left and right padding 89 | """ 90 | M = num_frames(len(x), fsize, fshift) 91 | pad = (fsize - fshift) 92 | T = len(x) + 2 * pad 93 | r = (M - 1) * fshift + fsize - T 94 | return pad, pad + r 95 | ########################################################## 96 | #Librosa correct padding 97 | def librosa_pad_lr(x, fsize, fshift): 98 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 99 | 100 | # Conversions 101 | _mel_basis = None 102 | 103 | def _linear_to_mel(spectogram): 104 | global _mel_basis 105 | if _mel_basis is None: 106 | _mel_basis = _build_mel_basis() 107 | return np.dot(_mel_basis, spectogram) 108 | 109 | def _build_mel_basis(): 110 | assert hp.fmax <= hp.sample_rate // 2 111 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, 112 | fmin=hp.fmin, fmax=hp.fmax) 113 | 114 | def _amp_to_db(x): 115 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 116 | return 20 * np.log10(np.maximum(min_level, x)) 117 | 118 | def _db_to_amp(x): 119 | return np.power(10.0, (x) * 0.05) 120 | 121 | def _normalize(S): 122 | if hp.allow_clipping_in_normalization: 123 | if hp.symmetric_mels: 124 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 125 | -hp.max_abs_value, hp.max_abs_value) 126 | else: 127 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 128 | 129 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 130 | if hp.symmetric_mels: 131 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 132 | else: 133 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 134 | 135 | def _denormalize(D): 136 | if hp.allow_clipping_in_normalization: 137 | if hp.symmetric_mels: 138 | return (((np.clip(D, -hp.max_abs_value, 139 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 140 | + hp.min_level_db) 141 | else: 142 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 143 | 144 | if hp.symmetric_mels: 145 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 146 | else: 147 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 148 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class ContrastiveLoss(torch.nn.Module): 6 | 7 | def __init__(self, margin=2.0, metric='euclidean'): 8 | 9 | super(ContrastiveLoss, self).__init__() 10 | 11 | self.margin = margin 12 | 13 | self.metric = metric 14 | 15 | def forward(self,x1,x2,y): 16 | 17 | # calculate Euclidean distance 18 | distance = torch.nn.functional.pairwise_distance(x1,x2, p=2) # p=2 is norm degree for Euclidean distance 19 | 20 | term1 = (1-y) * torch.pow(distance,2) # 21 | 22 | term2 = (y) * torch.pow(torch.clamp(self.margin-distance, min=0),2) 23 | 24 | loss = torch.mean(term1 + term2) 25 | 26 | pred = np.array([ 1 if i > self.margin else 0 for i in distance ]) 27 | label = y.detach().clone().cpu().numpy().reshape(-1).astype(int) 28 | acc = np.mean([1 if p == l else 0 for p,l in zip(pred,label) ]) * 100 29 | 30 | 31 | 32 | return loss , acc 33 | 34 | class CosineBCELoss(torch.nn.Module): 35 | """ 36 | This loss was proposed by Wav2lip 37 | 38 | """ 39 | 40 | def __init__(self): 41 | super(CosineBCELoss, self).__init__() 42 | 43 | self.bce = torch.nn.BCELoss() 44 | 45 | 46 | def forward(self,x1,x2,y): 47 | 48 | # Get similarity score 49 | sim_score = torch.nn.functional.cosine_similarity(x1,x2) 50 | 51 | # compute BCE loss 52 | loss = self.bce(sim_score.unsqueeze(1),y) 53 | 54 | # get prediction 55 | pred = np.array([1 if i >0.5 else 0 for i in sim_score]) 56 | # clone gt from tensor 57 | label = y.detach().clone().cpu().numpy().reshape(-1).astype(int) 58 | # calculate accuracy 59 | acc = np.mean([1 if p == l else 0 for p,l in zip(pred,label) ]) * 100 60 | 61 | return loss, acc , pred, sim_score.detach().clone().cpu().numpy() 62 | 63 | class WingLoss(torch.nn.Module): 64 | 65 | 66 | def __init__(self, w, eps ): 67 | 68 | super(WingLoss,self).__init__() 69 | 70 | self.W = w 71 | 72 | self.E = eps 73 | 74 | self.C = self.W - (self.W * np.log(1+ self.W/self.E)) 75 | 76 | 77 | def forward(self,pred,target): 78 | 79 | diff = torch.abs(target-pred) # absolute value of differences 80 | 81 | first = diff[diff < self.W] # value that fall under first condition 82 | 83 | second = diff[diff>= self.W] # value that fall under second condition 84 | 85 | term1 = self.W * torch.log(1+ (first/self.E)) 86 | 87 | term2 = second - self.C 88 | 89 | N = len(term1) + len(term2) 90 | 91 | loss = torch.sum(term1) + torch.sum(term2) 92 | 93 | loss = loss / N 94 | 95 | return loss 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.metrics import confusion_matrix 6 | from sklearn.metrics import roc_curve 7 | from sklearn.metrics import auc 8 | import cv2 9 | 10 | 11 | def plot_single(value, label, name): 12 | """ 13 | *********************************************************** 14 | plot_single : Plot a single line plot for metrics (accuracy or loss) 15 | *********************************************************** 16 | @author: Wish Suharitdarmong 17 | ------ 18 | inputs 19 | ------ 20 | vali : array containg a log value 21 | ------- 22 | returns 23 | ------- 24 | figure : Single line plot figure 25 | """ 26 | fig , (ax) = plt.subplots(nrows=1,ncols=1) 27 | 28 | ax.plot(value, color='r', label=label, linestyle='-') 29 | if "loss" in name.lower(): 30 | ax.legend(loc="upper right") 31 | if "accuracy" in name.lower(): 32 | ax.legend(loc="lower right") 33 | ax.set_title("{}".format(name)) 34 | figure = ax.get_figure() 35 | plt.close(fig) 36 | 37 | return figure 38 | 39 | 40 | def plot_comp(train,vali, name): 41 | """ 42 | ************************************************************ 43 | plot_comp : Plot comparision between training and evaluation 44 | ************************************************************ 45 | @author: Wish Suharitdamrong 46 | ------ 47 | inputs 48 | ------ 49 | train : array containing a training logs value (accuracy or loss) 50 | vali : array containg a evaluation log value 51 | ------- 52 | returns 53 | ------- 54 | figure : Plot comparision figure 55 | """ 56 | 57 | 58 | fig, (ax) = plt.subplots(nrows=1,ncols=1) 59 | ax.plot(train, color='r' , label='Train', linestyle='-') 60 | ax.plot(vali, color='b', label='Validation', linestyle='-') 61 | if "loss" in name.lower(): 62 | ax.legend(loc="upper right") 63 | if "accuracy" in name.lower(): 64 | ax.legend(loc="lower right") 65 | ax.set_title("{}".format(name)) 66 | figure = ax.get_figure() 67 | plt.close(fig) 68 | 69 | return figure 70 | 71 | 72 | 73 | 74 | 75 | 76 | def plot_cm(y_pred_label,y_gt): 77 | """ 78 | **************************** 79 | Plot Confusion Matrix Figure 80 | **************************** 81 | @author: Wish Suharitdamrong 82 | ---------- 83 | parameters 84 | ---------- 85 | y_pred_label : labels predicted from a model 86 | y_gt : grouth truth of each labels 87 | ------- 88 | returns 89 | ------- 90 | dia_cm : Figure of Confusion Matrix 91 | """ 92 | 93 | df_cols_rows = ["Positive", "Negative"] 94 | 95 | cm = confusion_matrix(y_gt, y_pred_label,normalize="all") 96 | cm_df = pd.DataFrame(cm , index = df_cols_rows , columns = df_cols_rows) 97 | 98 | # Plot Fig 99 | fig, (ax) = plt.subplots(nrows=1, ncols=1) 100 | ax.set_title("Confusion Maxtrix") 101 | dia_cm = sns.heatmap(cm_df, annot =True).get_figure() 102 | plt.close(fig) 103 | 104 | return dia_cm 105 | 106 | 107 | 108 | 109 | def plot_roc(y_pred_proba,y_gt): 110 | """ 111 | ************************************************** 112 | Receiver Operating Characteristic (ROC) Curve plot 113 | ************************************************** 114 | @authors: Wish Suharitdamrong 115 | --------- 116 | parameter 117 | --------- 118 | y_pred_proba : Probabilities of each labels predicted from a model 119 | y_gt : Ground truth of eachs labels 120 | ------- 121 | returns 122 | ------- 123 | dia_roc : ROC Curve Figure 124 | """ 125 | 126 | # Calculate False and True positive rate 127 | fpr, tpr, threshold = roc_curve(y_gt, y_pred_proba) 128 | 129 | # Calculate area under a curve 130 | roc_auc = auc(fpr, tpr) 131 | 132 | # Plot Fig 133 | fig, (ax) = plt.subplots(nrows=1, ncols=1) 134 | # Plot curve 135 | line1 = ax.plot(fpr, tpr, 'b', label= "AUC = {:.2f}".format(roc_auc)) 136 | # Plot Threshold 137 | line2 = ax.plot([0,1],[0,1],'--r', label = "Random Classifier") 138 | ax.legend(loc="lower right") 139 | ax.set_title("Receiving Operating Characteristic (ROC) Curve") 140 | ax.set_xlim([0,1]) 141 | ax.set_ylim([0,1]) 142 | ax.set_xlabel("False Positive Rate") 143 | ax.set_ylabel("True Positive Rate") 144 | dia_roc = ax.get_figure() 145 | plt.close(fig) 146 | 147 | return dia_roc 148 | 149 | # 150 | # 151 | #def plot_compareLip(predict,gt_lip): 152 | # """ 153 | # ************************************************************ 154 | # plot_compareLip : Plot comparision between the generated lip and actual lip 155 | # ************************************************************ 156 | # @author: Wish Suharitdamrong 157 | # ------ 158 | # inputs 159 | # ------ 160 | # predict : predicted lip in form of torch tensor 161 | # gt_lip : ground truht lip in form of torch tensor 162 | # ------- 163 | # returns 164 | # ------- 165 | # com_fig : comparision figure 166 | # """ 167 | # 168 | # 169 | # # convert torch tensor into numpy array 170 | # lip_pred = predict.reshape(predict.shape[0], predict.shape[1], 20, -1)[0, 0].detach().clone().cpu().numpy() 171 | # fl = gt_lip.reshape(predict.shape[0], predict.shape[1], 20, -1)[0, 0].detach().clone().cpu().numpy() 172 | # 173 | # fig, (ax) = plt.subplots(nrows=1, ncols=1) 174 | # 175 | # # Plot Generated Lip 176 | # ax.scatter(lip_pred[:, 0], lip_pred[:, 1], s=20, c='r', linewidths=2) 177 | # ax.plot(lip_pred[0:7, 0], lip_pred[0:7, 1], c="tab:red", linewidth=3 , label='Generated') 178 | # ax.plot(np.append(lip_pred[6:12, 0], lip_pred[0, 0]), np.append(lip_pred[6:12, 1], lip_pred[0, 1]), c="tab:red", linewidth=3) 179 | # ax.plot(lip_pred[12:17, 0], lip_pred[12:17, 1], c="tab:red", linewidth=3) 180 | # ax.plot(np.append(lip_pred[16:20, 0], lip_pred[12, 0]), np.append(lip_pred[16:20, 1], lip_pred[12, 1]), c="tab:red", linewidth=3) 181 | # 182 | # # Plot Ground Truth Lip 183 | # ax.scatter(fl[:, 0], fl[:, 1], s=20, c='g', linewidths=2) 184 | # ax.plot(fl[0:7, 0], fl[0:7, 1], c="tab:blue", linewidth=3, label='Ground Truth') 185 | # ax.plot(np.append(fl[6:12, 0], fl[0, 0]), np.append(fl[6:12, 1], fl[0, 1]), c="tab:blue", linewidth=3) 186 | # ax.plot(fl[12:17, 0], fl[12:17, 1], c="tab:blue", linewidth=3) 187 | # ax.plot(np.append(fl[16:20, 0], fl[12, 0]), np.append(fl[16:20, 1], fl[12, 1]), c="tab:blue", linewidth=3) 188 | # ax.legend(loc="upper left") 189 | # ax.invert_yaxis() 190 | # com_fig = ax.get_figure() 191 | # plt.close(fig) 192 | # 193 | # return com_fig 194 | # 195 | # 196 | 197 | def plot_lip_landmark(fl): 198 | """ 199 | ************************************************************ 200 | plot_visLip : Visualize lip landmark mark 201 | ************************************************************ 202 | @author: Wish Suharitdamrong 203 | ------ 204 | inputs 205 | ------ 206 | fl : lip landmark in form of numpy array 207 | ------- 208 | returns 209 | ------- 210 | com_fig : comparision figure 211 | """ 212 | 213 | # invert yaxis value so that the lip is wont display upside down 214 | fig, ax = plt.subplots() 215 | ax.scatter(fl[:,0],fl[:,1],s=20, c='r',linewidths=4) 216 | ax.plot(fl[0:7,0],fl[0:7,1], c="tab:pink", linewidth=3 ) 217 | ax.plot(np.append(fl[6:12,0],fl[0,0]),np.append(fl[6:12,1],fl[0,1]), c="tab:pink", linewidth=3 ) 218 | ax.plot(fl[12:17,0],fl[12:17,1], c="tab:pink", linewidth=3 ) 219 | ax.plot(np.append(fl[16:20,0],fl[12,0]),np.append(fl[16:20,1],fl[12,1] ), c="tab:pink", linewidth=3) 220 | ax.invert_yaxis() 221 | plt.show() 222 | 223 | 224 | 225 | def plot_scatter_facial_landmark(fl): 226 | 227 | fig = plt.figure() 228 | ax = fig.add_subplot() 229 | ax.scatter(fl[:,0],fl[:,1],s=20,c='r') 230 | ax.invert_yaxis() 231 | fig.tight_layout() 232 | plt.show() 233 | 234 | 235 | def plot_lip_comparision(pred,ref,gt): 236 | """ 237 | ************************************************************ 238 | plot_seqlip_comp : Visualize gena 239 | ************************************************************ 240 | @author: Wish Suharitdamrong 241 | ------ 242 | inputs 243 | ------ 244 | fl : lip landmark in form of numpy array 245 | ------- 246 | returns 247 | ------- 248 | com_fig : comparision figure 249 | """ 250 | 251 | fig, (ax1, ax2) = plt.subplots(nrows=1 , ncols=2, figsize=(15,6)) 252 | 253 | fl = pred 254 | ax1.scatter(fl[:,0],fl[:,1],s=20, c='r',linewidths=4) 255 | ax1.plot(fl[0:7,0],fl[0:7,1], c="tab:red", linewidth=3 , label='Generated') 256 | ax1.plot(np.append(fl[6:12,0],fl[0,0]),np.append(fl[6:12,1],fl[0,1]), c="tab:red", linewidth=3 ) 257 | ax1.plot(fl[12:17,0],fl[12:17,1], c="tab:red", linewidth=3 ) 258 | ax1.plot(np.append(fl[16:20,0],fl[12,0]),np.append(fl[16:20,1],fl[12,1] ), c="tab:red", linewidth=3) 259 | 260 | fl = ref 261 | ax1.scatter(fl[:,0],fl[:,1],s=20, c='b',linewidths=4) 262 | ax1.plot(fl[0:7,0],fl[0:7,1], c="tab:cyan", linewidth=3 , label='Reference') 263 | ax1.plot(np.append(fl[6:12,0],fl[0,0]),np.append(fl[6:12,1],fl[0,1]), c="tab:cyan", linewidth=3 ) 264 | ax1.plot(fl[12:17,0],fl[12:17,1], c="tab:cyan", linewidth=3 ) 265 | ax1.plot(np.append(fl[16:20,0],fl[12,0]),np.append(fl[16:20,1],fl[12,1] ), c="tab:cyan", linewidth=3) 266 | ax1.legend(loc="upper left") 267 | ax1.invert_yaxis() 268 | 269 | ################################################################################# 270 | 271 | fl = pred 272 | ax2.scatter(fl[:,0],fl[:,1],s=20, c='r',linewidths=4) 273 | ax2.plot(fl[0:7,0],fl[0:7,1], c="tab:red", linewidth=3 , label='Generated') 274 | ax2.plot(np.append(fl[6:12,0],fl[0,0]),np.append(fl[6:12,1],fl[0,1]), c="tab:red", linewidth=3 ) 275 | ax2.plot(fl[12:17,0],fl[12:17,1], c="tab:red", linewidth=3 ) 276 | ax2.plot(np.append(fl[16:20,0],fl[12,0]),np.append(fl[16:20,1],fl[12,1] ), c="tab:red", linewidth=3) 277 | 278 | fl = gt 279 | ax2.scatter(fl[:,0],fl[:,1],s=20, c='g',linewidths=4) 280 | ax2.plot(fl[0:7,0],fl[0:7,1], c="tab:green", linewidth=3 , label='Ground Truth') 281 | ax2.plot(np.append(fl[6:12,0],fl[0,0]),np.append(fl[6:12,1],fl[0,1]), c="tab:green", linewidth=3 ) 282 | ax2.plot(fl[12:17,0],fl[12:17,1], c="tab:green", linewidth=3 ) 283 | ax2.plot(np.append(fl[16:20,0],fl[12,0]),np.append(fl[16:20,1],fl[12,1] ), c="tab:green", linewidth=3) 284 | 285 | 286 | ax2.legend(loc="upper left") 287 | ax2.invert_yaxis() 288 | 289 | 290 | seqlip_fig = ax1.get_figure() 291 | 292 | plt.close(fig) 293 | 294 | return seqlip_fig 295 | 296 | 297 | # Visualize Facial Landmark (Whole Face) 298 | def vis_landmark_on_img(img, shape, linewidth=2): 299 | 300 | """ 301 | 302 | @author: This function is originally from ***MakeItTalk*** 303 | 304 | Link: https://github.com/yzhou359/MakeItTalk 305 | 306 | """ 307 | def draw_curve(idx_list, color=(0, 255, 0), loop=False, lineWidth=linewidth): 308 | for i in idx_list: 309 | cv2.line(img, (shape[i, 0], shape[i, 1]), (shape[i + 1, 0], shape[i + 1, 1]), color, lineWidth) 310 | if (loop): 311 | cv2.line(img, (shape[idx_list[0], 0], shape[idx_list[0], 1]), 312 | (shape[idx_list[-1] + 1, 0], shape[idx_list[-1] + 1, 1]), color, lineWidth) 313 | 314 | draw_curve(list(range(0, 16)), color=(255, 144, 25)) # jaw 315 | draw_curve(list(range(17, 21)), color=(50, 205, 50)) # eye brow 316 | draw_curve(list(range(22, 26)), color=(50, 205, 50)) 317 | draw_curve(list(range(27, 35)), color=(208, 224, 63)) # nose 318 | draw_curve(list(range(36, 41)), loop=True, color=(71, 99, 255)) # eyes 319 | draw_curve(list(range(42, 47)), loop=True, color=(71, 99, 255)) 320 | draw_curve(list(range(48, 59)), loop=True, color=(238, 130, 238)) # mouth 321 | draw_curve(list(range(60, 67)), loop=True, color=(238, 130, 238)) 322 | 323 | return img 324 | 325 | 326 | 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | ****This file contain utlities functions**** 4 | 5 | @author : Wish Suharitdamrong 6 | 7 | """ 8 | 9 | import os 10 | import pandas as pd 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import argparse 14 | import math 15 | from sklearn.metrics import accuracy_score 16 | 17 | 18 | def save_logs(model_name,savename,**kwargs): 19 | """ 20 | ********* 21 | save_logs : Save a logs in .csv files 22 | ********* 23 | @author: Wish Suharitdamrong 24 | ------ 25 | inputs : 26 | ------ 27 | model_name : name of a model 28 | savename : name of a file saving as a .csv 29 | **kwagrs : array containing values of metrics such as accuracy and loss 30 | ------- 31 | outputs : 32 | ------- 33 | """ 34 | 35 | path = "./logs/{}/".format(model_name) 36 | 37 | if not os.path.exists(path): 38 | os.mkdir(path) 39 | 40 | df = pd.DataFrame() 41 | for log in kwargs: 42 | 43 | df[log] = kwargs[log] 44 | 45 | savepath = os.path.join(path,savename) 46 | 47 | df.to_csv(savepath, index=False) 48 | 49 | 50 | def load_logs(model_name,savename, epoch, type_model=None): 51 | """ 52 | ********* 53 | load_logs : Load a logs file from csv 54 | ********* 55 | @author: Wish Suharitdamrong 56 | ------ 57 | inputs : 58 | ------ 59 | model_name : name of a model 60 | savename : name of a logs in .csv files 61 | epoch : number of iteration continue from checkpoint 62 | ------- 63 | outputs : 64 | ------- 65 | train_loss : array containing training loss 66 | train_acc : array containing training accuracy 67 | vali_loss : array containing validation loss 68 | vali_acc : array containing validation accuracy 69 | """ 70 | 71 | path = "./logs/{}/".format(model_name) 72 | 73 | savepath = os.path.join(path,savename) 74 | 75 | if type_model is None: 76 | 77 | raise ValueError("Type of model should be specified Generator or SyncNet") 78 | 79 | if not os.path.exists(savepath): 80 | 81 | print("Logs file does not exists !!!!") 82 | 83 | exit() 84 | 85 | df = pd.read_csv(savepath)[:epoch+1] 86 | 87 | if type_model == "syncnet": 88 | 89 | train_loss = df["train_loss"] 90 | train_acc = df['train_acc'] 91 | vali_loss = df['vali_loss'] 92 | vali_acc = df['vali_acc'] 93 | 94 | 95 | return train_loss, train_acc, vali_loss, vali_acc 96 | 97 | elif type_model == "generator": 98 | 99 | train_loss = df["train_loss"] 100 | vali_loss = df["vali_loss"] 101 | 102 | return train_loss, vali_loss 103 | 104 | else : 105 | 106 | 107 | raise ValueError(" Argument type of model (type_model) should be either 'generator' or 'syncnet' !!!!") 108 | 109 | 110 | 111 | 112 | 113 | def str2bool(v): 114 | """ 115 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 116 | """ 117 | if isinstance(v, bool): 118 | return v 119 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 120 | return True 121 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 122 | return False 123 | else: 124 | raise argparse.ArgumentTypeError('Boolean value expected.') 125 | 126 | 127 | 128 | def norm_lip2d(fl_lip, distance=None): 129 | 130 | if distance is None: 131 | 132 | distance_x = fl_lip[6,0] - fl_lip[0,0] 133 | distance_y = fl_lip[6,1] - fl_lip[0,1] 134 | 135 | distance = math.sqrt(pow(distance_x,2)+pow(distance_y,2)) 136 | 137 | fl_x=(fl_lip[:,0]-fl_lip[0,0])/distance 138 | fl_y=(fl_lip[:,1]-fl_lip[0,1])/distance 139 | 140 | 141 | else: 142 | 143 | fl_x=(fl_lip[:,0]-fl_lip[0,0]) / distance 144 | fl_y=(fl_lip[:,1]-fl_lip[0,1]) / distance 145 | 146 | 147 | return np.stack((fl_x,fl_y),axis=1) , distance 148 | 149 | 150 | def get_accuracy(y_pred,y_true): 151 | """ 152 | ********* 153 | get_accuracy : calcualte accuracy of a model 154 | ********* 155 | @author: Wish Suharitdamrong 156 | ------ 157 | inputs : 158 | ------ 159 | y_pred : predicted label 160 | y_true : ground truth of a label 161 | ------- 162 | outputs : 163 | ------- 164 | acc : accuracy of a model 165 | 166 | """ 167 | 168 | acc = accuracy_score(y_pred,y_true, normalize=True) * 100 169 | 170 | return acc 171 | 172 | 173 | def procrustes(fl): 174 | 175 | transformation = {} 176 | 177 | fl, mean = translation(fl) 178 | 179 | fl, scale = scaling(fl) 180 | 181 | #fl , rotate = rotation(fl) 182 | 183 | transformation['translate'] = mean 184 | 185 | transformation['scale'] = scale 186 | 187 | #transformation['rotate'] = rotate 188 | 189 | return fl , transformation 190 | 191 | 192 | def translation(fl): 193 | 194 | mean = np.mean(fl, axis=0) 195 | 196 | fl = fl - mean 197 | 198 | return fl , mean 199 | 200 | def scaling(fl): 201 | 202 | scale = np.sqrt(np.mean(np.sum(fl**2, axis=1))) 203 | 204 | fl = fl/scale 205 | 206 | return fl , scale 207 | 208 | def rotation(fl): 209 | 210 | left_eye , right_eye = get_eye(fl) 211 | 212 | dx = right_eye[0] - left_eye[0] 213 | dy = right_eye[1] - left_eye[1] 214 | dz = right_eye[2] - left_eye[2] 215 | 216 | # Roll : rotate through z axis 217 | if dx!=0 : 218 | f = dy/dx 219 | a = np.arctan(f) 220 | roll = np.array([ 221 | [math.cos(a), -math.sin(a) , 0], 222 | [math.sin(a), math.cos(a) , 0], 223 | [0, 0 , 1] 224 | ]) 225 | else: 226 | roll = np.array([ 227 | [1 , 0 , 0], 228 | [0 , 1 , 0], 229 | [0 , 0 , 1] 230 | ]) 231 | 232 | # Yaw : rotate through y axis 233 | f = dx/dz 234 | 235 | a = np.arctan(f) 236 | yaw = np.array([ 237 | [math.cos(a),0, math.sin(a)], 238 | [0,1,0], 239 | [-math.sin(a),0,math.cos(a)] 240 | ]) 241 | 242 | # 243 | # f= dz/dy 244 | # a = np.arctan(f) 245 | # pitch = np.array([ 246 | # [1,0,0], 247 | # [0,math.cos(a), -math.sin(a)], 248 | # [0,math.sin(a),math.cos(a)], 249 | # 250 | # ]) 251 | 252 | # Roate face in frontal pose 253 | a = np.arctan(90) 254 | frontal = np.array([ 255 | [math.cos(a),0, math.sin(a)], 256 | [0,1,0], 257 | [-math.sin(a),0,math.cos(a)] 258 | ]) 259 | 260 | # transformation for rotation 261 | rotate = np.matmul(np.matmul(roll,yaw),frontal) 262 | 263 | fl = np.matmul(fl,rotate) 264 | 265 | 266 | return fl , rotate 267 | 268 | def get_eye(fl): 269 | """ 270 | get_eye : get center of both eye on the facial landmarks 271 | """ 272 | 273 | left_eye = np.mean(fl[36:42,:], axis=0) 274 | right_eye = np.mean(fl[42:48,:], axis=0) 275 | 276 | return left_eye, right_eye 277 | 278 | 279 | 280 | 281 | 282 | 283 | -------------------------------------------------------------------------------- /utils/wav2lip.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 4 | This file contain a Utility function from ***Wav2lip*** 5 | 6 | 7 | Links : https://github.com/Rudrabha/Wav2Lip 8 | 9 | 10 | """ 11 | 12 | from os.path import join 13 | import torch 14 | import os 15 | from torch.nn import DataParallel 16 | import cv2 17 | import subprocess 18 | import numpy as np 19 | from utils import audio 20 | 21 | 22 | def save_checkpoint(model, optimizer, checkpoint_dir,epoch, savename): 23 | checkpoint_path = join( 24 | checkpoint_dir, savename) 25 | 26 | optimizer_state = optimizer.state_dict() 27 | torch.save({ 28 | "state_dict": model.state_dict(), 29 | "optimizer": optimizer_state, 30 | "global_epoch": epoch, 31 | }, checkpoint_path) 32 | print("Saved checkpoints:", checkpoint_path) 33 | 34 | 35 | def _load(checkpoint_path, use_cuda): 36 | if use_cuda: 37 | checkpoint = torch.load(checkpoint_path) 38 | else: 39 | checkpoint = torch.load(checkpoint_path, 40 | map_location=lambda storage, loc: storage) 41 | return checkpoint 42 | 43 | def load_checkpoint(path, model, optimizer,use_cuda, reset_optimizer=False, pretrain=False): 44 | 45 | global global_epoch 46 | 47 | print("Load checkpoints from: {}".format(path)) 48 | 49 | checkpoint = _load(path, use_cuda) 50 | try : 51 | model.load_state_dict(checkpoint["state_dict"]) 52 | except Exception as e: 53 | 54 | model = DataParallel(model) 55 | model.load_state_dict(checkpoint["state_dict"]) 56 | 57 | if not pretrain: 58 | if not reset_optimizer: 59 | optimizer_state = checkpoint["optimizer"] 60 | if optimizer_state is not None: 61 | print("Load optimizer state from {}".format(path)) 62 | optimizer.load_state_dict(checkpoint["optimizer"]) 63 | 64 | if not pretrain: 65 | global_epoch = checkpoint["global_epoch"] 66 | return model, optimizer, global_epoch 67 | 68 | else: 69 | 70 | return model 71 | 72 | 73 | 74 | # This function originally name **get_image_list** 75 | def get_fl_list(data_root, split): 76 | 77 | filelist = [] 78 | with open('filelists/{}.txt'.format(split)) as f: 79 | for line in f: 80 | line = line.strip() 81 | if ' ' in line: line = line.split()[0] 82 | filelist.append(os.path.join(data_root, line)) 83 | 84 | return filelist 85 | 86 | 87 | 88 | 89 | 90 | ################ 91 | 92 | 93 | 94 | 95 | def prepare_video(path,in_fps): 96 | """ 97 | ******************************************************************************* 98 | Prepare input image/videos : Detect a FPS of input and spilt video into frames 99 | ******************************************************************************* 100 | 101 | The code in this function was orginially was from ***Wav2Lip*** 102 | 103 | """ 104 | 105 | # Checko that the give file exists 106 | if not os.path.exists(path): 107 | 108 | raise ValueError("Cannot locate a input file in a given path {} in argument --input_face".format(path)) 109 | 110 | # Check that input face is valid files 111 | elif not os.path.isfile(path): 112 | 113 | raise ValueError("Input must be valid file at the given path {} in argument --input_face".format(path)) 114 | 115 | # Check in input is image 116 | elif path.split('.')[1] in ['jpg','jpeg','png']: 117 | 118 | img = cv2.imread(path) 119 | img = cv2.resize(img,(256,256),interpolation=cv2.INTER_LINEAR) 120 | 121 | # read images 122 | all_frames = [img] 123 | # set FPS for inference equal to an input fps 124 | fps = in_fps 125 | 126 | else: 127 | print( path.split('.') ) 128 | print("vido") 129 | # read video 130 | video = cv2.VideoCapture(path) 131 | # set FPS for inference equal to FPS of input video 132 | fps = video.get(cv2.CAP_PROP_FPS) 133 | all_frames = [] 134 | while 1: 135 | # read each frame of video 136 | still_reading, frame= video.read() 137 | # if no futher frame stop reading 138 | if not still_reading: 139 | 140 | video.release() # close video reader 141 | break 142 | 143 | frame = cv2.resize(frame,(256,256)) # resize input image to 256x256 (width and height) 144 | #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 145 | """ 146 | cv2.imshow('asdasd', frame) 147 | cv2.waitKey(0) 148 | cv2.destroyAllWindows() 149 | """ 150 | all_frames.append(frame) 151 | 152 | return all_frames, fps 153 | 154 | def prepare_audio(path, fps): 155 | """ 156 | *************************************************************** 157 | Prepare input audio : Transform audio input to Melspectrogram 158 | *************************************************************** 159 | 160 | The code in this function was orginially was from ***Wav2Lip*** 161 | 162 | """ 163 | # if the input audio is not .wav file then convert it to .wav 164 | if not path.endswith('.wav'): 165 | # command using ffmpeg to convert audio to .wav and store temporary .wav file {temp/temp.wav} 166 | command = 'ffmpeg -y -i {} -strict -2 {}'.format(path, 'temp/temp.wav') 167 | subprocess.call(command, shell=True) 168 | # change path of the input file to temporary wav file 169 | path = 'temp/temp.wav' 170 | 171 | wav = audio.load_wav(path, 16000) # load wave audio with sample rate of 16000 172 | mel = audio.melspectrogram(wav) # transform wav to melspectorgram 173 | 174 | if np.isnan(mel.reshape(-1)).sum() > 0: 175 | raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 176 | 177 | mel_chunks = [] 178 | mel_idx_multiplier = 80./fps 179 | mel_step_size = 18 #time step in spectrogram 180 | i = 0 181 | while 1: 182 | 183 | start_idx = int(i * mel_idx_multiplier) 184 | 185 | if start_idx + mel_step_size > len(mel[0]): 186 | 187 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) 188 | break 189 | 190 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 191 | i += 1 192 | 193 | return mel_chunks 194 | 195 | 196 | 197 | 198 | 199 | 200 | --------------------------------------------------------------------------------