├── .gitignore ├── .gitmodules ├── README.md ├── setup.py ├── wav2mov-docs ├── gan_setup.PNG └── gen_arch.PNG └── wav2mov ├── __init__.py ├── config.json ├── config.py ├── core ├── data │ ├── __init__.py │ ├── collates.py │ ├── dataloaders.py │ ├── datasets.py │ ├── note.txt │ ├── raw_datasets.py │ ├── transforms.py │ └── utils.py ├── engine │ ├── __init__.py │ ├── callbacks.py │ └── engine.py ├── models │ ├── __init__.py │ ├── base_model.py │ └── template_model.py └── utils │ ├── __init__.py │ ├── average_meter.py │ ├── checkpoints.py │ ├── logger.py │ ├── misc.py │ └── os_utils.py ├── datasets └── create_file_list.py ├── inference ├── __init__.py ├── audio_utils.py ├── generate.py ├── generate.sh ├── image_utils.py ├── model_utils.py ├── models │ ├── __init__.py │ ├── generator │ │ ├── __init__.py │ │ ├── audio_encoder.py │ │ ├── frame_generator.py │ │ ├── id_decoder.py │ │ └── id_encoder.py │ ├── layers │ │ ├── conv_layers.py │ │ └── debug_layers.py │ └── utils.py ├── params.py ├── quantize.py └── utils.py ├── logger.py ├── losses ├── ReadMe.md ├── __init__.py ├── gan_loss.py ├── l1_loss.py └── sync_loss.py ├── main ├── callbacks.py ├── data.py ├── engine.py ├── main.py ├── options.py ├── preprocess.py ├── test.py ├── train.py ├── trained.txt └── validate_params.py ├── models ├── README.md ├── __init__.py ├── discriminators │ ├── identity_discriminator.py │ ├── patch_disc.py │ ├── sequence_discriminator.py │ ├── sync_discriminator.py │ └── utils.py ├── generator │ ├── audio_encoder.py │ ├── frame_generator.py │ ├── id_decoder.py │ ├── id_encoder.py │ └── noise_encoder.py ├── layers │ ├── conv_layers.py │ └── debug_layers.py ├── utils.py ├── wav2mov_inferencer.py ├── wav2mov_template.py └── wav2mov_trainer.py ├── params.json ├── params.py ├── plans ├── README.md ├── extras.md ├── images │ ├── components.png │ ├── gen_arch.png │ ├── plan_v1.png │ └── system.png └── observations.md ├── preprocess.bat ├── preprocess.sh ├── pylintrc ├── requirements.txt ├── run.bat ├── run.sh ├── run_sync_expert.sh ├── settings.py ├── test.bat ├── test.sh ├── tests ├── old │ ├── log.json │ ├── test.py │ ├── test.txt │ ├── test_3d_cnn.py │ ├── test_audio_util.py │ ├── test_batching.py │ ├── test_config.py │ ├── test_dataset.py │ ├── test_discriminators.py │ ├── test_file_logger.py │ ├── test_logger.py │ ├── test_main_data.py │ ├── test_module_level_logger.py │ ├── test_packing.py │ ├── test_settings.py │ ├── test_stored_video.py │ ├── test_tensorboard_logger.py │ ├── test_unet.py │ ├── test_utils.py │ ├── test_wav2mov.py │ ├── test_wav2mov_v7.py │ └── utils.py ├── test_audio_encoder.py ├── test_audio_frames.py ├── test_generator.py ├── test_main_dataloader.py ├── test_seq_disc.py ├── test_sync_disc.py └── test_video.py └── utils ├── __init__.py ├── audio.py ├── cnn_shape_calc.py ├── files.py ├── misc.py └── plots.py /.gitignore: -------------------------------------------------------------------------------- 1 | shape_predictor_68_face_landmarks.dat 2 | wav2mov/datasets/* 3 | !wav2mov/datasets/create_file_list.py 4 | wav2mov/runs/ 5 | wav2mov/logs/ 6 | wav2mov/tests/res 7 | wav2mov/plans/*.txt 8 | *.png 9 | !wav2mov/plans/images/*.png 10 | !wav2mov-docs/*.png 11 | *.gif 12 | *.pt 13 | *.jfif 14 | *.wav 15 | *.avi 16 | archieves/ 17 | # Vs code 18 | .vscode 19 | #vim files 20 | *.vimrc 21 | *.*~ 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | cover/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | db.sqlite3-journal 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | .pybuilder/ 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | # For a library or package, you might want to ignore these files since the code is 108 | # intended to run in multiple environments; otherwise, check them in: 109 | # .python-version 110 | 111 | # pipenv 112 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 113 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 114 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 115 | # install all needed dependencies. 116 | #Pipfile.lock 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [![website](https://img.shields.io/static/v1?label=&message=wav2mov&color=blue&style=for-the-badge)](https://wav2mov.vercel.app) 2 | 3 | ## Speech To Facial Animation Using GANs 4 | 5 | 6 | [![python](https://img.shields.io/badge/Python-3776AB?style=for-the-badge&logo=python&logoColor=white)](https://www.python.org/) [![pytorch](https://img.shields.io/badge/PyTorch-EE4C2C?style=for-the-badge&logo=PyTorch&logoColor=white)](https://pytorch.org/) [![GANs](https://img.shields.io/badge/GANs-4BB749?style=for-the-badge&logo=&logoColor=white)](#1) 7 | 8 | This repo contains the pytorch implementation of achieving facial animation from given face image and speech input using Generative Adversarial Nets (See [References](#1)). 9 | 10 | 11 | ## Results 12 | 13 | Some of the generated videos are found [here](https://wav2mov-examples.vercel.app/examples). 14 | 15 | ## Implementation 16 | ### GAN setup 17 | ![gan_setup](/wav2mov-docs/gan_setup.PNG) 18 | ![generator_architecture](/wav2mov-docs/gen_arch.PNG) 19 | ## References 20 | 21 | [1] Generative Adversarial Nets 22 | ```bibtex 23 | @article{goodfellow2014generative, 24 | title={Generative adversarial networks}, 25 | author={Goodfellow, Ian J and Pouget-Abadie, Jean and Mirza, Mehdi and Xu, Bing and Warde-Farley, David and Ozair, Sherjil and Courville, Aaron and Bengio, Yoshua}, 26 | journal={arXiv preprint arXiv:1406.2661}, 27 | year={2014} 28 | } 29 | ``` 30 | 31 | [2] The Audio-Visual Lombard Grid Speech Corpus 32 | 33 | [![Github stars](https://img.shields.io/badge/Dataset-LombardGrid-.svg)](http://spandh.dcs.shef.ac.uk/avlombard/) 34 | 35 | ```bibtex 36 | @article{Alghamdi_2018, 37 | doi = {10.1121/1.5042758}, 38 | url = {https://doi.org/10.1121%2F1.5042758}, 39 | year = 2018, 40 | month = {jun}, 41 | publisher = {Acoustical Society of America ({ASA})}, 42 | volume = {143}, 43 | number = {6}, 44 | pages = {EL523--EL529}, 45 | author = {Najwa Alghamdi and Steve Maddock and Ricard Marxer and Jon Barker and Guy J. Brown}, 46 | title = {A corpus of audio-visual Lombard speech with frontal and profile views}, 47 | journal = {The Journal of the Acoustical Society of America} 48 | } 49 | ``` 50 | 51 | [3] Realistic Facial Animation using GANs 52 | 53 | [![Github stars](https://img.shields.io/badge/Github-sda-.svg)](https://github.com/DinoMan/speech-driven-animation) 54 | 55 | ```bibtex 56 | @article{Vougioukas_2019, 57 | doi = {10.1007/s11263-019-01251-8}, 58 | url = {https://doi.org/10.1007%2Fs11263-019-01251-8}, 59 | year = 2019, 60 | month = {oct}, 61 | publisher = {Springer Science and Business Media {LLC}}, 62 | volume = {128}, 63 | number = {5}, 64 | pages = {1398--1413}, 65 | author = {Konstantinos Vougioukas and Stavros Petridis and Maja Pantic}, 66 | title = {Realistic Speech-Driven Facial Animation with {GANs}}, 67 | journal = {International Journal of Computer Vision} 68 | } 69 | ``` 70 | 71 | [4] End to End Facial Animation using Temporal GANs 72 | ```bibtex 73 | @article{vougioukas2018end, 74 | title={End-to-end speech-driven facial animation with temporal gans}, 75 | author={Vougioukas, Konstantinos and Petridis, Stavros and Pantic, Maja}, 76 | journal={arXiv preprint arXiv:1805.09313}, 77 | year={2018} 78 | } 79 | ``` 80 | 81 | [5] A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild 82 | 83 | 84 | [![Github stars](https://img.shields.io/badge/Github-wav2Lip-.svg)](https://github.com/Rudrabha/Wav2Lip) 85 | ```bibtex 86 | @inproceedings{10.1145/3394171.3413532, 87 | author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.}, 88 | title = {A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild}, 89 | year = {2020}, 90 | isbn = {9781450379885}, 91 | publisher = {Association for Computing Machinery}, 92 | address = {New York, NY, USA}, 93 | url = {https://doi.org/10.1145/3394171.3413532}, 94 | doi = {10.1145/3394171.3413532}, 95 | booktitle = {Proceedings of the 28th ACM International Conference on Multimedia}, 96 | pages = {484–492}, 97 | numpages = {9}, 98 | keywords = {lip sync, talking face generation, video generation}, 99 | location = {Seattle, WA, USA}, 100 | series = {MM '20} 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup,find_packages 2 | setup(name="wav2mov",version="1.0.0",packages= find_packages()) 3 | -------------------------------------------------------------------------------- /wav2mov-docs/gan_setup.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov-docs/gan_setup.PNG -------------------------------------------------------------------------------- /wav2mov-docs/gen_arch.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov-docs/gen_arch.PNG -------------------------------------------------------------------------------- /wav2mov/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/__init__.py -------------------------------------------------------------------------------- /wav2mov/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fixed": { 3 | "seed": 10, 4 | "grid_dataset_dir": "D:\\dataset_lip\\GRID" 5 | }, 6 | "runtime": { 7 | "runs_dir": "%(base_dir)s\\runs\\%(v)s\\%(version)s", 8 | "log_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\%(log_filename)s.log", 9 | "train_test_dataset_dir": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14", 10 | "filenames_txt": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14\\filenames.txt", 11 | "filenames_train_txt": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14\\filenames_train.txt", 12 | "filenames_test_txt": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14\\filenames_test.txt", 13 | "gen_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\gen_%(version)s.pt", 14 | "seq_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\seq_disc_%(version)s.pt", 15 | "sync_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\sync_disc_%(version)s.pt", 16 | "sync_expert_checkpoint_fullpath": "/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/sync_expert/Run_11_7_2021__13_36", 17 | "id_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\id_disc_%(version)s.pt", 18 | "params_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\hparams_%(version)s.json", 19 | "optim_gen_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_gen_%(version)s.pt", 20 | "optim_seq_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_seq_disc_%(version)s.pt", 21 | "optim_sync_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_sync_disc_%(version)s.pt", 22 | "optim_id_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_id_disc_%(version)s.pt", 23 | "optim_params_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_hparams_%(version)s.json", 24 | "scheduler_gen_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_gen_%(version)s.pt", 25 | "scheduler_seq_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_seq_disc_%(version)s.pt", 26 | "scheduler_sync_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_sync_disc_%(version)s.pt", 27 | "scheduler_id_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_id_disc_%(version)s.pt", 28 | "out_dir": "%(base_dir)s\\out\\%(v)s" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /wav2mov/config.py: -------------------------------------------------------------------------------- 1 | """Config""" 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import pytz 7 | from datetime import datetime 8 | from wav2mov.logger import get_module_level_logger 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | logger = get_module_level_logger(__name__) 12 | logger.setLevel(logging.WARNING) 13 | 14 | def get_curr_run_str(): 15 | now = datetime.now().astimezone(pytz.timezone("Asia/Calcutta")) 16 | date,time = now.date(),now.time() 17 | day,month,year = date.day,date.month,date.year 18 | hour , minutes = time.hour,time.minute 19 | return 'Run_{}_{}_{}__{}_{}'.format(day,month,year,hour,minutes) 20 | 21 | class Config : 22 | def __init__(self,v): 23 | self.vals = {'base_dir':BASE_DIR} 24 | self.version = get_curr_run_str() 25 | #these two are used to decide whether to create a path using os.makedirs 26 | self.fixed_paths = set() 27 | self.runtime_paths = set() 28 | self.v = v 29 | 30 | def _rectify_paths(self,val): 31 | return val % {'base_dir': self.vals['base_dir'], 32 | 'version': self.version, 33 | 'v':self.v, 34 | 'checkpoint_filename': 'checkpoint_'+self.version, 35 | 'log_filename': 'log_' + self.version 36 | } 37 | 38 | def _flatten_dict(self,d): 39 | flattened= {} 40 | for k,v in d.items(): 41 | if isinstance(v,dict): 42 | inner_d = self._flatten_dict(v) 43 | for k_inner in inner_d.keys(): 44 | if k== 'fixed': self.fixed_paths.add(k_inner) 45 | else: self.runtime_paths.add(k_inner) 46 | 47 | flattened = {**flattened,**inner_d} 48 | else: 49 | flattened[k] = v 50 | return flattened 51 | 52 | def _update_vals_from_dict(self,d:dict): 53 | d = self._flatten_dict(d) 54 | for key,val in d.items(): 55 | if isinstance(val,str): 56 | value = self._rectify_paths(val) 57 | 58 | else: 59 | value = val 60 | self.vals[key] = value 61 | 62 | @classmethod 63 | def from_json(cls,json_file_fullpath,v): 64 | with open(json_file_fullpath,'r') as file: 65 | configs = json.load(file) 66 | obj = cls(v) 67 | obj._update_vals_from_dict(configs) 68 | return obj 69 | 70 | def update(self,key,value): 71 | if key in self.vals: 72 | logger.warning(f'Updating existing parameter {key} : changing from {self.vals[key]} with {value}') 73 | self.vals[key] = value 74 | 75 | def __getitem__(self,item): 76 | if item not in self.vals: 77 | logger.error(f'{self.__class__.__name__} object has no key called {item}') 78 | raise KeyError('No key called ', item) 79 | if item in self.fixed_paths: 80 | return self.vals[item] 81 | if os.sep != '\\': 82 | self.vals[item] = re.sub(r'(\\)+', os.sep, self.vals[item]) 83 | if 'fullpath' in item or 'dir' in item: 84 | path = os.path.dirname(self.vals[item]) if '.' in os.path.basename(self.vals[item]) else self.vals[item] 85 | os.makedirs(path,exist_ok=True) 86 | logger.debug(f'accessing {self.vals[item]}') 87 | return self.vals[item] 88 | 89 | def get_config(v): 90 | config = Config.from_json(os.path.join(BASE_DIR,'config.json'),v) 91 | return config 92 | 93 | if __name__=='__main__': 94 | pass 95 | -------------------------------------------------------------------------------- /wav2mov/core/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ Datasets utils 2 | """ 3 | from .raw_datasets import RawDataset,GridDataset,RavdessDataset 4 | 5 | -------------------------------------------------------------------------------- /wav2mov/core/data/collates.py: -------------------------------------------------------------------------------- 1 | """ collate functions """ 2 | from collections import namedtuple 3 | import torch 4 | from wav2mov.core.data.utils import Sample,SampleWithFrames 5 | 6 | from wav2mov.core.utils.logger import get_module_level_logger 7 | logger = get_module_level_logger(__name__) 8 | from wav2mov.core.data.utils import AudioUtil 9 | Lens = namedtuple('lens', ('audio', 'video')) 10 | 11 | def get_frames_limit(audio_lens,video_lens,stride): 12 | audio_frames_lens = [audio_len//stride for audio_len in audio_lens] 13 | return min(min(audio_frames_lens),min(video_lens)) 14 | 15 | def video_frange_start(num_frames,req_frames): 16 | return (num_frames-req_frames)//2 17 | 18 | def video_frange_end(num_frames,req_frames): 19 | return (num_frames + req_frames)//2 20 | 21 | def get_batch_collate(hparams): 22 | stride = hparams['audio_sf']// hparams['video_fps'] 23 | audio_util = AudioUtil(hparams['audio_sf'],hparams['coarticulation_factor'],stride) 24 | def collate_fn(batch): 25 | videos = [(sample.video,sample.video.shape[0]) for sample in batch] 26 | videos,video_lens = list(zip(*videos)) 27 | 28 | audios = [(sample.audio.unsqueeze(0),sample.audio.shape[0]) for sample in batch] 29 | audios,audio_lens = list(zip(*audios)) 30 | 31 | req_frames = get_frames_limit(audio_lens,video_lens,stride) 32 | ranges = [(video_frange_start(video.shape[0],req_frames),video_frange_end(video.shape[0],req_frames)) for video in videos] 33 | #[<------- total_frames-------->] 34 | #[<---><---req_frames----><---->] 35 | # ^ ^ 36 | # frange_start frange_end 37 | videos = [video[ranges[i][0]:ranges[i][1],... ].unsqueeze(0) for i,video in enumerate(videos)] 38 | 39 | audio_frames = [audio_util.get_audio_frames(audio,num_frames=req_frames,get_mfccs=True).unsqueeze(0) for i,audio in enumerate(audios)] 40 | audios = [audio_util.get_limited_audio(audio,num_frames=req_frames,get_mfccs=False) for audio in audios] 41 | return SampleWithFrames(torch.cat(audios),torch.cat(audio_frames),torch.cat(videos)) 42 | 43 | return collate_fn -------------------------------------------------------------------------------- /wav2mov/core/data/dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import NamedTuple 3 | from collections import namedtuple 4 | from torch.utils.data import DataLoader,Dataset,random_split 5 | 6 | 7 | class DataloadersPack: 8 | 9 | 10 | def get_dataloaders(self,dataset:Dataset,splits,batch_size,seed): 11 | 12 | train_set,test_set,validation_set = self.__split_dataset(dataset,splits,seed) 13 | train_dataloader = self.__get_dataloader(train_set,batch_size) 14 | test_dataloader = self.__get_dataloader(test_set,batch_size) 15 | validation_dataloader = self.__get_dataloader(validation_set,batch_size) 16 | 17 | dataloaders_pack = namedtuple('dataloaders_pack',['train','validation','test']) 18 | return dataloaders_pack(train_dataloader,validation_dataloader,test_dataloader) 19 | 20 | def __get_dataloader(self,dataset,batch_size): 21 | 22 | return DataLoader(dataset, 23 | batch_size=batch_size,#,shuffle=True) 24 | num_workers=2, #[WARNING] runs project.py this number of times 25 | shuffle=True) 26 | 27 | def __split_dataset(self,dataset,splits,seed): 28 | N = len(dataset) 29 | train_size,test_size,val_size = self.__get_split_sizes(N,splits) 30 | return random_split(dataset,[train_size,test_size,val_size], 31 | generator=torch.Generator().manual_seed(seed)) 32 | 33 | def __get_split_sizes(self,N,splits): 34 | train_size = int((splits.train_size/10)*N) 35 | test_size = int((splits.test_size/10)*N) 36 | validation_size = N-train_size-test_size 37 | return train_size,test_size,validation_size 38 | 39 | 40 | 41 | 42 | def splitted_dataloaders(dataset:Dataset,splits:namedtuple,batch_size,seed)->NamedTuple: 43 | 44 | return DataloadersPack().get_dataloaders(dataset,splits,batch_size,seed) -------------------------------------------------------------------------------- /wav2mov/core/data/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import os 4 | import torch 5 | from torch.utils.data import Dataset 6 | from collections import namedtuple 7 | 8 | from wav2mov.core.data.utils import Sample 9 | 10 | 11 | class AudioVideoDataset(Dataset): 12 | """ 13 | Dataset for audio and video numpy files 14 | """ 15 | def __init__(self, 16 | root_dir, 17 | filenames_text_filepath, 18 | audio_sf, 19 | video_fps, 20 | num_videos, 21 | transform=None): 22 | self.root_dir = root_dir 23 | self.audio_fs = audio_sf 24 | self.video_fps = video_fps 25 | self.stride = math.floor(self.audio_fs / self.video_fps) 26 | self.transform = transform 27 | self.filenames = [] 28 | with open(filenames_text_filepath,'r') as file: 29 | self.filenames = file.read().split('\n') 30 | self.filenames = self.filenames[:num_videos] 31 | 32 | def __len__(self): 33 | return len(self.filenames) 34 | 35 | def __load_from_np(self,path): 36 | return np.load(path) 37 | 38 | def __get_folder_name(self,curr_file): 39 | return os.path.join(self.root_dir,curr_file) 40 | 41 | def get_audio(self,idx): 42 | folder = self.__get_folder_name(self.filenames[idx]) 43 | audio_filepath = os.path.join(folder,'audio.npy') 44 | return self.__load_from_np(audio_filepath) 45 | 46 | def get_video_frames(self,idx): 47 | folder = self.__get_folder_name(self.filenames[idx]) 48 | video_filepath = os.path.join(folder,'video_frames.npy') 49 | return self.__load_from_np(video_filepath) 50 | 51 | def __getitem__(self,idx): 52 | if torch.is_tensor(idx): 53 | idx = idx.tolist() 54 | 55 | audio = self.get_audio(idx) 56 | video = self.get_video_frames(idx) 57 | audio = torch.from_numpy(audio) 58 | video = torch.from_numpy(video).permute(0,3,1,2)#F,H,W,C ==> F,C,H,W 59 | video = video/255 60 | sample = Sample(audio,video) 61 | if self.transform: 62 | sample = self.transform(sample) 63 | return sample 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /wav2mov/core/data/note.txt: -------------------------------------------------------------------------------- 1 | original GRID DATASET has videos of shape (H,W,C)==(480, 720, 3) -------------------------------------------------------------------------------- /wav2mov/core/data/raw_datasets.py: -------------------------------------------------------------------------------- 1 | """Raw Datasets Classes : Grid and Ravdess""" 2 | import os 3 | from abc import abstractmethod 4 | from collections import namedtuple 5 | from tqdm import tqdm 6 | 7 | from wav2mov.core.data.utils import get_video_frames,get_audio,get_audio_from_video 8 | 9 | 10 | SampleContainer = namedtuple('SampleContainer',['audio','video']) 11 | Sample = namedtuple('Sample',['path','val']) 12 | 13 | 14 | 15 | class RawDataset: 16 | """ 17 | Description of RawDataset 18 | 19 | Args: 20 | location (undefined): 21 | audio_sampling_rate=None (undefined): 22 | video_frame_rate=None (undefined): 23 | website=None (undefined): 24 | 25 | """ 26 | 27 | 28 | def __init__(self, 29 | location, 30 | audio_sampling_rate=None, 31 | video_frame_rate=None, 32 | samples_count = None, 33 | website=None): 34 | """ 35 | Description of __init__ 36 | 37 | Args: 38 | self (undefined): 39 | location (undefined): 40 | audio_sampling_rate=None (undefined): 41 | video_frame_rate=None (undefined): 42 | samples_count=None 43 | website=None (undefined): 44 | 45 | """ 46 | self.root_location = location 47 | self.audio_sampling_rate = audio_sampling_rate 48 | self.video_frame_rate = video_frame_rate 49 | self.samples_count = samples_count 50 | self.website = website 51 | 52 | def info(self): 53 | return print(f'self.__class__.__name__ : {self.website}') 54 | 55 | @abstractmethod 56 | def generator(self): 57 | raise NotImplementedError(f'{self.__class__.__name__} should implement generator method') 58 | 59 | class RavdessDataset(RawDataset): 60 | 61 | """Dataset class for RAVDESS dataset. 62 | 63 | [LINK] https://zenodo.org/record/1188976#.YBlpZOgzZPY 64 | 65 | Folder structure 66 | ------------------ 67 | 68 | ` 69 | root_location:folder 70 | | 71 | |___actor1:folder 72 | | | 73 | | |_ video 1:file 74 | | |_ video 2:file 75 | | |_ ... 76 | | 77 | |___actor2:folder 78 | | 79 | |_ video 1:file 80 | |... 81 | ` 82 | 83 | """ 84 | link = r"https://zenodo.org/record/1188976#.YBlpZOgzZPY" 85 | name = 'ravdess_dataset' 86 | def __init__(self,location,audio_sampling_rate,video_frame_rate,samples_count=None,link=None): 87 | super().__init__(location,audio_sampling_rate,video_frame_rate,samples_count,link) 88 | self.sub_folders = (folder for folder in os.listdir(location) 89 | if os.path.isdir(os.path.join(location,folder))) 90 | 91 | 92 | 93 | def generator(self)->Sample: 94 | """yields audio and video filenames one by one 95 | 96 | Returns: 97 | Sample: containes audio and video filpaths 98 | 99 | Yields: 100 | Iterator[Sample]: 101 | 102 | Examples: 103 | 104 | >>>dataset = RavdessDataset(...) 105 | >>>for sample in dataset.generator(): 106 | >>> audio_filepath = sample.audio 107 | 108 | """ 109 | for actor in self.sub_folders: 110 | actor_path = os.path.join(self.root_location,actor) 111 | videos = (video for _,_,video in os.walk(actor_path)) 112 | limit = self.samples_count if self.samples_count!=None else len(videos[0]) 113 | for video in tqdm(videos[:limit], 114 | desc="Ravdess Dataset", 115 | total=len(videos),ascii=True,colour="green"): 116 | audio_path = get_audio_from_video(os.path.join(actor_path,video)) 117 | video_path = os.path.join(actor_path,video) 118 | yield SampleContainer(video=Sample(video_path,val=None), 119 | audio=Sample(audio_path,val=None)) 120 | 121 | 122 | class GridDataset(RawDataset): 123 | """Dataset class for Grid dataset. 124 | 125 | [LINK] http://spandh.dcs.shef.ac.uk/avlombard/ 126 | 127 | [LINK] [PAPER] https://asa.scitation.org/doi/10.1121/1.5042758 128 | 129 | `Folder structure`:: 130 | 131 | root_location:folder 132 | | 133 | |___audio:folder 134 | | | 135 | | |_ .wav:file 136 | | |_ .wav:file 137 | | |_ ... 138 | | 139 | |___video:folder 140 | | 141 | |_ .mov:file 142 | |_ .mov:file 143 | |_ ... 144 | 145 | """ 146 | link = "http://spandh.dcs.shef.ac.uk/avlombard/" 147 | name = 'grid_dataset' 148 | def __init__(self,location,audio_sampling_rate,video_frame_rate,samples_count=None,link=None): 149 | super().__init__(location, 150 | audio_sampling_rate, 151 | video_frame_rate, 152 | samples_count, 153 | link) 154 | 155 | 156 | def generator(self,get_filepath_only=True,img_size=(256,256),show_progress_bar=True)->Sample: 157 | video_folder = os.path.join(self.root_location,'video/') 158 | audio_folder = os.path.join(self.root_location,'audio/') 159 | videos = [video for _,_,video in os.walk(video_folder)] 160 | if self.samples_count is None : 161 | self.samples_count = len(videos[0]) 162 | self.samples_count = min(len(videos[0]),self.samples_count) 163 | limit = self.samples_count 164 | progress_bar = tqdm(enumerate(videos[0][:limit]), 165 | desc="Processing Grid Dataset", 166 | total=len(videos[0][:limit]), ascii=True, colour="green") if show_progress_bar else enumerate(videos[0][:limit]) 167 | 168 | for idx,video_filename in progress_bar: 169 | if isinstance(progress_bar,tqdm): 170 | progress_bar.set_postfix({'file':video_filename}) 171 | audio_filename = video_filename.split('.')[0] + '.wav' 172 | video_path = os.path.join(video_folder, video_filename) 173 | audio_path = os.path.join(audio_folder, audio_filename) 174 | 175 | video_val = None if get_filepath_only else get_video_frames(video_path, img_size=img_size) 176 | audio_val = None if get_filepath_only else get_audio(audio_path) 177 | 178 | res = SampleContainer(video=Sample(video_path, val=video_val), 179 | audio=Sample(audio_path, val=audio_val)) 180 | # if idx+1==limit:#what if the actual length of videos is less than the limit passed by talaharte hudga 181 | # return res 182 | # else: 183 | yield res 184 | 185 | 186 | -------------------------------------------------------------------------------- /wav2mov/core/data/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as vtransforms 2 | 3 | from wav2mov.core.data.utils import Sample 4 | 5 | class ResizeGrayscale: 6 | def __init__(self,target_shape): 7 | _,*img_size = target_shape 8 | self.transform = vtransforms.Compose( 9 | [ 10 | # vtransforms.Grayscale(1), 11 | vtransforms.Resize(img_size), 12 | # vtransforms.Normalize([0.5]*img_channels, [0.5]*img_channels) 13 | ]) 14 | def __call__(self,sample): 15 | video = sample.video 16 | video = self.transform(video) 17 | return Sample(sample.audio,video) 18 | 19 | class Normalize: 20 | def __init__(self,mean_std): 21 | audio_mean,audio_std = mean_std['audio'] 22 | video_mean,video_std = mean_std['video'] 23 | 24 | self.audio_transform = lambda x: (x-audio_mean)/audio_std 25 | self.video_transform = vtransforms.Normalize(video_mean,video_std) 26 | 27 | def __call__(self,sample): 28 | audio,video = sample.audio,sample.video 29 | audio = self.audio_transform(audio) 30 | video = self.video_transform(video) 31 | return Sample(audio,video) -------------------------------------------------------------------------------- /wav2mov/core/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .callbacks import Callbacks,CallbackEvents,CallbackDispatcher 2 | from .engine import TemplateEngine 3 | -------------------------------------------------------------------------------- /wav2mov/core/engine/callbacks.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import Enum 3 | class CallbackEvents(Enum): 4 | TRAIN_START = "on_train_start" 5 | TRAIN_END = "on_train_end" 6 | EPOCH_START = "on_epoch_start" 7 | EPOCH_END = "on_epoch_end" 8 | BATCH_START = "on_batch_start" 9 | BATCH_END = "on_batch_end" 10 | RUN_START = "on_run_start" 11 | RUN_END = "on_run_end" 12 | NONE = "none" 13 | 14 | class Callbacks: 15 | def on_train_start(self, *args,**kwargs): 16 | pass 17 | def on_train_end(self, *args,**kwargs): 18 | pass 19 | def on_epoch_start(self, *args,**kwargs): 20 | pass 21 | def on_epoch_end(self,*args, **kwargs): 22 | pass 23 | def on_batch_start(self,*args, **kwargs): 24 | pass 25 | def on_batch_end(self, *args,**kwargs): 26 | pass 27 | def on_run_start(self,*args,**kwargs): 28 | pass 29 | def on_run_end(self,*args,**kwargs): 30 | pass 31 | 32 | class CallbackDispatcher: 33 | def __init__(self): 34 | self.callbacks = [] 35 | 36 | def register(self,callbacks): 37 | self.callbacks = callbacks 38 | 39 | def dispatch(self,event,*args,**kwargs): 40 | for callback in self.callbacks: 41 | # print(callback.__class__.__name__,args) 42 | getattr(callback,event.value)(*args,**kwargs) 43 | 44 | -------------------------------------------------------------------------------- /wav2mov/core/engine/engine.py: -------------------------------------------------------------------------------- 1 | from wav2mov.core.engine.callbacks import CallbackEvents,CallbackDispatcher 2 | from wav2mov.core.utils.logger import get_module_level_logger 3 | logger = get_module_level_logger(__name__) 4 | 5 | class TemplateEngine: 6 | 7 | def __init__(self): 8 | self.__event = CallbackEvents.NONE 9 | self.__dispatcher = CallbackDispatcher() 10 | 11 | @property 12 | def event(self): 13 | return self.__event 14 | 15 | @event.setter 16 | def event(self,event): 17 | self.__event = event 18 | 19 | def register(self,callbacks): 20 | self.__dispatcher.register(callbacks) 21 | 22 | def dispatch(self,event,*args,**kwargs): 23 | self.event = event 24 | # logger.debug(f'{self.event},args : {args} kwargs : {kwargs}') 25 | self.__dispatcher.dispatch(event,*args,**kwargs) 26 | 27 | def on_run_start(self,*args,**kwargs): 28 | pass 29 | def on_run_end(self,*args,**kwargs): 30 | pass 31 | def on_train_start(self,*args,**kwargs): 32 | pass 33 | def on_epoch_start(self,*args,**kwargs): 34 | pass 35 | def on_batch_start(self,*args,**kwargs): 36 | pass 37 | def on_batch_end(self,*args,**kwargs): 38 | pass 39 | def log(self,*args,**kwargs): 40 | pass 41 | def validate(self,*args,**kwargs): 42 | pass 43 | def on_epoch_end(self,*args,**kwargs): 44 | pass 45 | def on_train_end(self,*args,**kwargs): 46 | pass 47 | def run(self,*args,**kwargs): 48 | """ 49 | training script goes here 50 | """ 51 | print("[TEMPLATE ENGINE] 'run' function not implemented") 52 | pass 53 | def test(self,*args,**kwargs): 54 | """ 55 | test script goes here 56 | """ 57 | print("[TEMPLATE ENGINE] 'testing' function not implemented") 58 | pass -------------------------------------------------------------------------------- /wav2mov/core/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .template_model import TemplateModel -------------------------------------------------------------------------------- /wav2mov/core/models/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch 3 | from torch import nn 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | class BaseModel(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | 13 | def save_to(self,checkpoint_fullpath): 14 | torch.save(self,checkpoint_fullpath) 15 | logger.log(f'Model saved at {checkpoint_fullpath}','INFO') 16 | 17 | def load_from(self,checkpoint_fullpath): 18 | try: 19 | self.load_statedict(torch.load(checkpoint_fullpath)) 20 | except: 21 | logger.log(f'Cannot load checkpoint from {checkpoint_fullpath}',type="ERROR") 22 | 23 | @abstractmethod 24 | def forward(self,*args): 25 | raise NotImplementedError(f'Forward method is not defined in {self.__class__.__name__}') 26 | 27 | def freeze_learning(self): 28 | for p in self.parameters(): 29 | p.require_grad = False 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /wav2mov/core/models/template_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class TemplateModel(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | 7 | def forward(self,*args,**kwargs): 8 | raise NotImplementedError("Forward method must be implemented") 9 | 10 | def on_run_start(self,*args,**kwargs): 11 | pass 12 | def on_run_end(self,*args,**kwargs): 13 | pass 14 | def on_train_start(self,*args,**kwargs): 15 | pass 16 | def on_epoch_start(self,*args,**kwargs): 17 | pass 18 | def on_batch_start(self,*args,**kwargs): 19 | pass 20 | def setup_input(self,*args,**kwargs): 21 | pass 22 | def optimize_parameters(self,*args,**kwargs): 23 | pass 24 | def on_batch_end(self,*args,**kwargs): 25 | pass 26 | def on_epoch_end(self,*args,**kwargs): 27 | pass 28 | def log(self,*args,**kwargs): 29 | pass 30 | def validate(self,*args,**kwargs): 31 | pass 32 | def on_train_end(self,*args,**kwargs): 33 | pass 34 | 35 | def load(self,checkpoint): 36 | self.load_state_dict(checkpoint) -------------------------------------------------------------------------------- /wav2mov/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoints import save_checkpoint,save_fully_trained_model -------------------------------------------------------------------------------- /wav2mov/core/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(): 3 | """class for tracking metrics generated during training 4 | """ 5 | def __init__(self,name,fmt=':0.2f'): 6 | self.name = name 7 | self.fmt = fmt 8 | self.reset() 9 | 10 | def reset(self): 11 | self.curr_val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self,val,n=1): 17 | self.curr_val = val 18 | self.sum += val*n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | -------------------------------------------------------------------------------- /wav2mov/core/utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | def save_checkpoint(model,loss_val,location,save_params=True): 4 | statedict = { 5 | 'model_state_dict': model.state_dict(), 6 | 'min_val_loss' : loss_val, 7 | 'model_arch': str(model) 8 | } 9 | os.makedirs(os.path.dirname(location),exist_ok=True) 10 | torch.save(statedict,location) 11 | 12 | if save_params: 13 | hparams_dir = os.path.dirname(location) 14 | hparams_filename = 'hparams_' + os.path.basename(location).split('.')[0] + '.json' 15 | model.hparams.to_json(os.path.join(hparams_dir,hparams_filename)) 16 | # with open(os.path.join(hparams_dir,hparams_filename),'a+') as file: 17 | # json.dump(model.hparams.asdict(),) 18 | 19 | 20 | 21 | 22 | def load_checkpoint(location): 23 | return torch.load(location) 24 | 25 | def save_fully_trained_model(model,loss_val,location,save_params=True): 26 | basename,file_extension = os.path.basename(location).split('.') 27 | new_name = basename +'_fully_trained.' + file_extension 28 | new_location = os.path.join(os.path.dirname(location),new_name) 29 | 30 | save_checkpoint(model,loss_val,new_location,save_params=save_params) 31 | -------------------------------------------------------------------------------- /wav2mov/core/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(level=logging.DEBUG,format="%(levelname)s : %(name)s : %(asctime)s | %(msg)s ") 3 | 4 | 5 | def get_module_level_logger(name): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | logger.propagate = False 9 | return logger 10 | -------------------------------------------------------------------------------- /wav2mov/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | from colorama import init,Fore,Back,Style 2 | 3 | from .os_utils import is_windows 4 | from tqdm import tqdm 5 | 6 | init() # colorama 7 | 8 | 9 | class COLOURS: 10 | RED = Fore.RED 11 | GREEN = Fore.GREEN 12 | BLUE = Fore.BLUE 13 | YELLOW = Fore.YELLOW 14 | WHITE = Fore.WHITE 15 | 16 | # def colored(text,color=None,reset=True): 17 | # if color is None: 18 | # color = COLOURS.WHITE 19 | # return f"{color}{text}{Style.RESET_ALL}" if reset else f"{color}{text}" 20 | 21 | def __coloured(text:str,colour)->str: 22 | return f"{colour}{text}{Style.RESET_ALL}" 23 | 24 | def error(text:str)->str: 25 | return __coloured(text,COLOURS.RED) 26 | 27 | def debug(text:str)->str: 28 | return __coloured(text,COLOURS.BLUE) 29 | 30 | def success(text:str)->str: 31 | return __coloured(text,COLOURS.GREEN) 32 | 33 | 34 | 35 | 36 | 37 | def string_wrapper(msg): 38 | if is_windows(): 39 | return debug(msg) 40 | else: 41 | return msg 42 | 43 | def get_tqdm_iterator(iterable,description,colour="green"): 44 | """setups the tqdm progress bar 45 | 46 | Arguments: 47 | `iterable` {Iterable} 48 | `description` {str} -- Description to be shown in progress bar 49 | 50 | Keyword Arguments: 51 | `colour` {str} -- colour of progress bar (default: {"green"}) 52 | 53 | Returns: 54 | `tqdm` 55 | """ 56 | return tqdm(iterable,ascii=True,desc=description,total=len(iterable),colour=colour) 57 | 58 | -------------------------------------------------------------------------------- /wav2mov/core/utils/os_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def is_windows(): 5 | return sys.platform.startswith('win') -------------------------------------------------------------------------------- /wav2mov/datasets/create_file_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from tqdm import tqdm 4 | 5 | FILENAME = 'filenames.txt' 6 | FILENAME_TRAIN = 'filenames_train.txt' 7 | FILENAME_TEST = 'filenames_test.txt' 8 | DATASET = 'grid_dataset_a5_500_a10to14' 9 | FROM_DIR = '' 10 | TO_DIR = '' 11 | 12 | VIDEOS_PER_ACTOR = 30 13 | 14 | if len(FROM_DIR)==0: 15 | FROM_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | FROM_DIR = os.path.join(FROM_DIR,DATASET) 17 | TO_DIR = FROM_DIR 18 | 19 | def get_folders_list(): 20 | folders = [folder for _,folder,_ in os.walk(FROM_DIR)][0] 21 | folders_strs = [] 22 | for folder in tqdm(folders): 23 | if os.path.isdir(os.path.join(FROM_DIR,folder)): 24 | folders_strs.append(folder+'\n') 25 | return folders_strs 26 | 27 | def write_to_text_file(folders_list): 28 | with open(os.path.join(TO_DIR,FILENAME),'w') as file: 29 | file.writelines(folders_list) 30 | print(f'list written to {os.path.join(TO_DIR,FILENAME)}') 31 | 32 | 33 | def write_to_file(file_path,content_list): 34 | with open(file_path,'w') as file: 35 | file.writelines(content_list) 36 | 37 | def get_actors(folders_list): 38 | actors = set() 39 | for folder in folders_list: 40 | actors.add(folder.split('_')[0]) 41 | return sorted(actors) 42 | 43 | def get_train_test_list(folders_list): 44 | actors = get_actors(folders_list) 45 | train_actors = set(actors[:-1]) 46 | train_dict = {} 47 | test_list = [] 48 | for folder in folders_list: 49 | actor = folder.split('_')[0] 50 | if actor in train_actors: 51 | if len(train_dict.get(actor,[])) possible_num_frames: 97 | raise ValueError(f'given audio has {possible_num_frames} frames but {num_frames} frames requested.') 98 | start_idx = (possible_num_frames-num_frames)//2 99 | end_idx = (possible_num_frames+num_frames)//2 #start_idx + (num_frames) 100 | padding = torch.zeros((1,self.coarticulation_factor*self.stride),device=self.device) 101 | audio = torch.cat([padding,audio,padding],dim=1) 102 | if get_mfccs: 103 | frames = [self.get_frame_from_idx(audio,idx) for idx in range(start_idx,end_idx)] 104 | frames = [self.extract_mfccs(frame) for frame in frames]# each of shape [t,13] 105 | # frames = [((frame-mean[i])/(std[i]+1e-7)) for i,frame in enumerate(frames)] 106 | frames = torch.stack(frames,axis=0)# 1,num_frames,(t,13) 107 | # logger.warning(f'frames {frames.shape} mean : {mean.shape}') 108 | return (frames-mean)/(std+1e-7) 109 | frames = [self.get_frame_from_idx(audio,idx) for idx in range(start_idx,end_idx)] 110 | #each frame is of shape (1,frame_size) so can be catenated along zeroth dimension . 111 | return torch.cat(frames,dim=0) 112 | 113 | def get_limited_audio(self,audio,num_frames,start_frame=None,get_mfccs=False) : 114 | possible_num_frames = audio.shape[-1]//self.stride 115 | if num_frames>possible_num_frames: 116 | logger.error(f'Given num_frames {num_frames} is larger the possible_num_frames {possible_num_frames}') 117 | 118 | mean,std = self.get_mfccs_mean_std(audio) 119 | padding = torch.zeros((audio.shape[0],self.coarticulation_factor*self.stride),device=self.device) 120 | audio = torch.cat([padding,audio,padding],dim=1) 121 | 122 | # possible_num_frames = audio.shape[-1]//self.stride 123 | actual_start_frame = (possible_num_frames-num_frames)//2 124 | # [......................................................] 125 | # [................................] 126 | # |<-----num_frames---------------->| 127 | #.........^ 128 | # actual start frame 129 | if start_frame is None: 130 | start_frame = actual_start_frame 131 | 132 | if start_frame+num_frames>possible_num_frames:#[why > not >=]think if possible num_frames is 50 and 50 is the required num_frames and start_frame is zero 133 | logger.warning(f'Given Audio has {possible_num_frames} frames. Given starting frame {start_frame} cannot be consider for getting {num_frames} frames. Changing startframes to {actual_start_frame} frame.') 134 | start_frame = actual_start_frame 135 | 136 | end_frame = start_frame + (num_frames) #exclusive 137 | 138 | start_pos = self.__get_center_idx(start_frame) 139 | end_pos = self.__get_center_idx(end_frame-1) 140 | 141 | audio = audio[:,self.__get_start_idx(start_pos):self.__get_end_idx(end_pos)] 142 | if get_mfccs: 143 | mfccs = list(map(self.extract_mfccs,audio)) 144 | # mfccs = [(mfcc-mean[i]/(std[i]+1e-7)) for i,mfcc in enumerate(mfccs)] 145 | mfccs = torch.stack(mfccs,axis=0) 146 | return (mfccs-mean)/(std+1e-7) 147 | return audio 148 | 149 | def load_audio(audio_path): 150 | audio,sr = librosa.load(audio_path,sr=params.AUDIO_SF) 151 | return torch.from_numpy(audio) 152 | 153 | def get_num_frames(audio): 154 | audio_len = audio.squeeze().shape[0] 155 | return audio_len//params.STRIDE 156 | 157 | def preprocess_audio(audio): 158 | if not isinstance(audio,Tensor): 159 | audio = torch.from_numpy(audio) 160 | audio = (audio-params.AUDIO_MEAN)/(params.AUDIO_STD + params.EPSILON) 161 | framewise_mfccs = AudioUtil(params.AUDIO_SF,params.COARTICULATION_FACTOR, 162 | params.STRIDE,params.DEVICE).get_audio_frames(audio,get_mfccs=True) 163 | return framewise_mfccs 164 | -------------------------------------------------------------------------------- /wav2mov/inference/generate.py: -------------------------------------------------------------------------------- 1 | # get audio and image 2 | # get audio_frames and extract face from image and resize 3 | # apply transforms 4 | # generator forward pass 5 | import argparse 6 | import os 7 | import torch 8 | 9 | from inference import audio_utils,image_utils,utils,model_utils 10 | logger = utils.get_module_level_logger(__name__) 11 | 12 | def is_exists(file): 13 | return os.path.isfile(file) 14 | 15 | def create_batch(sample): 16 | return sample.unsqueeze(0) 17 | 18 | def save_video(audio,video_frames): 19 | video_path = os.path.join(DIR,'out','fake_video_without_audio.avi') 20 | ##################################################################### 21 | #⚠ IMPORTANT TO DE NORMALiZE ELSE OUTPUT WILL BE BLACK! 22 | ##################################################################### 23 | video_frames = ((video_frames*0.5)+0.5)*255 24 | utils.save_video(video_path,audio,video_frames) 25 | 26 | DIR = os.path.dirname(os.path.abspath(__file__)) 27 | 28 | def generate_video(image,audio): 29 | audio_feats = audio_utils.preprocess_audio(audio) 30 | num_frames = audio_feats.shape[0] 31 | image = image_utils.preprocess_image(image) 32 | images = image_utils.repeat_img(image,num_frames) 33 | logger.debug(f'✅ Preprocessing Done.') 34 | model = model_utils.get_model() 35 | logger.debug(f'✅ Model Loaded.') 36 | with torch.no_grad(): 37 | model.eval() 38 | gen_frames = model(create_batch(audio_feats),create_batch(images)) 39 | logger.debug(f'✅ Frames Generated.') 40 | B,T,*img_shape = gen_frames.shape 41 | gen_frames = gen_frames.reshape(B*T,*img_shape) 42 | save_video(audio,gen_frames) 43 | 44 | if __name__ == '__main__': 45 | logger.debug(f'✅ Modules loaded.') 46 | # default_image = r'ref_frame_Run_27_5_2021__16_6.png' 47 | # default_image = r'ref_frame_Run_19_6_2021__12_14.png' 48 | # default_image = r'image.jfif' 49 | # default_image = r'musk.jfif' 50 | default_image = r'train_ref_frame_Run_9_7_2021__14_52.png' 51 | default_image = r'train_ref_frame_Run_11_7_2021__19_18.png' 52 | # default_audio = r'03-01-01-01-02-02-01.wav' 53 | default_audio = r'audio.wav' 54 | arg_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 55 | description="Wav2Mov | Speech Controlled Facial Animation") 56 | arg_parser.add_argument('--image','-i',type=str,required=False, 57 | help='image containing face of the person', 58 | default=os.path.join(DIR,'inputs',default_image)) 59 | arg_parser.add_argument('--audio','-a',type=str,required=False, 60 | help='speech for which face should be animated', 61 | default=os.path.join(DIR,'inputs',default_audio)) 62 | options = arg_parser.parse_args() 63 | 64 | 65 | image_path = options.image 66 | audio_path = options.audio 67 | if not (is_exists(image_path)): 68 | raise FileNotFoundError(f'[ERROR] ❌ image path is incorrect :{image_path}') 69 | if not (is_exists(audio_path)): 70 | raise FileNotFoundError(f'[ERROR] ❌ audio path is incorrect :{audio_path}') 71 | image = image_utils.load_image(image_path) 72 | audio = audio_utils.load_audio(audio_path) 73 | generate_video(image,audio) 74 | logger.debug('✅ Video saved') 75 | -------------------------------------------------------------------------------- /wav2mov/inference/generate.sh: -------------------------------------------------------------------------------- 1 | python -m inference.generate 2 | -------------------------------------------------------------------------------- /wav2mov/inference/image_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import dlib 3 | import torch 4 | from torchvision import utils as vutils 5 | from torchvision.transforms import Normalize 6 | from inference import params 7 | 8 | face_detector = None 9 | def load_face_detector(): 10 | return dlib.get_frontal_face_detector() 11 | 12 | def convert_and_trim_bb(image, rect): 13 | """ from pyimagesearch 14 | https://www.pyimagesearch.com/2021/04/19/face-detection-with-dlib-hog-and-cnn/ 15 | """ 16 | # extract the starting and ending (x, y)-coordinates of the 17 | # bounding box 18 | start_x = rect.left() 19 | start_y = rect.top() 20 | endX = rect.right() 21 | endY = rect.bottom() 22 | # ensure the bounding box coordinates fall within the spatial 23 | # dimensions of the image 24 | start_x = max(0, start_x) 25 | start_y = max(0, start_y) 26 | endX = min(endX, image.shape[1]) 27 | endY = min(endY, image.shape[0]) 28 | # compute the width and height of the bounding box 29 | w = endX - start_x 30 | h = endY - start_y 31 | # return our bounding box coordinates 32 | return (start_x, start_y, w, h) 33 | 34 | 35 | def load_image(image_path): 36 | img = cv2.imread(image_path,cv2.IMREAD_COLOR) 37 | if params.IMAGE_CHANNELS==1: 38 | return cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 39 | # height,width = params.IMAGE_SIZE 40 | # img = cv2.resize(img,(width,height),cv2.INTER_CUBIC) 41 | return cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 42 | 43 | def save_image(image,path,normalize=False): 44 | vutils.save_image(image,path,normalize=normalize) 45 | 46 | def show_image(image): 47 | cv2.imshow('image',image) 48 | cv2.waitKey(0) 49 | cv2.destroyAllWindows() 50 | 51 | def preprocess_image(image): 52 | global face_detector 53 | if face_detector is None: 54 | face_detector = load_face_detector() 55 | face = face_detector(image)[0] 56 | x,y,w,h = convert_and_trim_bb(image,face) 57 | target_height,target_width = params.IMAGE_SIZE 58 | image = cv2.resize(image[y:y+h,x:x+w],(target_width,target_height),interpolation=cv2.INTER_CUBIC) 59 | if len(image.shape)==2: 60 | image = image.reshape(image.shape[0],image.shape[1],1) 61 | # show_image(image) 62 | image = torch.from_numpy(image).float().permute(2,0,1)/255 63 | image= Normalize(params.VIDEO_MEAN,params.VIDEO_STD)(image) 64 | # show_image(image.permute(1,2,0).numpy()) 65 | return image 66 | 67 | def repeat_img(img,n): 68 | if img.shape==3: 69 | img = img.unsqueeze(0) 70 | 71 | return img.repeat(n,1,1,1) 72 | -------------------------------------------------------------------------------- /wav2mov/inference/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from inference.models import Generator 4 | from inference import params 5 | from inference.utils import get_module_level_logger 6 | logger = get_module_level_logger(__name__) 7 | 8 | gen_hparams = { 9 | "in_channels": 3, 10 | "chs": [64, 128, 256, 512, 1024], 11 | "latent_dim": 272, 12 | "latent_dim_id": [8, 8], 13 | "comment": "laten_dim not eq latent_dim_id + latent_dim_audio, its 4x4 + 256", 14 | "latent_dim_audio": 256, 15 | "device": "cpu", 16 | "lr": 2e-4 17 | } 18 | 19 | 20 | def get_model(checkpoint_path=None): 21 | if checkpoint_path is None: 22 | checkpoint_path = params.GEN_CHECKPOINT_PATH 23 | 24 | if not os.path.isfile(checkpoint_path): 25 | logger.error(f'NO FILE : {checkpoint_path}') 26 | raise FileNotFoundError( 27 | 'Please make sure to put generator file in pt_files folder ') 28 | state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict'] 29 | model = Generator(gen_hparams) 30 | model.load_state_dict(state_dict) 31 | model.eval() 32 | return model 33 | -------------------------------------------------------------------------------- /wav2mov/inference/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator.frame_generator import Generator 2 | -------------------------------------------------------------------------------- /wav2mov/inference/models/generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/inference/models/generator/__init__.py -------------------------------------------------------------------------------- /wav2mov/inference/models/generator/audio_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from inference.models.layers.conv_layers import Conv1dBlock,Conv2dBlock 4 | from inference.models.layers.debug_layers import IdentityDebugLayer 5 | from inference.models.utils import squeeze_batch_frames,get_same_padding 6 | from inference.utils import get_module_level_logger 7 | 8 | logger = get_module_level_logger(__name__) 9 | 10 | class AudioEnocoder(nn.Module): 11 | def __init__(self,hparams): 12 | super().__init__() 13 | self.hparams = hparams 14 | padding_31 = get_same_padding(3,1) 15 | padding_32 = get_same_padding(3,2) 16 | padding_42 = get_same_padding(4,2) 17 | # each frame has mfcc of shape (7,13) 18 | self.conv_encoder = nn.Sequential( 19 | Conv2dBlock(1,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()),#7,13 20 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 21 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 22 | 23 | Conv2dBlock(32,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13 24 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 25 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 26 | 27 | Conv2dBlock(64,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13 28 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 29 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 30 | 31 | Conv2dBlock(128,256,3,(2,1),padding=padding_32,use_norm=True,use_act=True,act=nn.ReLU()), #4,13 32 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 33 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 34 | 35 | Conv2dBlock(256,512,(4,3),(2,1),padding=1,use_norm=True,use_act=True,act=nn.ReLU()), #2,13 36 | Conv2dBlock(512,512,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 37 | Conv2dBlock(512,512,(2,5),(1,2),padding=(0,1),use_norm=True,use_act=True,act=nn.ReLU()),#1,6, 38 | ) 39 | 40 | self.features_len = 6*512#out of conv layers 41 | self.num_layers = 1 42 | self.hidden_size = self.hparams['latent_dim_audio'] 43 | self.gru = nn.GRU(input_size=self.features_len, 44 | hidden_size=self.hidden_size, 45 | num_layers=1, 46 | batch_first=True) 47 | self.final_act = nn.Tanh() 48 | 49 | def forward(self, x): 50 | """ x : audio frames of shape B,T,t,13""" 51 | batch_size,num_frames,*_ = x.shape 52 | x = self.conv_encoder(squeeze_batch_frames(x).unsqueeze(1))#add channel dimension 53 | x = x.reshape(batch_size,num_frames,self.features_len) 54 | x,_ = self.gru(x) 55 | return self.final_act(x) # shape (batch_size,num_frames,hidden_size=latent_dim_audio) 56 | -------------------------------------------------------------------------------- /wav2mov/inference/models/generator/frame_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn,optim 3 | 4 | from inference.models.generator.audio_encoder import AudioEnocoder 5 | from inference.models.generator.id_encoder import IdEncoder 6 | from inference.models.generator.id_decoder import IdDecoder 7 | 8 | from inference.models.utils import squeeze_batch_frames 9 | from inference.utils import get_module_level_logger 10 | logger = get_module_level_logger(__name__) 11 | 12 | class Generator(nn.Module): 13 | def __init__(self,hparams): 14 | super().__init__() 15 | self.hparams = hparams 16 | self.id_encoder = IdEncoder(self.hparams) 17 | self.id_decoder = IdDecoder(self.hparams) 18 | self.audio_encoder = AudioEnocoder(self.hparams) 19 | # self.noise_encoder = NoiseEncoder(self.hparams) 20 | 21 | def forward(self,audio_frames,ref_frames): 22 | batch_size,num_frames,*_ = ref_frames.shape 23 | assert num_frames == audio_frames.shape[1] 24 | encoded_id , intermediates = self.id_encoder(squeeze_batch_frames(ref_frames)) 25 | encoded_id = encoded_id.reshape(batch_size*num_frames,-1,1,1) 26 | encoded_audio = self.audio_encoder(audio_frames).reshape(batch_size*num_frames,-1,1,1) 27 | # encoded_noise = self.noise_encoder(batch_size,num_frames).reshape(batch_size*num_frames,-1,1,1) 28 | # logger.debug(f'encoded_id {encoded_id.shape} encoded_audio {encoded_audio.shape} encoded_noise {encoded_noise.shape}') 29 | encoded = torch.cat([encoded_id,encoded_audio],dim=1)#along channel dimension 30 | gen_frames = self.id_decoder(encoded,intermediates) 31 | _,*img_shape = gen_frames.shape 32 | return gen_frames.reshape(batch_size,num_frames,*img_shape) 33 | 34 | def get_optimizer(self): 35 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5, 0.999)) 36 | -------------------------------------------------------------------------------- /wav2mov/inference/models/generator/id_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from inference.models.layers.conv_layers import Conv2dBlock,ConvTranspose2dBlock,DoubleConvTranspose2d 5 | from inference.models.utils import get_same_padding 6 | from inference.utils import get_module_level_logger 7 | 8 | logger = get_module_level_logger(__name__) 9 | 10 | class IdDecoder(nn.Module): 11 | def __init__(self,hparams): 12 | super().__init__() 13 | self.hparams = hparams 14 | self.in_channels = hparams['in_channels'] 15 | self.latent_channels = hparams['latent_dim'] 16 | latent_id_height,latent_id_width = hparams['latent_dim_id'] 17 | chs = self.hparams['chs'][::-1]#1024,512,256,128,64 18 | chs = [self.latent_channels] + chs + [self.in_channels] 19 | padding = get_same_padding(kernel_size=4,stride=2) 20 | self.convs = nn.ModuleList() 21 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[0], 22 | out_ch=chs[1], 23 | kernel_size=(latent_id_height,latent_id_width), 24 | stride=2, 25 | padding=0, 26 | use_norm=True, 27 | use_act=True, 28 | act=nn.ReLU() 29 | )) 30 | for i in range(1,len(chs)-2): 31 | self.convs.append( DoubleConvTranspose2d(in_ch=chs[i], 32 | skip_ch=chs[i], 33 | out_ch=chs[i+1], 34 | kernel_size=(4,4), 35 | stride=2, 36 | padding=padding, 37 | use_norm=True, 38 | use_act=True, 39 | act=nn.ReLU()) 40 | ) 41 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[-2], 42 | out_ch=chs[-1], 43 | kernel_size=(4,4), 44 | stride=2, 45 | padding=padding, 46 | use_norm=True, 47 | use_act=True, 48 | act=nn.Tanh() 49 | )) 50 | def forward(self,encoded,skip_outs): 51 | """[summary] 52 | 53 | Args: 54 | skip_outs ([type]): 64,128,256,512,1024 55 | encoded ([type]): latent_channels, 56 | """ 57 | encoded = encoded.reshape(encoded.shape[0],self.latent_channels,1,1) 58 | skip_outs = skip_outs[::-1] 59 | for i,block in enumerate(self.convs[:-2]): 60 | x = block(encoded) 61 | encoded = torch.cat([x,skip_outs[i]],dim=1)#cat along channel axis 62 | 63 | encoded = self.convs[-2](encoded) 64 | return self.convs[-1](encoded) 65 | -------------------------------------------------------------------------------- /wav2mov/inference/models/generator/id_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from inference.models.layers.conv_layers import Conv2dBlock 5 | from inference.models.utils import get_same_padding 6 | from inference.utils import get_module_level_logger 7 | 8 | logger = get_module_level_logger(__name__) 9 | 10 | class IdEncoder(nn.Module): 11 | def __init__(self,hparams): 12 | super().__init__() 13 | self.hparams = hparams 14 | in_channels = hparams['in_channels'] 15 | chs = self.hparams['chs'] 16 | chs = [in_channels] + chs +[1]# 1 is added here not in params because see how channels are being used in id_decoder 17 | padding = get_same_padding(kernel_size=4,stride=2) 18 | self.conv_blocks = nn.ModuleList(Conv2dBlock(chs[i],chs[i+1], 19 | kernel_size=(4,4), 20 | stride=2, 21 | padding=padding, 22 | use_norm=True, 23 | use_act=True, 24 | act=nn.ReLU() 25 | ) for i in range(len(chs)-2) 26 | ) 27 | 28 | self.conv_blocks.append(Conv2dBlock(chs[-2],chs[-1], 29 | kernel_size=(4,4), 30 | stride=2, 31 | padding=padding, 32 | use_norm=False, 33 | use_act=True, 34 | act=nn.Tanh())) 35 | def forward(self,images): 36 | intermediates = [] 37 | for block in self.conv_blocks[:-1]: 38 | images = block(images) 39 | intermediates.append(images) 40 | encoded = self.conv_blocks[-1](images) 41 | return encoded,intermediates 42 | -------------------------------------------------------------------------------- /wav2mov/inference/models/layers/conv_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from inference.models.utils import get_same_padding 5 | 6 | class DoubleConv2dBlock(nn.Module): 7 | """height and width are halved in feature map""" 8 | def __init__(self, 9 | in_ch, 10 | out_ch, 11 | use_norm=True, 12 | act=None): 13 | super().__init__() 14 | self.use_norm = use_norm 15 | self.conv1 = nn.Conv2d(in_ch, out_ch, 3,padding=get_same_padding(3,1),bias=not self.use_norm) 16 | self.batch_norm = nn.BatchNorm2d(out_ch) 17 | self.relu = nn.LeakyReLU(0.2) if act is None else act() 18 | self.conv2 = nn.Conv2d(out_ch, out_ch, 3) 19 | 20 | def forward(self, x): 21 | # print('double conv block',x.shape,type(x),x.device,next(self.parameters()).device) 22 | x = self.conv1(x) 23 | if self.use_norm : 24 | x = self.batch_norm(x) 25 | return self.relu(self.conv2(self.relu(x))) 26 | 27 | class Conv1dBlock(nn.Module): 28 | def __init__(self, 29 | in_ch, 30 | out_ch, 31 | kernel_size, 32 | stride, 33 | padding, 34 | use_norm=True, 35 | use_act=True,act=None, 36 | residual=False): 37 | 38 | super().__init__() 39 | self.use_norm = use_norm 40 | self.use_act = use_act 41 | self.act = nn.LeakyReLU(0.2) if act is None else act 42 | self.norm = nn.BatchNorm1d(out_ch) 43 | self.conv = nn.Conv1d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm) 44 | self.residual = residual 45 | 46 | def forward(self,in_x): 47 | x = self.conv(in_x) 48 | if self.use_norm: 49 | x = self.norm(x) 50 | if self.residual: 51 | x += in_x 52 | if self.use_act: 53 | x = self.act(x) 54 | return x 55 | 56 | class Conv2dBlock(nn.Module): 57 | def __init__(self, 58 | in_ch, 59 | out_ch, 60 | kernel_size, 61 | stride, 62 | padding, 63 | use_norm=True, 64 | use_act=True,act=None, 65 | residual=False): 66 | 67 | super().__init__() 68 | self.use_norm = use_norm 69 | self.use_act = use_act 70 | self.act = nn.LeakyReLU(0.2) if act is None else act 71 | self.norm = nn.BatchNorm2d(out_ch) 72 | self.conv = nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm) 73 | self.residual = residual 74 | 75 | def forward(self,in_x): 76 | x = self.conv(in_x) 77 | if self.use_norm: 78 | x = self.norm(x) 79 | if self.residual: 80 | x += in_x 81 | if self.use_act: 82 | x = self.act(x) 83 | return x 84 | 85 | class ConvTranspose2dBlock(nn.Module): 86 | def __init__(self,in_ch, 87 | out_ch, 88 | kernel_size, 89 | stride, 90 | padding, 91 | use_norm=True, 92 | use_act=True, 93 | act=None): 94 | super().__init__() 95 | self.use_norm = use_norm 96 | self.use_act = use_act 97 | self.act = nn.LeakyReLU(0.2) if act is None else act 98 | self.norm = nn.BatchNorm2d(out_ch) 99 | self.conv = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm) 100 | 101 | def forward(self,x): 102 | x = self.conv(x) 103 | if self.use_norm: 104 | x = self.norm(x) 105 | if self.use_act: 106 | x = self.act(x) 107 | return x 108 | 109 | class DoubleConvTranspose2d(nn.Module): 110 | def __init__(self,in_ch, 111 | skip_ch, 112 | out_ch, 113 | kernel_size, 114 | stride, 115 | padding, 116 | use_norm=True, 117 | use_act=True, 118 | act=None): 119 | super().__init__() 120 | self.use_norm = use_norm 121 | self.use_act = use_act 122 | 123 | self.conv1 = nn.Conv2d(in_ch+skip_ch,in_ch,kernel_size=3,stride=1,padding=1,bias=False) 124 | self.norm1 = nn.BatchNorm2d(in_ch) 125 | self.conv2 = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding) 126 | self.norm2 = nn.BatchNorm2d(out_ch) 127 | self.act = nn.LeakyReLU(0.2) if act is None else act 128 | 129 | def forward(self,x): 130 | x = nn.functional.relu(self.norm1(self.conv1(x))) 131 | x = self.conv2(x) 132 | if self.use_norm: 133 | x = self.norm2(x) 134 | if self.use_act: 135 | x = self.act(x) 136 | return x 137 | -------------------------------------------------------------------------------- /wav2mov/inference/models/layers/debug_layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from inference.utils import get_module_level_logger 4 | logger = get_module_level_logger(__name__) 5 | 6 | class IdentityDebugLayer(nn.Module): 7 | def __init__(self,name): 8 | super().__init__() 9 | self.name = name 10 | def forward(self,x): 11 | logger.debug(f'{self.name} : {x.shape}') 12 | return x 13 | -------------------------------------------------------------------------------- /wav2mov/inference/models/utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch.nn import init 3 | 4 | from inference.utils import get_module_level_logger 5 | logger = get_module_level_logger(__name__) 6 | 7 | def get_same_padding(kernel_size,stride,in_size=0): 8 | """https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks 9 | https://github.com/DinoMan/speech-driven-animation/blob/dc9fe4aa77b4e042177328ea29675c27e2f56cd4/sda/utils.py#L18-L21 10 | 11 | padding = 'same' 12 | • Padding such that feature map size has size In_size/Stride 13 | 14 | • Output size is mathematically convenient 15 | 16 | • Also called 'half' padding 17 | 18 | out = (in-k+2*p)/s + 1 19 | if out == in/s: 20 | in/s = (in-k+2*p)/s + 1 21 | ((in/s)-1)*s + k -in = 2*p 22 | (in-s)+k-in = 2*p 23 | in case of s==1: 24 | p = (k-1)/2 25 | """ 26 | out_size = ceil(in_size/stride) 27 | return ceil(((out_size-1)*stride+ kernel_size-in_size)/2)#(k-1)//2 for same padding 28 | 29 | def init_weights(net, init_type='normal', init_gain=0.02): 30 | """Initialize network weights. 31 | src : https://github1s.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 32 | Parameters: 33 | net (network) -- network to be initialized 34 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 35 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 36 | 37 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 38 | work better for some applications. Feel free to try yourself. 39 | """ 40 | def init_func(m): # define the initialization function 41 | classname = m.__class__.__name__ 42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 43 | if init_type == 'normal': 44 | init.normal_(m.weight.data, 0.0, init_gain) 45 | elif init_type == 'xavier': 46 | init.xavier_normal_(m.weight.data, gain=init_gain) 47 | elif init_type == 'kaiming': 48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 49 | elif init_type == 'orthogonal': 50 | init.orthogonal_(m.weight.data, gain=init_gain) 51 | else: 52 | raise NotImplementedError( 53 | 'initialization method [%s] is not implemented' % init_type) 54 | if hasattr(m, 'bias') and m.bias is not None: 55 | init.constant_(m.bias.data, 0.0) 56 | # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 57 | elif classname.find('BatchNorm2d') != -1: 58 | init.normal_(m.weight.data, 1.0, init_gain) 59 | init.constant_(m.bias.data, 0.0) 60 | 61 | logger.debug(f'initializing {net.__class__.__name__} with {init_type} weights') 62 | # print('initialize network with %s' % init_type) 63 | net.apply(init_func) # apply the initialization function 64 | 65 | 66 | def init_net(net, init_type='normal', init_gain=0.02): 67 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 68 | Parameters: 69 | net (network) -- the network to be initialized 70 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 71 | gain (float) -- scaling factor for normal, xavier and orthogonal. 72 | 73 | 74 | Return an initialized network. 75 | """ 76 | 77 | init_weights(net, init_type, init_gain=init_gain) 78 | return net 79 | 80 | 81 | def squeeze_batch_frames(target): 82 | batch_size,num_frames,*extra = target.shape 83 | return target.reshape(batch_size*num_frames,*extra) 84 | -------------------------------------------------------------------------------- /wav2mov/inference/params.py: -------------------------------------------------------------------------------- 1 | EPSILON = 1e-7 2 | DEVICE = 'cpu' 3 | COARTICULATION_FACTOR = 2 4 | AUDIO_SF = 16000 5 | VIDEO_FPS = 24 6 | STRIDE = AUDIO_SF//VIDEO_FPS 7 | IMAGE_CHANNELS = 3 8 | IMAGE_SIZE = (256,256) 9 | VIDEO_MEAN = [0.5,0.5,0.5] 10 | VIDEO_STD = [0.5,0.5,0.5] 11 | # AUDIO_MEAN = 5.6196e-07 12 | # AUDIO_STD = 0.0142 13 | AUDIO_MEAN = -4.3688e-07 14 | AUDIO_STD = 0.0123 15 | GEN_CHECKPOINT_PATH = r'.\models\checkpoints\gen_Run_11_7_2021__19_18.pt' 16 | -------------------------------------------------------------------------------- /wav2mov/inference/quantize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.quantization 4 | from torch import nn 5 | # here is our floating point instance 6 | from inference.model_utils import get_model 7 | # import the modules used here in this recipe 8 | 9 | 10 | def print_size_of_model(m, label=""): 11 | torch.save(m.state_dict(), "temp.p") 12 | size=os.path.getsize("temp.p") 13 | print("model: ",label,' \t','Size (MB):', size/1e6) 14 | os.remove('temp.p') 15 | return size 16 | 17 | DIR = os.path.dirname(os.path.abspath(__file__)) 18 | model = get_model() 19 | target_path =r'models/checkpoints/gen_quantized.pt' 20 | # this is the call that does the work 21 | model_quantized = torch.quantization.quantize_dynamic( 22 | model, {nn.Conv1d,nn.Conv2d,nn.ConvTranspose2d, nn.Linear,nn.GRU}, dtype=torch.qint8 23 | ) 24 | 25 | # compare the sizes 26 | # f=print_size_of_model(float_lstm,"fp32") 27 | f=print_size_of_model(model,"fp32") 28 | q=print_size_of_model(model_quantized,"int8") 29 | print("{0:.2f} times smaller".format(f/q)) 30 | -------------------------------------------------------------------------------- /wav2mov/inference/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torchvision.io import write_video 5 | from scipy.io.wavfile import write as write_audio 6 | from moviepy import editor as mpy 7 | 8 | from inference import params 9 | def get_module_level_logger(name): 10 | m_logger = logging.getLogger(name) 11 | m_logger.setLevel(logging.DEBUG) 12 | m_logger.propagate = False 13 | handler = logging.StreamHandler() 14 | handler.setFormatter(logging.Formatter("%(levelname)-5s : %(filename)s :%(lineno)s | %(asctime)s | %(msg)s ", "%b %d,%Y %H:%M:%S")) 15 | m_logger.addHandler(handler) 16 | return m_logger 17 | 18 | logger = get_module_level_logger('utils') 19 | 20 | def save_video(video_path, audio, video_frames): 21 | """ 22 | audio_frames : C,S 23 | video_frames : F,C,H,W 24 | """ 25 | # if has batch dimension remove it 26 | hparams = {'video_fps':params.VIDEO_FPS,'audio_sf':params.AUDIO_SF} 27 | if len(video_frames.shape) == 5: 28 | video_frames = video_frames[0] 29 | if video_frames.shape[1] == 1: 30 | video_frames = video_frames.repeat(1, 3, 1, 1) 31 | logger.warning('Grayscale images...') 32 | logger.debug(f'✅ video frames :{video_frames.shape[:]}, audio : {audio.shape[:]}') 33 | video_frames = video_frames.to(torch.uint8) 34 | os.makedirs(os.path.dirname(video_path),exist_ok=True) 35 | write_video(filename=video_path, 36 | video_array=video_frames.permute(0, 2, 3, 1), 37 | fps=hparams['video_fps'], 38 | video_codec="h264", 39 | ) 40 | dir_name = os.path.dirname(video_path) 41 | temp_audio_path = os.path.join(dir_name, 'temp', 'temp_audio.wav') 42 | os.makedirs(os.path.dirname(temp_audio_path), exist_ok=True) 43 | write_audio(temp_audio_path,hparams['audio_sf'], audio.cpu().numpy().reshape(-1)) 44 | 45 | video_clip = mpy.VideoFileClip(video_path) 46 | audio_clip = mpy.AudioFileClip(temp_audio_path) 47 | video_clip.audio = audio_clip 48 | video_clip.write_videofile(os.path.join(dir_name, 'fake_video_with_audio.avi'), fps=hparams['video_fps'], codec='png',logger=None) 49 | -------------------------------------------------------------------------------- /wav2mov/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from datetime import datetime 5 | from torch.utils.tensorboard.writer import SummaryWriter 6 | from pythonjsonlogger import jsonlogger 7 | 8 | logging.basicConfig(level=logging.ERROR,format="%(levelname)-5s : %(name)s : %(asctime)s | %(msg)s ") 9 | TIME_FORMAT = "%b %d,%Y %H:%M:%S" 10 | 11 | def get_module_level_logger(name): 12 | m_logger = logging.getLogger(name) 13 | m_logger.setLevel(logging.DEBUG) 14 | m_logger.propagate = False 15 | handler = logging.StreamHandler() 16 | handler.setFormatter(logging.Formatter("%(levelname)-5s : %(name)s : %(asctime)s | %(msg)s ",TIME_FORMAT)) 17 | m_logger.addHandler(handler) 18 | return m_logger 19 | 20 | 21 | logger = get_module_level_logger(__name__) 22 | class CustomJsonFormatter(jsonlogger.JsonFormatter): 23 | def add_fields(self,log_record,record,message_dict): 24 | # print(log_record,vars(record),message_dict) 25 | # for key,value in vars(record).items(): 26 | # print(f'{key} : {value}') 27 | super().add_fields(log_record,record,message_dict) 28 | # if not log_record.get('timestamp'): 29 | # now = datetime.utcnow().strftime(TIME_FORMAT) 30 | # log_record['timestamp'] = now 31 | # if log_record.get('level'): 32 | # log_record['level'] = log_record['level'].upper() 33 | # else: 34 | # log_record['level'] = record.levelname 35 | if log_record.get('asctime'): 36 | log_record['asctime'] = datetime.utcnow().strftime(TIME_FORMAT) 37 | 38 | 39 | 40 | class TensorLogger: 41 | def __init__(self,runs_dir): 42 | self.writers = {} 43 | self.runs_dir = runs_dir 44 | def create_writer(self,name): 45 | self.writers[name] = SummaryWriter(os.path.join(self.runs_dir,name)) 46 | 47 | def add_writer(self, name, writer): 48 | self.writers[name] = writer 49 | 50 | def add_writers(self,names): 51 | for name in names: 52 | self.create_writer(name) 53 | 54 | def add_scalar(self,writer_name,tag,scalar,global_step): 55 | if writer_name not in self.writers: 56 | logger.warning(f'No writer found named {writer_name}') 57 | self.create_writer(writer_name) 58 | self.writers[writer_name].add_scalar(tag,scalar,global_step) 59 | 60 | 61 | def add_scalars(self,d,global_step): 62 | for writer_name,(tag,scalar) in d.items(): 63 | self.add_scalar(writer_name,tag,scalar,global_step) 64 | 65 | 66 | 67 | 68 | 69 | class Logger: 70 | def __init__(self, name): 71 | self.log_fullpath = None 72 | self.name = name 73 | self.logger = logging.getLogger(name) 74 | self.logger.propagate = False 75 | self.logger.setLevel(logging.DEBUG) 76 | self.is_json = False 77 | self.is_first_log = True 78 | 79 | def __add_handler(self, handler: logging.Handler): 80 | handler.setLevel(logging.DEBUG) 81 | self.logger.addHandler(handler) 82 | 83 | @classmethod 84 | def __get_formatter(cls, fmt: str): 85 | return logging.Formatter(fmt,TIME_FORMAT) 86 | 87 | @classmethod 88 | def __get_json_formatter(cls,fmt): 89 | return CustomJsonFormatter(fmt=fmt) 90 | 91 | def __json_log_begin(self): 92 | # print(f'inside json begin') 93 | with open(self.log_fullpath,'a+') as file: 94 | # print(file.read()) 95 | file.write('[\n') 96 | # print(file.read()) 97 | 98 | def __json_log_end(self): 99 | print('writing "]" to json log') 100 | with open(self.log_fullpath,'a+') as file: 101 | file.write(']\n') 102 | 103 | 104 | def add_filehandler(self, log_fullpath, fmt:str=None,in_json=False): 105 | self.log_fullpath = log_fullpath 106 | if fmt is None: 107 | fmt = '%(levelname)-5s : %(filename)s : %(asctime)s : line no: %(lineno)d : %(message)s' 108 | if in_json: 109 | self.is_json = True 110 | folder = os.path.dirname(self.log_fullpath) 111 | filename= os.path.basename(self.log_fullpath) 112 | filename = filename.split('.')[0] + '.json' 113 | 114 | self.log_fullpath = os.path.join(folder,filename) 115 | 116 | self.__json_log_begin() 117 | file_handler = logging.FileHandler(self.log_fullpath) 118 | file_handler.setFormatter(self.__get_json_formatter(fmt)) 119 | else: 120 | file_handler = logging.FileHandler(self.log_fullpath) 121 | file_handler.setFormatter(self.__get_formatter(fmt)) 122 | # self.log_fullpath = re.sub(r'(\\)',os.sep,self.log_fullpath) 123 | os.makedirs(os.path.dirname(self.log_fullpath), exist_ok=True) 124 | self.__add_handler(file_handler) 125 | 126 | def add_console_handler(self, fmt: str = None): 127 | if fmt is None: 128 | fmt = "%(levelname)-5s : %(filename)s : %(asctime)s : line no: %(lineno)d : %(message)s" 129 | console_handler = logging.StreamHandler() 130 | console_handler.setFormatter(self.__get_formatter(fmt)) 131 | self.__add_handler(console_handler) 132 | 133 | def __add_comma(self): 134 | if self.is_first_log: 135 | self.is_first_log = False 136 | return 137 | with open(self.log_fullpath,'a+') as file: 138 | file.write(',\n') 139 | 140 | @property 141 | def debug(self): 142 | if self.is_json:self.__add_comma() 143 | return self.logger.debug 144 | 145 | @property 146 | def info(self): 147 | if self.is_json:self.__add_comma() 148 | return self.logger.info 149 | 150 | @property 151 | def warning(self): 152 | if self.is_json:self.__add_comma() 153 | return self.logger.warning 154 | 155 | @property 156 | def error(self): 157 | if self.is_json:self.__add_comma() 158 | return self.logger.error 159 | 160 | @property 161 | def exception(self): 162 | if self.is_json:self.__add_comma() 163 | return self.logger.exception 164 | 165 | def cleanup(self): 166 | if self.is_json: 167 | self.__json_log_end() 168 | 169 | -------------------------------------------------------------------------------- /wav2mov/losses/ReadMe.md: -------------------------------------------------------------------------------- 1 | # Loss / Criterions 2 | 3 | ## Gan Loss 4 | 5 | + Binary Cross Entropy Loss 6 | 7 | ## Sync Loss 8 | 9 | + Binary Cross Entropy of (cosine similarity between audio and video embeddings) and the (real/fake) labels 10 | 11 | ## L1 Loss 12 | 13 | + Reconstruction loss : Mean Absolute Error 14 | -------------------------------------------------------------------------------- /wav2mov/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .gan_loss import GANLoss 2 | from .sync_loss import SyncLoss 3 | from .l1_loss import L1_Loss -------------------------------------------------------------------------------- /wav2mov/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch import nn 4 | 5 | class GANLoss(nn.Module): 6 | """ To abstract away the task of creating real/fake labels and calculating loss 7 | [Reference]: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 8 | """ 9 | def __init__(self,device,real_label=None,fake_label=0.0): 10 | super().__init__() 11 | self.real_label = real_label 12 | self.fake_label = fake_label 13 | # self.register_buffer('real_label',torch.tensor(real_label)) 14 | # self.register_buffer('fake_label',torch.tensor(fake_label)) 15 | self.loss = nn.BCEWithLogitsLoss() 16 | self.device = device 17 | 18 | def get_target_tensor(self,preds,is_real_target): 19 | real_label = torch.tensor( round(random.uniform(0.8,1),2) if self.real_label is None else self.real_label) 20 | target_tensor = real_label if is_real_target else torch.tensor(self.fake_label) 21 | return target_tensor.expand_as(preds).to(self.device) 22 | 23 | def forward(self,preds,is_real_target): 24 | target_tensor = self.get_target_tensor(preds,is_real_target) 25 | return self.loss(preds,target_tensor) 26 | -------------------------------------------------------------------------------- /wav2mov/losses/l1_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class L1_Loss(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | self.loss = nn.L1Loss() 8 | def forward(self,preds,targets): 9 | """ 10 | preds and targets are of the shape = (N,C,H,W) 11 | """ 12 | # height = preds.shape[-2] 13 | # preds = preds[...,height//2:,:] 14 | # targets = targets[...,height//2:,:] 15 | # # print(torch.mean(abs(preds-targets)).shape) #torch.size([]) 16 | # return torch.mean(abs(preds-targets))#get mean of error of all batch samples 17 | 18 | return self.loss(preds,targets) -------------------------------------------------------------------------------- /wav2mov/losses/sync_loss.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch import nn 4 | 5 | from wav2mov.logger import get_module_level_logger 6 | logger = get_module_level_logger(__name__) 7 | 8 | class SyncLoss(nn.Module): 9 | """Abstracts away the funciton of applying loss of synchronity between audio and video frames. 10 | [Reference]: 11 | https://github.com/Rudrabha/Wav2Lip/blob/master/hq_wav2lip_train.py#L181-L186 12 | """ 13 | def __init__(self,device,real_label=None,fake_label=0.0): 14 | super().__init__() 15 | self.real_label = real_label 16 | self.fake_label = fake_label 17 | # self.register_buffer('real_label',torch.tensor(real_label)) 18 | # self.register_buffer('fake_label',torch.tensor(fake_label)) 19 | self.loss = nn.BCELoss() 20 | self.device = device 21 | 22 | def get_target_tensor(self,preds,is_real_target): 23 | real_label = torch.tensor( round(random.uniform(0.8,1),2) if self.real_label is None else self.real_label ) 24 | target_tensor = if is_real_target else torch.tensor(self.fake_label) 25 | return target_tensor.expand_as(preds).to(self.device) 26 | 27 | def forward(self,preds,is_real_target): 28 | # preds = nn.functional.cosine_similarity(audio_embedding,video_embedding) 29 | target_tensor = self.get_target_tensor(preds,is_real_target) 30 | loss = self.loss(preds,target_tensor) 31 | logger.debug(f'[sync] loss {is_real_target} frames :{loss.item():0.4f}') 32 | return loss 33 | -------------------------------------------------------------------------------- /wav2mov/main/data.py: -------------------------------------------------------------------------------- 1 | """ provides utils for datasets and dataloaders """ 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | from collections import namedtuple 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms as vtransforms 10 | 11 | from wav2mov.core.data.datasets import AudioVideoDataset 12 | from wav2mov.core.data.transforms import ResizeGrayscale,Normalize 13 | from wav2mov.logger import get_module_level_logger 14 | logger = get_module_level_logger(__name__) 15 | 16 | DataloadersPack = namedtuple('dataloaders',('train','val')) 17 | 18 | 19 | To_Grayscale = vtransforms.Grayscale(1) 20 | def get_dataset(options,config,hparams): 21 | hparams = hparams['data'] 22 | root_dir = config['train_test_dataset_dir'] 23 | filenames_train_txt = config['filenames_train_txt'] 24 | filenames_test_txt = config['filenames_test_txt'] 25 | video_fps = hparams['video_fps'] 26 | audio_sf = hparams["audio_sf"] 27 | img_size = hparams['img_size'] 28 | target_img_shape = (hparams['img_channels'],img_size,img_size) 29 | 30 | num_videos_train = int(options.num_videos*0.9) 31 | num_videos_test = options.num_videos - num_videos_train 32 | 33 | mean_std_train = get_mean_and_std(root_dir,filenames_train_txt,img_channels=1) 34 | to_gray_transform = ResizeGrayscale(target_img_shape) 35 | normalize = Normalize(mean_std_train) 36 | transforms_composed = vtransforms.Compose([to_gray_transform,normalize]) 37 | dataset_train = AudioVideoDataset(root_dir=root_dir, 38 | filenames_text_filepath=filenames_train_txt, 39 | audio_sf=audio_sf, 40 | video_fps=video_fps, 41 | num_videos=num_videos_train, 42 | transform=transforms_composed) 43 | dataset_test = AudioVideoDataset(root_dir=root_dir, 44 | filenames_text_filepath=filenames_test_txt, 45 | audio_sf=audio_sf, 46 | video_fps=video_fps, 47 | num_videos=num_videos_test, 48 | transform=transforms_composed) 49 | return dataset_train,dataset_test 50 | 51 | def get_dataloaders(options,config,params,shuffle=True,collate_fn=None): 52 | hparams = params['data'] 53 | batch_size = hparams['mini_batch_size'] 54 | train_ds,test_ds = get_dataset(options,config,params) 55 | train_dl = DataLoader(train_ds,batch_size=batch_size,shuffle=shuffle,collate_fn=collate_fn,pin_memory=True) 56 | test_dl = DataLoader(test_ds,batch_size=1,shuffle=shuffle,collate_fn=collate_fn) 57 | return DataloadersPack(train_dl,test_dl) 58 | 59 | 60 | def get_video_mean_and_std(root_dir,filenames,img_channels): 61 | channels_sum,channels_squared_sum,num_batches = 0,0,0 62 | # num_items = 0 63 | progress_bar = tqdm(enumerate(filenames),ascii=True,total=len(filenames),desc='video') 64 | for _,filename in progress_bar: 65 | progress_bar.set_postfix({'file':filename}) 66 | video_path = os.path.join(root_dir,filename,'video_frames.npy') 67 | video = torch.from_numpy(np.load(video_path)) 68 | video = video/255 #of shape (F,H,W,C) 69 | if img_channels==1: 70 | video = video.permute(0,3,1,2)#F,C,H,W 71 | video = To_Grayscale(video) 72 | video = video.permute(0,2,3,1) 73 | 74 | channels_sum += torch.mean(video,dim=[0,1,2]) 75 | #except for the channel dimension as we want mean and std 76 | # for each channel 77 | channels_squared_sum += torch.mean(video**2,dim=[0,1,2]) 78 | # num_items += video.shape[0] 79 | num_batches += 1 80 | mean = channels_sum/num_batches 81 | 82 | std = ((channels_squared_sum/num_batches) - mean**2)**0.5 83 | return mean,std 84 | 85 | def get_audio_mean_and_std(root_dir,filenames): 86 | running_mean_sum , running_squarred_mean_sum = 0,0 87 | progress_bar = tqdm(enumerate(filenames),ascii=True,total=len(filenames),desc='audio') 88 | for _,filename in progress_bar: 89 | progress_bar.set_postfix({'file':filename}) 90 | audio_path = os.path.join(root_dir,filename,'audio.npy') 91 | audio = torch.from_numpy(np.load(audio_path)) 92 | running_mean_sum += torch.mean(audio) 93 | running_squarred_mean_sum += torch.mean(audio**2) 94 | mean = running_mean_sum/len(filenames) 95 | std = ((running_squarred_mean_sum/len(filenames))-mean**2)**0.5 96 | return mean,std 97 | 98 | def get_mean_and_std(root_dir,filenames_txt,img_channels): 99 | logger.debug('Calculating mean and standard deviation for the dataset.Please wait...') 100 | ret = {} 101 | #mean = E(X) 102 | #variance = E(X**2)- E(X)**2 103 | #standard deviation = variance**0.5 104 | filenames_path = os.path.join(root_dir,filenames_txt) 105 | with open(filenames_path) as file: 106 | filenames = file.read().split('\n') 107 | for i,name in enumerate(filenames[:]): 108 | if not name.strip(): 109 | del filenames[i] 110 | 111 | ret['video'] = ([0.5]*img_channels,[0.5]*img_channels) 112 | ret['audio'] = get_audio_mean_and_std(root_dir,filenames) 113 | logger.debug(f'[MEAN and STANDARD_DEVIATION] {ret}') 114 | return ret -------------------------------------------------------------------------------- /wav2mov/main/engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from wav2mov.core.engine import TemplateEngine 4 | from wav2mov.core.engine import CallbackEvents as Events 5 | 6 | 7 | class State: 8 | def __init__(self,names): 9 | self.names = names 10 | for name in self.names: 11 | setattr(self,name,None) 12 | 13 | def reset(self,names): 14 | for name in names: 15 | setattr(self,name,None) 16 | 17 | 18 | class Engine(TemplateEngine): 19 | def __init__(self,options,hparams,config,logger): 20 | super().__init__() 21 | self.logger = logger 22 | self.configure(options,hparams,config) 23 | self.state = State(['num_batches','cur_batch_size','epoch','batch_idx','start_epoch','logs']) 24 | 25 | 26 | def configure(self,options,hparams,config): 27 | self.hparams = hparams 28 | self.options = options 29 | self.config = config 30 | 31 | def load_checkpoint(self,model): 32 | prev_epoch = 0 33 | if getattr(self.options, 'model_path', None) is None: 34 | return prev_epoch 35 | loading_version = os.path.basename(self.options.model_path) 36 | self.logger.debug(f'Loading pretrained weights : {self.config.version} <== {loading_version}') 37 | 38 | prev_epoch = model.load(checkpoint_dir=self.options.model_path) 39 | if prev_epoch is None: 40 | prev_epoch = 0 41 | self.logger.debug(f'weights loaded successfully: {self.config.version} <== {loading_version}') 42 | return prev_epoch + 1 43 | 44 | def dispatch(self, event): 45 | super().dispatch(event,state=self.state) 46 | 47 | def run(self,model,dataloaders_ntuple,callbacks=None): 48 | callbacks = callbacks or [] 49 | callbacks = [model] + callbacks 50 | self.register(callbacks) 51 | self.state.start_epoch = self.load_checkpoint(model) 52 | 53 | train_dl = dataloaders_ntuple.train 54 | self.state.num_batches = len(train_dl) 55 | num_epochs = self.hparams['num_epochs'] 56 | self.dispatch(Events.RUN_START) 57 | self.dispatch(Events.TRAIN_START) 58 | for epoch in range(self.state.start_epoch,num_epochs): 59 | self.state.epoch = epoch 60 | self.dispatch(Events.EPOCH_START) 61 | for batch_idx,batch in enumerate(train_dl): 62 | self.state.batch_idx = batch_idx 63 | self.state.cur_batch_size = batch[0].shape[0] #makes the system tight coupled though!? 64 | self.dispatch(Events.BATCH_START) 65 | model.setup_input(batch,state=self.state) 66 | logs = model.optimize(state=self.state) 67 | self.state.logs = logs 68 | self.dispatch(Events.BATCH_END) 69 | self.state.reset(['logs']) 70 | 71 | self.dispatch(Events.EPOCH_END) 72 | self.dispatch(Events.TRAIN_END) 73 | self.dispatch(Events.RUN_END) 74 | 75 | def to_device(self,device,*args): 76 | return [arg.to(device) for arg in args] 77 | 78 | def test(self,model,test_dl): 79 | last_epoch = self.load_checkpoint(model) 80 | if last_epoch is None or last_epoch==0: 81 | self.logger.warning(f'Testing an untrained model !!!.') 82 | sample = next(iter(test_dl)) 83 | audio,audio_frames,video = sample 84 | # audio,audio_frames,video = self.to_device('cpu',audio,audio_frames,video) 85 | model.test(audio,audio_frames,video) 86 | 87 | 88 | -------------------------------------------------------------------------------- /wav2mov/main/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main file of the Wav2Mov Project 3 | It is the entry point to various functions 4 | """ 5 | import os 6 | import random 7 | import torch 8 | from torch.utils.tensorboard.writer import SummaryWriter 9 | 10 | from wav2mov.logger import Logger 11 | from wav2mov.config import get_config 12 | 13 | from wav2mov.params import params 14 | 15 | from wav2mov.main.preprocess import create_from_grid_dataset 16 | from wav2mov.main.train import train_model 17 | from wav2mov.main.test import test_model 18 | 19 | from wav2mov.main.options import Options,set_options 20 | from wav2mov.main.validate_params import check_batchsize 21 | torch.manual_seed(params['seed']) 22 | random.seed(params['seed']) 23 | 24 | def get_logger(config,filehandler_required=False): 25 | local_logger = Logger(__name__) 26 | local_logger.add_console_handler() 27 | if filehandler_required: 28 | local_logger.add_filehandler(config['log_fullpath'],in_json=True) 29 | return local_logger 30 | 31 | def preprocess(preprocess_logger,config): 32 | if options.grid_dataset_dir: 33 | config.update('grid_dataset_dir',options.grid_dataset_dir) 34 | create_from_grid_dataset(config,preprocess_logger) 35 | 36 | def train(train_logger,args_options,config): 37 | train_model(args_options,params,config,train_logger) 38 | 39 | def test(test_logger,args_options,config): 40 | test_model(args_options,params,config,test_logger) 41 | 42 | def save_message(options,config,logger): 43 | if not getattr(options,'msg'):return 44 | if options.log=='n':return 45 | path = os.path.join(os.path.dirname(config['log_fullpath']),f'message_{config.version}.txt') 46 | logger.debug('message written to %(path)s'%{'path':path}) 47 | with open(path,'a+') as file: 48 | file.write(options.msg) 49 | file.write('\n') 50 | 51 | def main(config): 52 | 53 | allowed = ('y', 'yes') 54 | if options.preprocess in allowed: 55 | preprocess(logger,config) 56 | 57 | if options.train in allowed: 58 | required = ('num_videos', 'num_epochs') 59 | 60 | if not all(getattr(options, option) for option in required): 61 | raise RuntimeError('Cannot train without the options :', required) 62 | check_batchsize(hparams=params['data']) 63 | train(logger, options,config) 64 | 65 | if options.test in allowed: 66 | required = ('model_path',) 67 | 68 | if not all(getattr(options, option) for option in required): 69 | raise RuntimeError( 70 | f'Cannot test without the options : {required}') 71 | test(logger, options,config) 72 | 73 | if __name__ == '__main__': 74 | 75 | 76 | options = Options().parse() 77 | config = get_config(options.version) 78 | set_options(options,params) 79 | logger = get_logger(config,filehandler_required=options.log in ['y', 'yes']) 80 | save_message(options,config,logger) 81 | try: 82 | main(config) 83 | except Exception as e: 84 | if options.train in ['y','yes']: 85 | params.save(config['params_checkpoint_fullpath']) 86 | logger.exception(e) 87 | 88 | finally: 89 | logger.cleanup() 90 | 91 | 92 | -------------------------------------------------------------------------------- /wav2mov/main/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class Options(): 4 | def __init__(self): 5 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, 6 | description='Wav2Mov | End to End Speech to facial animation model') 7 | 8 | self.parser.add_argument('--version', 9 | type=str, 10 | help='version of the file being run. Example : v1,v2,...', 11 | required=True) 12 | 13 | self.parser.add_argument('--log', 14 | default='y', 15 | choices=['y','n','yes','no'], 16 | type=str, 17 | help='whether to initialize logger') 18 | 19 | self.parser.add_argument('--train', 20 | default='n', 21 | choices=['y','n','yes','no'], 22 | type=str, 23 | help='whether to train the model') 24 | 25 | self.parser.add_argument('--device','-d', 26 | default='cpu', 27 | choices=['cpu','cuda'], 28 | type=str, 29 | help='device on which model operations are done', 30 | required=True) 31 | 32 | self.parser.add_argument('--num_epochs','-e', 33 | type=int, 34 | help='device on which model operations are done', 35 | ) 36 | 37 | self.parser.add_argument('--test', 38 | default='n', 39 | choices=['y','n','yes','no'], 40 | type=str, 41 | help='run test script') 42 | self.parser.add_argument('--train_sync_expert', 43 | default='n', 44 | choices=['y','n','yes','no'], 45 | type=str, 46 | help='only train sync disc') 47 | 48 | self.parser.add_argument('--preprocess', 49 | default='n', 50 | choices=['y','n','yes','no'], 51 | type=str,help='run preprocess script') 52 | 53 | self.parser.add_argument('--grid_dataset_dir','-grid', 54 | type=str, 55 | help="path of raw dataset") 56 | # self.parser.add_argument('--device', 57 | # default='cuda', 58 | # choices=['cpu','cuda'], 59 | # type=str,help='device where processing should be done') 60 | 61 | self.parser.add_argument('--model_path','-path', 62 | type=str, 63 | help='generator checkpoint fullpath') 64 | 65 | self.parser.add_argument('--num_videos','-v', 66 | type=int, 67 | help='num of videos on which the model should be trained') 68 | 69 | 70 | self.parser.add_argument('--msg','-m', 71 | type=str, 72 | help='any message about current run') 73 | 74 | self.parser.add_argument('--test_sample_num','-snum', 75 | type=int, 76 | help='sample to be taken from test dataloader', 77 | ) 78 | 79 | def parse(self): 80 | return self.parser.parse_args() 81 | 82 | def set_device(options, params): 83 | device = options.device 84 | params.set('device', device) 85 | params.set('gen', {**params['gen'], 'device': device}) 86 | 87 | def set_epochs(options,params): 88 | num_epochs = options.num_epochs 89 | params.set('num_epochs',num_epochs) 90 | 91 | def set_train_sync_expert(options,params): 92 | train_sync_expert = True if options.train_sync_expert in ('y','yes') else False 93 | params.set('train_sync_expert',train_sync_expert) 94 | 95 | def set_options(options, params): 96 | set_device(options,params) 97 | set_epochs(options,params) 98 | set_train_sync_expert(options,params) 99 | -------------------------------------------------------------------------------- /wav2mov/main/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for creating and saving numoy file from raw dataset 3 | + mouth landmarks and audio frames are extracted. 4 | """ 5 | 6 | import numpy as np 7 | import os 8 | 9 | 10 | from wav2mov.core.data import RawDataset, GridDataset 11 | from wav2mov.datasets import create_file_list as create_file_list_module 12 | 13 | # currently these are not applied anywhere 14 | # audio file in util to see which values are being used instead. 15 | video_frame_rate = 30 16 | audio_sampling_rate = 16_000 17 | win_len = 0.033 18 | win_hop = 0.033 19 | 20 | 21 | 22 | def create(raw_dataset: RawDataset,config,logger) -> str: 23 | """ Creates numpy file containing video and audio frames from given dataset 24 | 25 | Args: 26 | 27 | + `raw_dataset` (RawDataset): either GridDataset or Ravdess dataset 28 | 29 | + `config` : Config object containing different location information 30 | 31 | Returns: 32 | 33 | + `str`: log string containing info about saved file. 34 | """ 35 | 36 | samples_count = raw_dataset.samples_count 37 | dataset_dir = config['train_test_dataset_dir'] 38 | filenames = [] 39 | 40 | for sample in raw_dataset.generator(get_filepath_only=False,img_size=(256,256),show_progress_bar=True): 41 | audio_filepath,audio_vals = sample.audio 42 | _,video_frames = sample.video 43 | folder_name = os.path.basename(audio_filepath).split('.')[0] 44 | 45 | 46 | if raw_dataset.name in os.path.basename(dataset_dir): 47 | dest_dir = os.path.join(dataset_dir,folder_name) 48 | else: 49 | dest_dir = os.path.join(dataset_dir,f'{raw_dataset.name}',folder_name) 50 | 51 | os.makedirs(dest_dir,exist_ok=True) 52 | np.save(os.path.join(dest_dir,'video_frames.npy'), 53 | np.array(video_frames)) 54 | np.save(os.path.join(dest_dir,'audio.npy'), 55 | audio_vals) 56 | filenames.append(folder_name + '\n') 57 | 58 | filenames[-1] = filenames[-1].strip('\n') 59 | with open(os.path.join(dataset_dir, 'filenames.txt'),'a+') as file: 60 | file.writelines(filenames) 61 | 62 | log = f'Samples generated : {samples_count}\n' 63 | log += f'Location : { dataset_dir }\n' 64 | log += f'Filenames are listed in filenames.txt\n' 65 | logger.info(log) 66 | create_file_list_module.main() 67 | logger.debug('train and test filelists created') 68 | 69 | return log 70 | 71 | 72 | def create_from_grid_dataset(config,logger): 73 | dataset = GridDataset(config['grid_dataset_dir'], 74 | audio_sampling_rate=audio_sampling_rate, 75 | video_frame_rate=video_frame_rate, 76 | samples_count=125) 77 | 78 | create(dataset,config,logger) 79 | print(f'{dataset.__class__.__name__} successfully processed.') 80 | # logger.info(log) 81 | -------------------------------------------------------------------------------- /wav2mov/main/test.py: -------------------------------------------------------------------------------- 1 | # import logging 2 | import os 3 | import re 4 | import torch 5 | from torchvision import utils as vutils 6 | # from tqdm import tqdm 7 | from scipy.io.wavfile import write 8 | from wav2mov.core.data.collates import get_batch_collate 9 | from wav2mov.models.wav2mov_inferencer import Wav2movInferencer 10 | from wav2mov.main.data import get_dataloaders 11 | from wav2mov.utils.plots import save_gif,save_video 12 | 13 | SAMPLE_NUM = 5 14 | 15 | def squeeze_frames(video): 16 | batch_size,num_frames,*extra = video.shape 17 | return video.reshape(batch_size*num_frames,*extra) 18 | 19 | def denormalize_frames(frames): 20 | return ((frames*0.5)+0.5)*255 21 | 22 | def make_path_compatible(path): 23 | if os.sep != '\\':#not windows 24 | return re.sub(r'(\\)+',os.sep,path) 25 | 26 | def test_sample(model,dl,options,hparams,config,logger,suffix): 27 | global SAMPLE_NUM 28 | if hasattr(options,'test_sample_num'): 29 | sample_num = options.test_sample_num 30 | else: 31 | sample_num = SAMPLE_NUM 32 | SAMPLE_NUM = sample_num 33 | logger.debug(f'FOR {suffix} sample') 34 | version = os.path.basename(options.model_path).strip('gen_').split('.')[0] 35 | out_dir = os.path.join(config['out_dir'],version) 36 | logger.debug(f'[OUTPUT DIR] {out_dir}') 37 | logger.debug(options.version) 38 | sample_iter = (iter(dl)) 39 | sample = next(sample_iter) 40 | 41 | for _ in range(min(sample_num-1,len(dl)-1)): 42 | sample = next(sample_iter) 43 | 44 | # print(sample,len(sample)) 45 | audio,audio_frames,video = sample 46 | batch_size = audio.shape[0] 47 | if batch_size>1: 48 | audio,audio_frames,video = (audio[0].unsqueeze(0), 49 | audio_frames[0].unsqueeze(0), 50 | video[0].unsqueeze(0)) 51 | fake_video_frames,ref_video_frame = model.test(audio_frames,video,get_ref_video_frame=True) 52 | fake_video_frames = squeeze_frames(fake_video_frames) 53 | video = squeeze_frames(video) 54 | os.makedirs(out_dir,exist_ok=True) 55 | save_path_fake_video_frames = os.path.join(out_dir,f'{suffix}_fake_frames_{version}.png') 56 | save_path_real_video_frames = os.path.join(out_dir,f'{suffix}_real_frames_{version}.png') 57 | save_path_ref_video_frame = os.path.join(out_dir,f'{suffix}_ref_frame_{version}.png') 58 | save_path_fake_video_frames = make_path_compatible(save_path_fake_video_frames) 59 | save_path_real_video_frames = make_path_compatible(save_path_real_video_frames) 60 | save_path_ref_video_frame = make_path_compatible(save_path_ref_video_frame) 61 | 62 | save_video(hparams['data'],os.path.join(out_dir,f'{suffix}_fake_video.avi'),audio,denormalize_frames(fake_video_frames)) 63 | logger.debug(f'video saved : {suffix}_{fake_video_frames.shape}') 64 | 65 | vutils.save_image(denormalize_frames(ref_video_frame),save_path_ref_video_frame,normalize=True) 66 | vutils.save_image(denormalize_frames(fake_video_frames),save_path_fake_video_frames,normalize=True) 67 | vutils.save_image(denormalize_frames(video),save_path_real_video_frames,normalize=True) 68 | 69 | gif_path = os.path.join(out_dir,f'{suffix}_fake_frames_{version}.gif') 70 | save_gif(gif_path,denormalize_frames(fake_video_frames)) 71 | gif_path = os.path.join(out_dir,f'{suffix}_real_frames_{version}.gif') 72 | save_gif(gif_path,denormalize_frames(video)) 73 | 74 | def test_model(options,hparams,config,logger): 75 | logger.debug(f'Testing model...') 76 | version = os.path.basename(options.model_path).strip('gen_').split('.')[0] 77 | logger.debug(f'loading version : {version}') 78 | out_dir = os.path.join(config['out_dir'],version) 79 | 80 | model = Wav2movInferencer(hparams,logger) 81 | checkpoint = torch.load(options.model_path,map_location='cpu') 82 | state_dict,last_epoch = checkpoint['state_dict'],checkpoint['epoch'] 83 | model.load(state_dict) 84 | logger.debug(f'model was trained for {last_epoch+1} epochs. ') 85 | collate_fn = get_batch_collate(hparams['data']) 86 | logger.debug('Loading dataloaders') 87 | dataloaders = get_dataloaders(options,config,hparams,collate_fn=collate_fn) 88 | train_dl =dataloaders.train 89 | test_dl = dataloaders.val 90 | # write(os.path.join(out_dir,f'audio_{SAMPLE_NUM}.wav'),16000,audio.cpu().numpy().reshape(-1)) 91 | # logger.debug(f'audio saved : audio_{SAMPLE_NUM}.wav') 92 | test_sample(model,train_dl,options,hparams,config,logger,'train') 93 | test_sample(model,test_dl,options,hparams,config,logger,'test') 94 | 95 | logger.debug(f'results are saved in {out_dir}') 96 | msg = "#"*25 97 | msg += f'\ntest_run for version {version}.\n' 98 | msg += '='*25 99 | msg += f'\nlast epoch : {last_epoch+1}\n' 100 | msg += f'curr_version : {config.version}\n' 101 | msg += f'sample num : {SAMPLE_NUM}\n' 102 | msg += "#"*25 103 | 104 | with open(os.path.join(out_dir,'info.txt'),'a+') as file: 105 | file.write(msg) 106 | 107 | -------------------------------------------------------------------------------- /wav2mov/main/train.py: -------------------------------------------------------------------------------- 1 | from wav2mov.models.wav2mov_trainer import Wav2MovTrainer 2 | from wav2mov.main.engine import Engine 3 | from wav2mov.main.callbacks import LossMetersCallback,ModelCheckpoint,TimeTrackerCallback,LoggingCallback 4 | 5 | from wav2mov.core.data.collates import get_batch_collate 6 | from wav2mov.main.data import get_dataloaders 7 | 8 | 9 | 10 | 11 | def train_model(options,hparams,config,logger): 12 | engine = Engine(options,hparams,config,logger) 13 | model = Wav2MovTrainer(hparams,config,logger) 14 | collate_fn = get_batch_collate(hparams['data']) 15 | dataloaders_ntuple = get_dataloaders(options,config,hparams, 16 | collate_fn=collate_fn) 17 | callbacks = [LossMetersCallback(options,hparams,config,logger, 18 | verbose=True), 19 | LoggingCallback(options,hparams,config,logger), 20 | TimeTrackerCallback(hparams,logger), 21 | ModelCheckpoint(model,hparams,config,logger, 22 | save_every=5)] 23 | 24 | engine.run(model,dataloaders_ntuple,callbacks) -------------------------------------------------------------------------------- /wav2mov/main/trained.txt: -------------------------------------------------------------------------------- 1 | E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\runs\v6\Run_7_4_2021__18_24\gen_Run_7_4_2021__18_24.pt -------------------------------------------------------------------------------- /wav2mov/main/validate_params.py: -------------------------------------------------------------------------------- 1 | def check_batchsize(hparams): 2 | if hparams['batch_size']%hparams['mini_batch_size']: 3 | raise ValueError(f'Batch size must be evenly divisible by mini_batch_size\n' 4 | f'Currently batch_size : {hparams["batch_size"]} mini_batch_size :{hparams["mini_batch_size"]}') -------------------------------------------------------------------------------- /wav2mov/models/README.md: -------------------------------------------------------------------------------- 1 | # Architectures 2 | 3 | ## Generator 4 | 5 | * follows UNET architecture 6 | * input : audio_frames (B,F,Sw) and ref_frames (B,F,C,H,W) 7 | * out is IMAGE_SIZE*IMAGE_SIZE 8 | 9 | ## identity Descriminator 10 | 11 | * input is real/fake frame and still face image(ref face image) 12 | * helps to retain identity of face of given reference image and to produce realistic images. 13 | 14 | ## sequence Descriminator 15 | 16 | * input is consecutive frames 17 | * helps to produce cohesive video frames. 18 | 19 | ## sync Descriminator 20 | 21 | * input is audio frame(666 points*5) and corresponding video frames. 22 | * helps to synchronize produced video frames with that of audio. 23 | 24 | -------------------------------------------------------------------------------- /wav2mov/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminators.identity_discriminator import IdentityDiscriminator 2 | from .discriminators.sequence_discriminator import SequenceDiscriminator 3 | from .discriminators.sync_discriminator import SyncDiscriminator 4 | from .discriminators.patch_disc import PatchDiscriminator 5 | from .generator.frame_generator import Generator -------------------------------------------------------------------------------- /wav2mov/models/discriminators/identity_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn,optim 3 | from torch.nn import functional as F 4 | 5 | from wav2mov.core.models.base_model import BaseModel 6 | from wav2mov.models.utils import squeeze_batch_frames,get_same_padding 7 | from wav2mov.models.layers.conv_layers import Conv2dBlock 8 | 9 | from wav2mov.logger import get_module_level_logger 10 | logger = get_module_level_logger(__name__) 11 | 12 | class IdentityDiscriminator(BaseModel): 13 | def __init__(self,hparams): 14 | 15 | super().__init__() 16 | self.hparams = hparams 17 | in_channels = hparams['in_channels']*2 18 | chs = self.hparams['chs'] 19 | chs = [in_channels] + chs 20 | padding = get_same_padding(kernel_size=4,stride=2) 21 | relu_neg_slope = self.hparams['relu_neg_slope'] 22 | self.conv_blocks = nn.ModuleList(Conv2dBlock(chs[i],chs[i+1], 23 | kernel_size=(4,4), 24 | stride=2, 25 | padding=padding, 26 | use_norm=True, 27 | use_act=True, 28 | act=nn.LeakyReLU(relu_neg_slope) 29 | ) for i in range(len(chs)-2) 30 | ) 31 | 32 | self.conv_blocks.append(Conv2dBlock(chs[-2],chs[-1], 33 | kernel_size=(4,4), 34 | stride=2, 35 | padding=padding, 36 | use_norm=False, 37 | use_act=False 38 | ) 39 | ) 40 | 41 | def forward(self,x,y): 42 | """ 43 | x : frame image (B,F,H,W) 44 | y : still image 45 | """ 46 | assert x.shape==y.shape 47 | 48 | if len(x.shape)>4:#frame dim present 49 | x = squeeze_batch_frames(x) 50 | y = squeeze_batch_frames(y) 51 | 52 | x = torch.cat([x,y],dim=1)#along channels 53 | for block in self.conv_blocks: 54 | x = block(x) 55 | return x 56 | 57 | def get_optimizer(self): 58 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5,0.999)) 59 | -------------------------------------------------------------------------------- /wav2mov/models/discriminators/patch_disc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn,optim 3 | from wav2mov.models.utils import squeeze_batch_frames 4 | 5 | from wav2mov.logger import get_module_level_logger 6 | logger = get_module_level_logger(__name__) 7 | 8 | 9 | class PatchDiscriminator(nn.Module): 10 | """ 11 | ref : https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 12 | """ 13 | def __init__(self,hparams,norm_layer=nn.BatchNorm2d,use_bias = True): 14 | super().__init__() 15 | self.hparams = hparams 16 | input_nc = hparams['in_channels'] 17 | ndf = hparams['ndf'] 18 | n_layers = hparams['num_layers'] 19 | use_bias = norm_layer==nn.InstanceNorm2d 20 | #batch normalization has affine = True so it has additional scaling and biasing/shifting terms by default. 21 | kw = 4 22 | padw = 1 23 | sequence = [nn.Conv2d(input_nc*2,ndf,kw,2,padw), 24 | nn.LeakyReLU(0.2,inplace=True)]#inplace=True saves memory 25 | 26 | nf_mult = 1 27 | nf_mult_prev = 1 28 | 29 | for n in range(1,n_layers): 30 | nf_mult_prev = nf_mult 31 | nf_mult = min(2**n,8) #multiplier is clipped at 8 : so if ndf is 64 max possible is 64*8=512 32 | sequence += [ 33 | nn.Conv2d(ndf*nf_mult_prev,ndf*nf_mult,kw,2,padw,bias=use_bias), 34 | norm_layer(ndf*nf_mult), 35 | nn.LeakyReLU(0.2,True) 36 | ] 37 | nf_mult_prev = nf_mult 38 | nf_mult = min(2**n_layers,8) 39 | sequence += [ 40 | nn.Conv2d(ndf*nf_mult_prev,ndf*nf_mult,kw,1,padw,bias=use_bias), 41 | norm_layer(ndf*nf_mult), 42 | nn.LeakyReLU(0.2,True) 43 | ] 44 | 45 | sequence += [ 46 | nn.Conv2d(ndf*nf_mult,1,kw,1,padw) 47 | ] 48 | 49 | self.disc = nn.Sequential(*sequence) 50 | 51 | 52 | def forward(self,frame_image,ref_image): 53 | assert frame_image.shape == ref_image.shape 54 | batch_size = frame_image.shape[0] 55 | is_frame_dim_present = False 56 | if len(frame_image.shape)>4: 57 | is_frame_dim_present = True 58 | batch_size,num_frames,*img_shape = frame_image.shape 59 | frame_image = squeeze_batch_frames(frame_image) 60 | ref_image = squeeze_batch_frames(ref_image) 61 | frame_image = torch.cat([frame_image,ref_image],dim=1)#catenate images along the channel dimension 62 | #frame image now has a shape of (N,C+C,H,W) 63 | return self.disc(frame_image) 64 | 65 | def get_optimizer(self): 66 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5, 0.999)) 67 | -------------------------------------------------------------------------------- /wav2mov/models/discriminators/sequence_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn,optim 3 | 4 | from wav2mov.core.models.base_model import BaseModel 5 | from wav2mov.models.layers.conv_layers import Conv2dBlock 6 | from wav2mov.models.utils import get_same_padding,squeeze_batch_frames 7 | 8 | from wav2mov.logger import get_module_level_logger 9 | logger = get_module_level_logger(__name__) 10 | 11 | class SequenceDiscriminator(BaseModel): 12 | def __init__(self,hparams): 13 | super().__init__() 14 | self.hparams = hparams 15 | relu_neg_slope = self.hparams['relu_neg_slope'] 16 | in_size, h_size, num_layers = self.hparams['in_size'],self.hparams['h_size'],self.hparams['num_layers'] 17 | self.gru = nn.GRU(input_size=in_size,hidden_size=h_size,num_layers=num_layers,batch_first = True) 18 | in_channels = self.hparams['in_channels'] 19 | chs = [in_channels] + self.hparams['chs'] 20 | kernel,stride = 4,2 21 | padding = get_same_padding(kernel,stride) 22 | cnn = nn.ModuleList([Conv2dBlock(chs[i],chs[i+1],kernel,stride,padding, 23 | use_norm=True,use_act=True, 24 | act=nn.LeakyReLU(relu_neg_slope)) for i in range(len(chs)-2)]) 25 | cnn.append(Conv2dBlock(chs[-2],chs[-1],kernel,stride,padding, 26 | use_norm=False,use_act=True, 27 | act=nn.Tanh())) 28 | self.cnn = nn.Sequential(*cnn) 29 | ############################################ 30 | # channels : 3 => 64 => 128 => 256 => 512 31 | # frame sz : 256=> 128 => 64 => 32 => 16 =>8 = Width and height =4 (since only upper half) 32 | # height is half so : final height is 4 33 | # thus out of self.cnn : 512x4x8 34 | ############################################ 35 | 36 | def forward(self,frames): 37 | """frames : B,T,C,H,W""" 38 | img_height = frames.shape[-2] 39 | frames = frames[...,0:img_height//2,:]#consider upper half 40 | batch_size,num_frames,*img_size = frames.shape 41 | frames = squeeze_batch_frames(frames) 42 | frames = self.cnn(frames) 43 | frames = frames.reshape(batch_size,num_frames,-1) 44 | out,_ = self.gru(frames)#out is of shape (batch_size,seq_len,num_dir*hidden_dim) 45 | return out[:,-1,:]#batch_size,hidden_size 46 | 47 | def get_optimizer(self): 48 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5,0.999)) 49 | -------------------------------------------------------------------------------- /wav2mov/models/discriminators/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | from torchvision import utils as vutils 5 | 6 | def squeeze_frames(video): 7 | batch_size,num_frames,*extra = video.shape 8 | return video.reshape(batch_size*num_frames,*extra) 9 | 10 | def denormalize_frames(frames): 11 | return ((frames*0.5)+0.5)*255 12 | 13 | def make_path_compatible(path): 14 | if os.sep != '\\':#not windows 15 | return re.sub(r'(\\)+',os.sep,path) 16 | 17 | 18 | def save_series(frames,config,i=0): 19 | frames = squeeze_frames(frames) 20 | save_dir = os.path.join(config['runs_dir'],'images') 21 | os.makedirs(save_dir,exist_ok=True) 22 | save_path = os.path.join(save_dir,f'sync_frames_fake_{i}.png') 23 | vutils.save_image(denormalize_frames(frames),save_path,normalize=True) -------------------------------------------------------------------------------- /wav2mov/models/generator/audio_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from wav2mov.models.layers.conv_layers import Conv1dBlock,Conv2dBlock 4 | from wav2mov.models.layers.debug_layers import IdentityDebugLayer 5 | from wav2mov.logger import get_module_level_logger 6 | from wav2mov.models.utils import squeeze_batch_frames,get_same_padding 7 | 8 | logger = get_module_level_logger(__name__) 9 | 10 | class AudioEnocoder(nn.Module): 11 | def __init__(self,hparams): 12 | super().__init__() 13 | self.hparams = hparams 14 | padding_31 = get_same_padding(3,1) 15 | padding_32 = get_same_padding(3,2) 16 | padding_42 = get_same_padding(4,2) 17 | # each frame has mfcc of shape (7,13) 18 | self.conv_encoder = nn.Sequential( 19 | Conv2dBlock(1,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()),#7,13 20 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 21 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 22 | 23 | Conv2dBlock(32,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13 24 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 25 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 26 | 27 | Conv2dBlock(64,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13 28 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 29 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 30 | 31 | Conv2dBlock(128,256,3,(2,1),padding=padding_32,use_norm=True,use_act=True,act=nn.ReLU()), #4,13 32 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 33 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 34 | 35 | Conv2dBlock(256,512,(4,3),(2,1),padding=1,use_norm=True,use_act=True,act=nn.ReLU()), #2,13 36 | Conv2dBlock(512,512,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True), 37 | Conv2dBlock(512,512,(2,5),(1,2),padding=(0,1),use_norm=True,use_act=True,act=nn.ReLU()),#1,6, 38 | ) 39 | 40 | self.features_len = 6*512#out of conv layers 41 | self.num_layers = 1 42 | self.hidden_size = self.hparams['latent_dim_audio'] 43 | self.gru = nn.GRU(input_size=self.features_len, 44 | hidden_size=self.hidden_size, 45 | num_layers=1, 46 | batch_first=True) 47 | self.final_act = nn.Tanh() 48 | 49 | def forward(self, x): 50 | """ x : audio frames of shape B,T,t,13""" 51 | batch_size,num_frames,*_ = x.shape 52 | x = self.conv_encoder(squeeze_batch_frames(x).unsqueeze(1))#add channel dimension 53 | x = x.reshape(batch_size,num_frames,self.features_len) 54 | x,_ = self.gru(x) 55 | return self.final_act(x) # shape (batch_size,num_frames,hidden_size=latent_dim_audio) 56 | -------------------------------------------------------------------------------- /wav2mov/models/generator/frame_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn,optim 3 | 4 | from wav2mov.models.generator.audio_encoder import AudioEnocoder 5 | from wav2mov.models.generator.noise_encoder import NoiseEncoder 6 | from wav2mov.models.generator.id_encoder import IdEncoder 7 | from wav2mov.models.generator.id_decoder import IdDecoder 8 | 9 | from wav2mov.models.utils import squeeze_batch_frames 10 | from wav2mov.logger import get_module_level_logger 11 | logger = get_module_level_logger(__name__) 12 | 13 | class Generator(nn.Module): 14 | def __init__(self,hparams): 15 | super().__init__() 16 | self.hparams = hparams 17 | self.id_encoder = IdEncoder(self.hparams) 18 | self.id_decoder = IdDecoder(self.hparams) 19 | self.audio_encoder = AudioEnocoder(self.hparams) 20 | # self.noise_encoder = NoiseEncoder(self.hparams) 21 | 22 | def forward(self,audio_frames,ref_frames): 23 | batch_size,num_frames,*_ = ref_frames.shape 24 | assert num_frames == audio_frames.shape[1] 25 | encoded_id , intermediates = self.id_encoder(squeeze_batch_frames(ref_frames)) 26 | encoded_id = encoded_id.reshape(batch_size*num_frames,-1,1,1) 27 | encoded_audio = self.audio_encoder(audio_frames).reshape(batch_size*num_frames,-1,1,1) 28 | # encoded_noise = self.noise_encoder(batch_size,num_frames).reshape(batch_size*num_frames,-1,1,1) 29 | # logger.debug(f'encoded_id {encoded_id.shape} encoded_audio {encoded_audio.shape} encoded_noise {encoded_noise.shape}') 30 | encoded = torch.cat([encoded_id,encoded_audio],dim=1)#along channel dimension 31 | gen_frames = self.id_decoder(encoded,intermediates) 32 | _,*img_shape = gen_frames.shape 33 | return gen_frames.reshape(batch_size,num_frames,*img_shape) 34 | 35 | def get_optimizer(self): 36 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5, 0.999)) 37 | -------------------------------------------------------------------------------- /wav2mov/models/generator/id_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from wav2mov.models.layers.conv_layers import Conv2dBlock,ConvTranspose2dBlock,DoubleConvTranspose2d 5 | from wav2mov.logger import get_module_level_logger 6 | from wav2mov.models.utils import get_same_padding 7 | 8 | logger = get_module_level_logger(__name__) 9 | 10 | class IdDecoder(nn.Module): 11 | def __init__(self,hparams): 12 | super().__init__() 13 | self.hparams = hparams 14 | self.in_channels = hparams['in_channels'] 15 | self.latent_channels = hparams['latent_dim'] 16 | latent_id_height,latent_id_width = hparams['latent_dim_id'] 17 | chs = self.hparams['chs'][::-1]#1024,512,256,128,64 18 | chs = [self.latent_channels] + chs + [self.in_channels] 19 | padding = get_same_padding(kernel_size=4,stride=2) 20 | self.convs = nn.ModuleList() 21 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[0], 22 | out_ch=chs[1], 23 | kernel_size=(latent_id_height,latent_id_width), 24 | stride=2, 25 | padding=0, 26 | use_norm=True, 27 | use_act=True, 28 | act=nn.ReLU() 29 | )) 30 | for i in range(1,len(chs)-2): 31 | self.convs.append( DoubleConvTranspose2d(in_ch=chs[i], 32 | skip_ch=chs[i], 33 | out_ch=chs[i+1], 34 | kernel_size=(4,4), 35 | stride=2, 36 | padding=padding, 37 | use_norm=True, 38 | use_act=True, 39 | act=nn.ReLU()) 40 | ) 41 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[-2], 42 | out_ch=chs[-1], 43 | kernel_size=(4,4), 44 | stride=2, 45 | padding=padding, 46 | use_norm=True, 47 | use_act=True, 48 | act=nn.Tanh() 49 | )) 50 | def forward(self,encoded,skip_outs): 51 | """[summary] 52 | 53 | Args: 54 | skip_outs ([type]): 64,128,256,512,1024 55 | encoded ([type]): latent_channels, 56 | """ 57 | encoded = encoded.reshape(encoded.shape[0],self.latent_channels,1,1) 58 | skip_outs = skip_outs[::-1] 59 | for i,block in enumerate(self.convs[:-2]): 60 | x = block(encoded) 61 | encoded = torch.cat([x,skip_outs[i]],dim=1)#cat along channel axis 62 | 63 | encoded = self.convs[-2](encoded) 64 | return self.convs[-1](encoded) 65 | -------------------------------------------------------------------------------- /wav2mov/models/generator/id_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from wav2mov.models.layers.conv_layers import Conv2dBlock 5 | from wav2mov.logger import get_module_level_logger 6 | from wav2mov.models.utils import get_same_padding 7 | 8 | logger = get_module_level_logger(__name__) 9 | 10 | class IdEncoder(nn.Module): 11 | def __init__(self,hparams): 12 | super().__init__() 13 | self.hparams = hparams 14 | in_channels = hparams['in_channels'] 15 | chs = self.hparams['chs'] 16 | chs = [in_channels] + chs +[1]# 1 is added here not in params because see how channels are being used in id_decoder 17 | padding = get_same_padding(kernel_size=4,stride=2) 18 | self.conv_blocks = nn.ModuleList(Conv2dBlock(chs[i],chs[i+1], 19 | kernel_size=(4,4), 20 | stride=2, 21 | padding=padding, 22 | use_norm=True, 23 | use_act=True, 24 | act=nn.ReLU() 25 | ) for i in range(len(chs)-2) 26 | ) 27 | 28 | self.conv_blocks.append(Conv2dBlock(chs[-2],chs[-1], 29 | kernel_size=(4,4), 30 | stride=2, 31 | padding=padding, 32 | use_norm=False, 33 | use_act=True, 34 | act=nn.Tanh())) 35 | def forward(self,images): 36 | intermediates = [] 37 | for block in self.conv_blocks[:-1]: 38 | images = block(images) 39 | intermediates.append(images) 40 | encoded = self.conv_blocks[-1](images) 41 | return encoded,intermediates 42 | -------------------------------------------------------------------------------- /wav2mov/models/generator/noise_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from wav2mov.logger import get_module_level_logger 5 | from wav2mov.models.utils import squeeze_batch_frames 6 | 7 | logger = get_module_level_logger(__name__) 8 | 9 | class NoiseEncoder(nn.Module): 10 | def __init__(self,hparams): 11 | super().__init__() 12 | self.hparams = hparams 13 | self.features_len = 10 14 | self.hidden_size = self.hparams['latent_dim_noise'] 15 | self.gru = nn.GRU(input_size=self.features_len, 16 | hidden_size=self.hidden_size, 17 | num_layers=1, 18 | batch_first=True) 19 | #input should be of shape batch_size,seq_len,input_size 20 | def forward(self,batch_size,num_frames): 21 | noise = torch.randn(batch_size,num_frames,self.features_len,device=self.hparams['device']) 22 | out,_ = self.gru(noise) 23 | return out#(batch_size,seq_len,hidden_size) 24 | -------------------------------------------------------------------------------- /wav2mov/models/layers/conv_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from wav2mov.models.utils import get_same_padding 5 | 6 | class DoubleConv2dBlock(nn.Module): 7 | """height and width are halved in feature map""" 8 | def __init__(self, 9 | in_ch, 10 | out_ch, 11 | use_norm=True, 12 | act=None): 13 | super().__init__() 14 | self.use_norm = use_norm 15 | self.conv1 = nn.Conv2d(in_ch, out_ch, 3,padding=get_same_padding(3,1),bias=not self.use_norm) 16 | self.batch_norm = nn.BatchNorm2d(out_ch) 17 | self.relu = nn.LeakyReLU(0.2) if act is None else act() 18 | self.conv2 = nn.Conv2d(out_ch, out_ch, 3) 19 | 20 | def forward(self, x): 21 | # print('double conv block',x.shape,type(x),x.device,next(self.parameters()).device) 22 | x = self.conv1(x) 23 | if self.use_norm : 24 | x = self.batch_norm(x) 25 | return self.relu(self.conv2(self.relu(x))) 26 | 27 | class Conv1dBlock(nn.Module): 28 | def __init__(self, 29 | in_ch, 30 | out_ch, 31 | kernel_size, 32 | stride, 33 | padding, 34 | use_norm=True, 35 | use_act=True,act=None, 36 | residual=False): 37 | 38 | super().__init__() 39 | self.use_norm = use_norm 40 | self.use_act = use_act 41 | self.act = nn.LeakyReLU(0.2) if act is None else act 42 | self.norm = nn.BatchNorm1d(out_ch) 43 | self.conv = nn.Conv1d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm) 44 | self.residual = residual 45 | 46 | def forward(self,in_x): 47 | x = self.conv(in_x) 48 | if self.use_norm: 49 | x = self.norm(x) 50 | if self.residual: 51 | x += in_x 52 | if self.use_act: 53 | x = self.act(x) 54 | return x 55 | 56 | class Conv2dBlock(nn.Module): 57 | def __init__(self, 58 | in_ch, 59 | out_ch, 60 | kernel_size, 61 | stride, 62 | padding, 63 | use_norm=True, 64 | use_act=True,act=None, 65 | residual=False): 66 | 67 | super().__init__() 68 | self.use_norm = use_norm 69 | self.use_act = use_act 70 | self.act = nn.LeakyReLU(0.2) if act is None else act 71 | self.norm = nn.BatchNorm2d(out_ch) 72 | self.conv = nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm) 73 | self.residual = residual 74 | 75 | def forward(self,in_x): 76 | x = self.conv(in_x) 77 | if self.use_norm: 78 | x = self.norm(x) 79 | if self.residual: 80 | x += in_x 81 | if self.use_act: 82 | x = self.act(x) 83 | return x 84 | 85 | class ConvTranspose2dBlock(nn.Module): 86 | def __init__(self,in_ch, 87 | out_ch, 88 | kernel_size, 89 | stride, 90 | padding, 91 | use_norm=True, 92 | use_act=True, 93 | act=None): 94 | super().__init__() 95 | self.use_norm = use_norm 96 | self.use_act = use_act 97 | self.act = nn.LeakyReLU(0.2) if act is None else act 98 | self.norm = nn.BatchNorm2d(out_ch) 99 | self.conv = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm) 100 | 101 | def forward(self,x): 102 | x = self.conv(x) 103 | if self.use_norm: 104 | x = self.norm(x) 105 | if self.use_act: 106 | x = self.act(x) 107 | return x 108 | 109 | class DoubleConvTranspose2d(nn.Module): 110 | def __init__(self,in_ch, 111 | skip_ch, 112 | out_ch, 113 | kernel_size, 114 | stride, 115 | padding, 116 | use_norm=True, 117 | use_act=True, 118 | act=None): 119 | super().__init__() 120 | self.use_norm = use_norm 121 | self.use_act = use_act 122 | 123 | self.conv1 = nn.Conv2d(in_ch+skip_ch,in_ch,kernel_size=3,stride=1,padding=1,bias=False) 124 | self.norm1 = nn.BatchNorm2d(in_ch) 125 | self.conv2 = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding) 126 | self.norm2 = nn.BatchNorm2d(out_ch) 127 | self.act = nn.LeakyReLU(0.2) if act is None else act 128 | 129 | def forward(self,x): 130 | x = nn.functional.relu(self.norm1(self.conv1(x))) 131 | x = self.conv2(x) 132 | if self.use_norm: 133 | x = self.norm2(x) 134 | if self.use_act: 135 | x = self.act(x) 136 | return x -------------------------------------------------------------------------------- /wav2mov/models/layers/debug_layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from wav2mov.logger import get_module_level_logger 4 | logger = get_module_level_logger(__name__) 5 | 6 | class IdentityDebugLayer(nn.Module): 7 | def __init__(self,name): 8 | super().__init__() 9 | self.name = name 10 | def forward(self,x): 11 | logger.debug(f'{self.name} : {x.shape}') 12 | return x 13 | -------------------------------------------------------------------------------- /wav2mov/models/utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from torch.nn import init 3 | 4 | from wav2mov.logger import get_module_level_logger 5 | logger = get_module_level_logger(__name__) 6 | 7 | def get_same_padding(kernel_size,stride,in_size=0): 8 | """https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks 9 | https://github.com/DinoMan/speech-driven-animation/blob/dc9fe4aa77b4e042177328ea29675c27e2f56cd4/sda/utils.py#L18-L21 10 | 11 | padding = 'same' 12 | • Padding such that feature map size has size In_size/Stride 13 | 14 | • Output size is mathematically convenient 15 | 16 | • Also called 'half' padding 17 | 18 | out = (in-k+2*p)/s + 1 19 | if out == in/s: 20 | in/s = (in-k+2*p)/s + 1 21 | ((in/s)-1)*s + k -in = 2*p 22 | (in-s)+k-in = 2*p 23 | in case of s==1: 24 | p = (k-1)/2 25 | """ 26 | out_size = ceil(in_size/stride) 27 | return ceil(((out_size-1)*stride+ kernel_size-in_size)/2)#(k-1)//2 for same padding 28 | 29 | def init_weights(net, init_type='normal', init_gain=0.02): 30 | """Initialize network weights. 31 | src : https://github1s.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 32 | Parameters: 33 | net (network) -- network to be initialized 34 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 35 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 36 | 37 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 38 | work better for some applications. Feel free to try yourself. 39 | """ 40 | def init_func(m): # define the initialization function 41 | classname = m.__class__.__name__ 42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 43 | if init_type == 'normal': 44 | init.normal_(m.weight.data, 0.0, init_gain) 45 | elif init_type == 'xavier': 46 | init.xavier_normal_(m.weight.data, gain=init_gain) 47 | elif init_type == 'kaiming': 48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 49 | elif init_type == 'orthogonal': 50 | init.orthogonal_(m.weight.data, gain=init_gain) 51 | else: 52 | raise NotImplementedError( 53 | 'initialization method [%s] is not implemented' % init_type) 54 | if hasattr(m, 'bias') and m.bias is not None: 55 | init.constant_(m.bias.data, 0.0) 56 | # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 57 | elif classname.find('BatchNorm2d') != -1: 58 | init.normal_(m.weight.data, 1.0, init_gain) 59 | init.constant_(m.bias.data, 0.0) 60 | 61 | logger.debug(f'initializing {net.__class__.__name__} with {init_type} weights') 62 | # print('initialize network with %s' % init_type) 63 | net.apply(init_func) # apply the initialization function 64 | 65 | 66 | def init_net(net, init_type='normal', init_gain=0.02): 67 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 68 | Parameters: 69 | net (network) -- the network to be initialized 70 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 71 | gain (float) -- scaling factor for normal, xavier and orthogonal. 72 | 73 | 74 | Return an initialized network. 75 | """ 76 | 77 | init_weights(net, init_type, init_gain=init_gain) 78 | return net 79 | 80 | 81 | def squeeze_batch_frames(target): 82 | batch_size,num_frames,*extra = target.shape 83 | return target.reshape(batch_size*num_frames,*extra) -------------------------------------------------------------------------------- /wav2mov/models/wav2mov_inferencer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from wav2mov.core.models import TemplateModel 4 | from wav2mov.core.data.utils import AudioUtil 5 | 6 | from wav2mov.models import Generator 7 | 8 | import logging 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.DEBUG) 12 | 13 | def no_grad_wrapper(fn): 14 | def wrapper(*args,**kwargs): 15 | with torch.no_grad(): 16 | return fn(*args,**kwargs) 17 | return wrapper 18 | 19 | class Wav2movInferencer(TemplateModel): 20 | def __init__(self,hparams,logger): 21 | super().__init__() 22 | self.hparams = hparams 23 | self.logger = logger 24 | self.device = 'cpu' 25 | self.gen = Generator(hparams['gen']) 26 | self.stride = self.hparams['data']['audio_sf']//self.hparams['data']['video_fps'] 27 | self.audio_util = AudioUtil(self.hparams['data']['coarticulation_factor'],self.stride,self.device) 28 | 29 | def load(self,checkpoint): 30 | self.gen.load_state_dict(checkpoint) 31 | 32 | def _squeeze_batch_frames(self,target): 33 | batch_size,num_frames,*extra = target.shape 34 | return target.reshape(batch_size*num_frames,*extra) 35 | 36 | def forward(self,audio_frames,ref_video_frames): 37 | self.logger.debug(f'[GENERATION] audio_frames : {audio_frames.shape} | ref_video_frames : {ref_video_frames.shape}') 38 | self.gen.eval() 39 | batch_size,num_frames,*extra = ref_video_frames.shape 40 | assert batch_size==audio_frames.shape[0] and num_frames ==audio_frames.shape[1] 41 | # audio_frames = self._squeeze_batch_frames(audio_frames) 42 | # ref_video_frames = self._squeeze_batch_frames(ref_video_frames) 43 | fake_frames = self.gen(audio_frames,ref_video_frames) 44 | return fake_frames 45 | 46 | @no_grad_wrapper 47 | def generate(self,audio_frames,ref_video_frames,fraction=None): 48 | if fraction is None: 49 | return self(audio_frames,ref_video_frames) 50 | else: 51 | return self._generate_with_fraction(audio_frames,ref_video_frames,fraction) 52 | 53 | @no_grad_wrapper 54 | def test(self,audio_frames,video): 55 | """test the generation of face images 56 | 57 | Args: 58 | audio (tensor | (B,S)): 59 | audio_frames (tensor | (B,F,Sw)): 60 | video (tensor | (B,F,C,H,W)): 61 | """ 62 | fraction = 2 63 | ref_video_frames = self._get_ref_video_frames(video) 64 | 65 | return self.generate(audio_frames,ref_video_frames,fraction),ref_video_frames[:,0,...] 66 | 67 | def __get_random_indices(self,low,high,num_frames): 68 | return random.sample(range(low,high),k=num_frames) 69 | 70 | def _get_ref_video_frames(self,video): 71 | num_frames = video.shape[1] 72 | ################################################################ 73 | # random_ids = self.__get_random_indices(0,num_frames,num_frames) 74 | # ref_video_frames = video[:,random_ids] 75 | ################################################################ 76 | REF_FRAME_IDX = int(0.6*num_frames) 77 | ref_video_frames = video[:,REF_FRAME_IDX,:,:,:] 78 | ref_video_frames = ref_video_frames.unsqueeze(1) 79 | ref_video_frames = ref_video_frames.repeat(1,num_frames,1,1,1) 80 | ################################################################ 81 | return ref_video_frames 82 | 83 | def _squeeze_frames(self,video): 84 | bsize,nframes,*extra = video.shape 85 | return video.reshape(bsize*nframes,*extra) 86 | 87 | 88 | def _generate_with_fraction(self,audio_frames,ref_video_frames,fraction): 89 | num_frames = audio_frames.shape[1] 90 | start_frame = 0 91 | fake_video_frames= [] 92 | for i in range(fraction): 93 | end_frame = int((1/fraction)*(i+1)*num_frames) if i!=fraction-1 else num_frames 94 | audio_frames_sample = audio_frames[:,start_frame:end_frame,...] 95 | ref_video_frames_sample = ref_video_frames[:,start_frame:end_frame,...] 96 | fake_video_frames.append(self(audio_frames_sample,ref_video_frames_sample)) 97 | start_frame = end_frame 98 | 99 | return torch.cat(fake_video_frames,dim=1) 100 | 101 | -------------------------------------------------------------------------------- /wav2mov/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 10, 3 | "ref_frame_idx": -10, 4 | "device": "cpu", 5 | "img_size": 256, 6 | "num_epochs": 4, 7 | "train_sync_expert": false, 8 | "train_sync": false, 9 | "pre_learning_epochs": 100, 10 | "adversarial_with_id": 300, 11 | "adversarial_with_sync": 100, 12 | "stop_adversarial_with_sync": 1000, 13 | "adversarial_with_seq": 300, 14 | "img_channels": 3, 15 | "num_frames_fraction": 15, 16 | "data": { 17 | "img_size": 256, 18 | "img_channels": 3, 19 | "coarticulation_factor": 2, 20 | "audio_sf": 16000, 21 | "video_fps": 24, 22 | "batch_size": 6, 23 | "mini_batch_size": 6, 24 | "mean": 0.516, 25 | "std": 0.236 26 | }, 27 | "disc": { 28 | "sync_disc": { 29 | "in_channels": 3, 30 | "lr": 1e-4, 31 | "relu_neg_slope": 0.01 32 | }, 33 | "sequence_disc": { 34 | "in_channels": 3, 35 | "chs": [64, 128, 256, 512, 1], 36 | "in_size": 32, 37 | "h_size": 256, 38 | "num_layers": 1, 39 | "lr": 1e-4, 40 | "relu_neg_slope": 0.01 41 | }, 42 | 43 | "identity_disc": { 44 | "in_channels": 3, 45 | "chs": [64, 128, 256, 512, 1024, 1], 46 | "lr": 1e-4, 47 | "relu_neg_slope": 0.01 48 | }, 49 | "patch_disc": { 50 | "ndf": 64, 51 | "in_channels": 3, 52 | "num_layers": 3, 53 | "lr": 1e-4 54 | } 55 | }, 56 | 57 | "gen": { 58 | "in_channels": 3, 59 | "chs": [64, 128, 256, 512, 1024], 60 | "latent_dim": 272, 61 | "latent_dim_id": [8, 8], 62 | "comment": "laten_dim not eq latent_dim_id + latent_dim_audio, its 4x4 + 256", 63 | "latent_dim_audio": 256, 64 | "device": "cpu", 65 | "lr": 2e-4 66 | }, 67 | 68 | "scales_archieved": { 69 | "lambda_seq_disc": 0.3, 70 | "lambda_sync_disc": 0.8, 71 | "lambda_id_disc": 1, 72 | "lambda_L1": 50 73 | }, 74 | "scales": { 75 | "lambda_seq_disc": 0.6, 76 | "lambda_sync_disc": 0.8, 77 | "lambda_id_disc": 1, 78 | "lambda_L1": 50 79 | }, 80 | "scheduler": { 81 | "gen": { 82 | "step_size": 20, 83 | "gamma": 0.02 84 | }, 85 | "discs": { 86 | "step_size": 20, 87 | "gamma": 0.1 88 | }, 89 | "max_epoch": 100 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /wav2mov/params.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from wav2mov.settings import BASE_DIR 4 | from wav2mov.logger import get_module_level_logger 5 | logger = get_module_level_logger(__name__) 6 | class Params: 7 | def __init__(self): 8 | self.vals = {} 9 | 10 | 11 | 12 | def _flatten_dict(self, d): 13 | flattened = {} 14 | for k, v in d.items(): 15 | if isinstance(v, dict): 16 | inner_d = self._flatten_dict(v) 17 | flattened = {**flattened, **inner_d} 18 | else: 19 | flattened[k] = v 20 | return flattened 21 | 22 | def update(self,key,value): 23 | if key in self.vals: 24 | logger.warning(f'Updating existing parameter {key} : changing from {self.vals[key]} with {value}') 25 | 26 | self.vals[key] = value 27 | 28 | def _update_vals_from_dict(self, d: dict): 29 | self.vals = d 30 | # d = self._flatten_dict(d) 31 | # self.vals = {**self.vals,**d} 32 | 33 | @classmethod 34 | def from_json(cls, json_file_fullpath): 35 | with open(json_file_fullpath, 'r') as file: 36 | configs = json.load(file) 37 | obj = cls() 38 | obj._update_vals_from_dict(configs) 39 | return obj 40 | 41 | def __getitem__(self, item): 42 | if item not in self.vals: 43 | logger.error(f'{self.__class__.__name__} object has no key called {item}') 44 | raise KeyError(f'No key called {item}') 45 | 46 | return self.vals[item] 47 | 48 | def __repr__(self): 49 | fstr = '' 50 | for key,val in self.vals.items(): 51 | fstr += f"{key} : {val}\n" 52 | return fstr 53 | 54 | def set(self,key,val): 55 | if key in self.vals: 56 | logger.warning(f'updating existing value of {key} with value {self.vals[key] } to {val}') 57 | self.vals[key] = val 58 | 59 | def save(self,file_fullpath): 60 | with open(file_fullpath,'w') as file: 61 | json.dump(self.vals,file) 62 | 63 | 64 | #Singleton Object 65 | params = Params.from_json(os.path.join(BASE_DIR, 'params.json')) 66 | 67 | if __name__ == '__main__': 68 | pass 69 | -------------------------------------------------------------------------------- /wav2mov/plans/README.md: -------------------------------------------------------------------------------- 1 | # Plans 2 | 3 | ![Plan V1](images/plan_v1.png) 4 | ![System](images/system.png) 5 | ![File Structure](images/components.png) 6 | ![Generator](images/gen_arch.png) 7 | -------------------------------------------------------------------------------- /wav2mov/plans/extras.md: -------------------------------------------------------------------------------- 1 | # ✨ Mystics of Python 🐍 2 | 3 | ## Exploring the unexplored 🌍 🗺 4 | 5 | ------------------ 6 | 7 | #### Image shapes: 8 | 9 | + Numpy / Matplotlib / opencv treats image with shape H,W,C 10 | 11 | + Pytorch requires image to be in thes shape C,H,W 12 | 13 | #### * and ** operators 14 | 15 | ------------------ 16 | 17 | Below are 6 different use cases for * and ** in python programming: 18 | [source : StackOverflow : https://stackoverflow.com/a/59630576/12988588](https://stackoverflow.com/a/59630576/12988588) 19 | 20 | + __To accept any number of positional arguments using *args:__ 21 | 22 | ```python 23 | def foo(*args): 24 | pass 25 | ``` 26 | 27 | here foo accepts any number of positional arguments, 28 | i. e., the following calls are valid ``foo(1)``, ``foo(1, 'bar')`` 29 | 30 | + __To accept any number of keyword arguments using **kwargs:__ 31 | 32 | ```python 33 | def foo(**kwargs): 34 | pass 35 | ``` 36 | 37 | here 'foo' accepts any number of keyword arguments, 38 | i. e., the following calls are valid ``foo(name='Tom')``,``foo(name='Tom', age=33)`` 39 | 40 | + __To accept any number of positional and keyword arguments using *args, **kwargs:__ 41 | 42 | ```python 43 | def foo(*args, **kwargs): 44 | pass 45 | ``` 46 | 47 | here foo accepts any number of positional and keyword arguments, 48 | i. e., the following calls are valid ``foo(1,name='Tom')``, ``foo(1, 'bar', name='Tom', age=33) 49 | `` 50 | 51 | + __To enforce keyword only arguments using *:__ 52 | 53 | ```python 54 | def foo(pos1, pos2, *, kwarg1): 55 | pass 56 | ``` 57 | 58 | here * means that foo only accept keyword arguments after pos2, hence foo(1, 2, 3) raises TypeError 59 | but ``foo(1, 2, kwarg1=3)`` is ok. 60 | 61 | + __To express no further interest in more positional arguments using `*_` (Note: this is a convention only):__ 62 | 63 | ```python 64 | def foo(bar, baz, *_): 65 | pass 66 | ``` 67 | 68 | means (by convention) foo only uses bar and baz arguments in its working and will ignore others. 69 | 70 | + __To express no further interest in more keyword arguments using `\**_` (Note: this is a convention only):__ 71 | 72 | ```python 73 | def foo(bar, baz, **_): 74 | pass 75 | ``` 76 | 77 | means (by convention) foo only uses bar and baz arguments in its working and will ignore others. 78 | 79 | + __BONUS__: From python 3.8 onward, one can use `/` in function definition to enforce positional only parameters. 80 | In the following example, parameters a and b are positional-only, while c or d can be positional or keyword, 81 | and e or f are required to be keywords: 82 | 83 | ```python 84 | def f(a, b, /, c, d, *, e, f): 85 | pass 86 | ``` 87 | 88 | 89 | -------------------------------------------------------------------------------- /wav2mov/plans/images/components.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/components.png -------------------------------------------------------------------------------- /wav2mov/plans/images/gen_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/gen_arch.png -------------------------------------------------------------------------------- /wav2mov/plans/images/plan_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/plan_v1.png -------------------------------------------------------------------------------- /wav2mov/plans/images/system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/system.png -------------------------------------------------------------------------------- /wav2mov/plans/observations.md: -------------------------------------------------------------------------------- 1 | # Observations 🔍 2 | 3 | *** 4 | 5 | + __sync discriminator loss is very small (of the order of e-9)__ 6 | 7 | + solved by using removing batch normalization and using leaky version of relu with slope of 0.2 instead of vanilla relu. 8 | 9 | + __changes in losses is very small__ 10 | 11 | ## lip variations 12 | + it requires longer times to get variations . 13 | + size of dataset plays huge role 14 | (60 videos with 2 actores took nearly 200 epochs to show mild variations between frames) 15 | + train sync discriminator only on real frames initially both on synchronized and unsynchronized audio and video frames. 16 | -------------------------------------------------------------------------------- /wav2mov/preprocess.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | TITLE wav2mov 3 | echo Running preprocessing 4 | 5 | set DEVICE="cpu" 6 | set VERSION="preprocess_500_a23456" 7 | set LOG="y" 8 | set GRID_DATASET_DIR="/content/drive/MyDrive/Colab Notebooks/projects/wav2mov/datasets/grid_a5_500_a10to14_raw" 9 | echo %GRID_DATASET_DIR% 10 | @REM python main/main.py --preprocess=y -grid=%GRID_DATASET_DIR% --device=%DEVICE% --version=%VERSION% --log=%LOG% -------------------------------------------------------------------------------- /wav2mov/preprocess.sh: -------------------------------------------------------------------------------- 1 | echo "Preprocess" 2 | DEVICE='cpu' 3 | VERSION='preprocess_500_a23456' 4 | LOG='y' 5 | GRID_DATASET_DIR='/content/drive/MyDrive/Colab Notebooks/projects/wav2mov/datasets/grid_a5_500_a10to14_raw' 6 | python main/main.py --preprocess=y -grid="${GRID_DATASET_DIR}" --device=${DEVICE} --version=${VERSION} --log=${LOG} 7 | -------------------------------------------------------------------------------- /wav2mov/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/requirements.txt -------------------------------------------------------------------------------- /wav2mov/run.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | TITLE wav2mov 3 | echo Running training script 4 | 5 | set EPOCHS=1 6 | set NUM_VIDEOS=14 7 | set VERSION="v9" 8 | set "COMMENT='GPU | scaled id ,gen_id losses by FRACTION | prelearning till 15 epoch | 10 l1_l and 1 id_l'" 9 | set DEVICE="cuda" 10 | set IS_TRAIN="y" 11 | set LOG="n" 12 | set "MODEL_PATH='/content/drive/MyDrive/Colab Notebooks/projects/wav2mov_engine/wav2mov/runs/v9/Run_6_5_2021__17_28'" 13 | @REM echo 'options chosen are '"$EPOCHS"' is_training = '"$IS_TRAIN" 14 | @REM #python main/main.py --train=%IS_TRAIN% -e=%EPOCHS% -v=%NUM_VIDEOS% -m="%COMMENT%" --device=%DEVICE% --version=%VERSION% --log=%LOG% --model_path=%MODEL_PATH% 15 | python main/main.py --train=%IS_TRAIN% -e=%EPOCHS% -v=%NUM_VIDEOS% -m="%COMMENT%" --device=%DEVICE% --version=%VERSION% --log=%LOG% 16 | -------------------------------------------------------------------------------- /wav2mov/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run main/main.py 3 | EPOCHS=300 4 | NUM_VIDEOS=120 5 | VERSION="v16_sync_expert" 6 | TRAIN_SYNC_EXPERT='n' 7 | COMMENT='GPU | 225 to 300| no noise encoder | interleaved sync training | only l1 at the beginning' 8 | # COMMENT='GPU | 375 to 450| seq upper half and sync lower half | sync bce loss' 9 | DEVICE="cuda" 10 | IS_TRAIN="y" 11 | LOG='y' 12 | MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/Run_11_7_2021__17_44" 13 | #echo 'options chosen are '"$EPOCHS"' is_training = '"$IS_TRAIN" 14 | # python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --train_sync_expert=${TRAIN_SYNC_EXPERT} 15 | python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --model_path="$MODEL_PATH" --train_sync_expert=${TRAIN_SYNC_EXPERT} 16 | -------------------------------------------------------------------------------- /wav2mov/run_sync_expert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run main/main.py 3 | EPOCHS=450 4 | NUM_VIDEOS=120 5 | VERSION="sync_expert" 6 | TRAIN_SYNC_EXPERT='y' 7 | COMMENT='GPU |350 to 450| train only sync disc to make it a expert' 8 | # COMMENT='GPU | 375 to 450| seq upper half and sync lower half | sync bce loss' 9 | DEVICE="cuda" 10 | IS_TRAIN="y" 11 | LOG='y' 12 | # MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/v16_sync_expert/Run_10_7_2021__23_46" 13 | # MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/Run_10_7_2021__14_19" 14 | MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/Run_11_7_2021__11_24" 15 | #echo 'options chosen are '"$EPOCHS"' is_training = '"$IS_TRAIN" 16 | # python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --train_sync_expert=${TRAIN_SYNC_EXPERT} 17 | python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --model_path="$MODEL_PATH" --train_sync_expert=${TRAIN_SYNC_EXPERT} 18 | -------------------------------------------------------------------------------- /wav2mov/settings.py: -------------------------------------------------------------------------------- 1 | """ Base Dir """ 2 | import os 3 | 4 | 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | -------------------------------------------------------------------------------- /wav2mov/test.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | TITLE Wav2Mov 3 | 4 | set is_test=y 5 | set "model_path=E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\runs\v6\%1\gen_%1.pt" 6 | set log=n 7 | set device=cpu 8 | set version=v9 9 | echo %model_path% 10 | python main/main.py --device=%device% --test=%is_test% --version=%version% --model_path=%model_path% --log=%log% -v=14 -------------------------------------------------------------------------------- /wav2mov/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | VERSION="v16_sync_expert" 3 | echo $1 4 | is_test='y' 5 | model_path="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/${1}/gen_${1}.pt" 6 | log='n' 7 | device='cpu' 8 | sample_num=2 9 | python main/main.py --device=${device} --test=${is_test} --version=${VERSION} --model_path="${model_path}" --log=${log} -v=14 --test_sample_num=${sample_num} 10 | -------------------------------------------------------------------------------- /wav2mov/tests/old/log.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/tests/old/log.json -------------------------------------------------------------------------------- /wav2mov/tests/old/test.py: -------------------------------------------------------------------------------- 1 | lines = ['aytala\n','macha'] 2 | with open('test.txt','a+') as f: 3 | f.writelines(lines) 4 | 5 | with open('test.txt','r') as f: 6 | read_lines = f.read().split('\n') 7 | print(read_lines) 8 | 9 | 10 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test.txt: -------------------------------------------------------------------------------- 1 | aytala 2 | machaaytala 3 | machaaytala 4 | machaaytala 5 | machaaytala 6 | machaaytala 7 | machaaytala 8 | macha -------------------------------------------------------------------------------- /wav2mov/tests/old/test_3d_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from wav2mov.models.sequence_discriminator import SequenceDiscriminatorCNN 5 | 6 | 7 | def get_shape(in_shape,kernels,strides,padding): 8 | 9 | new_shape = [] 10 | for i,k,s,p in zip(in_shape,kernels,strides,padding): 11 | new_shape.append(((i-k+2*p)/s)+1) 12 | 13 | return new_shape 14 | 15 | def test(hparams): 16 | x1 = torch.randn(1,1,256,256) #N,Cin,Depth,Height,Width 17 | x2 = torch.randn(1,1,256,256) 18 | model = SequenceDiscriminatorCNN(hparams) 19 | 20 | print(model(x1,x2).shape) 21 | strides = (2,2,2) 22 | padding = (1,1,1) 23 | in_shape = (2,256,256) 24 | kernels = (2,4,4) 25 | print(get_shape(in_shape,kernels,strides,padding)) 26 | """ 27 | out shape : torch.Size([1, 6, 2, 128, 128]) 28 | out shape : torch.Size([1, 32, 2, 64, 64]) 29 | out shape : torch.Size([1, 64, 2, 32, 32]) 30 | out shape : torch.Size([1, 32, 2, 16, 16]) 31 | out shape : torch.Size([1, 16, 2, 8, 8]) 32 | out shape : torch.Size([1, 1, 2, 4, 4]) 33 | torch.Size([1, 32]) 34 | 35 | 1=>4=>8=>16=>8=>4=>1 36 | out shape : torch.Size([1, 4, 2, 128, 128]) 37 | out shape : torch.Size([1, 8, 2, 64, 64]) 38 | out shape : torch.Size([1, 16, 2, 32, 32]) 39 | out shape : torch.Size([1, 8, 2, 16, 16]) 40 | out shape : torch.Size([1, 4, 2, 8, 8]) 41 | out shape : torch.Size([1, 1, 2, 4, 4]) 42 | torch.Size([1, 32]) 43 | [2.0, 128.0, 128.0] 44 | """ 45 | 46 | def main(): 47 | hparams ={'in_channels':1,'chs':[4,8,16,8,4,1]} 48 | test(hparams) 49 | 50 | if __name__=='__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test_audio_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import unittest 4 | from wav2mov.core.data.utils import AudioUtil 5 | from wav2mov.utils.audio import StridedAudioV2 6 | 7 | 8 | 9 | logger = logging.getLogger(__file__) 10 | logger.setLevel(logging.DEBUG) 11 | class TestAudioUtil(unittest.TestCase): 12 | def test_frames_cont(self): 13 | BATCH_SIZE = 2 14 | audio = torch.randn(BATCH_SIZE,668*40) 15 | 16 | for coarticulation_factor in range(5): 17 | strided_audio = StridedAudioV2(666,coarticulation_factor) 18 | get_frames_from_idx,_ = strided_audio.get_frame_wrapper(audio) 19 | logging.debug('For coarticulation factor of {}'.format(coarticulation_factor)) 20 | for i in range((audio.shape[0]//666)): 21 | frame = get_frames_from_idx(i) 22 | if i%10==0: 23 | logging.debug(f'{i},{frame.shape}') 24 | self.assertEqual(frame.shape, (1,(coarticulation_factor*2 + 1)*666)) 25 | 26 | def test_frames_range(self): 27 | BATCH_SIZE = 2 28 | audio = torch.randn(BATCH_SIZE,668*40) 29 | 30 | for coarticulation_factor in range(5): 31 | strided_audio = StridedAudioV2(666,coarticulation_factor) 32 | _,get_frames_from_range = strided_audio.get_frame_wrapper(audio) 33 | logging.debug('For coarticulation factor of {}'.format(coarticulation_factor)) 34 | num_frames=5 35 | for i in range((audio.shape[0]//666)): 36 | frame = get_frames_from_range(i,i+num_frames-1) 37 | if i%10==0: 38 | logging.debug(f'{i},{frame.shape}') 39 | self.assertEqual(frame.shape, (1,(num_frames+2)*666)) 40 | 41 | def test_limit_audio(self): 42 | COARTICULATION_FACTOR = 2 43 | STRIDE = 666 44 | audio_util = AudioUtil(COARTICULATION_FACTOR,STRIDE) 45 | audio = torch.randn(1,45000) 46 | limited_audio = audio_util.get_limited_audio(audio,5,2) 47 | self.assertEqual(limited_audio.shape,(1,5*STRIDE +2*COARTICULATION_FACTOR*STRIDE)) 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test_batching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from wav2mov.params import params 4 | from wav2mov.config import get_config 5 | 6 | from wav2mov.core.data.utils import AudioUtil 7 | from wav2mov.core.data.collates import get_batch_collate 8 | 9 | from wav2mov.main.data import get_dataloaders 10 | 11 | STRIDE = 666 12 | audio_util = AudioUtil(2,stride=STRIDE) 13 | 14 | class Options: 15 | num_videos = 9 16 | v = 'v_test' 17 | options = Options() 18 | config = get_config(options.v) 19 | 20 | class TestBatching(unittest.TestCase): 21 | def test_audio_batching(self): 22 | AUDIO_LEN = 44200 23 | audio = torch.randn(AUDIO_LEN) 24 | frames = audio_util.get_audio_frames(audio.unsqueeze(0)) 25 | self.assertEqual(frames.shape,(AUDIO_LEN//STRIDE,STRIDE*(4+1))) 26 | 27 | 28 | def test_collate(self): 29 | params.update('data',{**params['data'],'mini_batch_size':2}) 30 | params.update('data',{**params['data'],'coarticulation_factor':0}) 31 | collate = get_batch_collate(hparams=params['data']) 32 | 33 | dl = get_dataloaders(options,config,params,get_mean_std=False,collate_fn=collate) 34 | 35 | batch_size = params['data']['mini_batch_size'] 36 | # print(vars(dl.train)) 37 | # self.assertEqual(len(dl.train),batch_size) 38 | 39 | for i,sample in enumerate(dl.train): 40 | audio = sample.audio 41 | audio_frames = sample.audio_frames 42 | video = sample.video 43 | print(f'Batch {i}') 44 | print(f'audio :{audio.shape}') 45 | print(f'audio frames :{audio_frames.shape}') 46 | print(f'video : {video.shape} ') 47 | 48 | self.assertEqual(sample.audio.shape[0],batch_size) 49 | 50 | if __name__ == '__main__': 51 | unittest.main() -------------------------------------------------------------------------------- /wav2mov/tests/old/test_config.py: -------------------------------------------------------------------------------- 1 | from wav2mov.config import config 2 | 3 | for k,v in config.vals.items(): 4 | print(f"{k} : {v}") -------------------------------------------------------------------------------- /wav2mov/tests/old/test_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | from wav2mov.core.data.datasets import AudioVideoDataset 5 | from wav2mov.utils.audio import StridedAudio 6 | from settings import BASE_DIR 7 | ROOT_DIR = os.path.join(BASE_DIR,'datasets','grid_dataset') 8 | 9 | def test(): 10 | strided_audio = StridedAudio(16000//24,1) 11 | dataset = AudioVideoDataset(ROOT_DIR,os.path.join(ROOT_DIR,'filenames.txt'),video_fps=24,audio_sf=16000) 12 | for i,sample in enumerate(dataset): 13 | audio,video = sample 14 | stride = math.floor(16_000/24) 15 | print(f"Stride = {stride}") 16 | print(f"Audio shape,video shape , audio.shape[0]//stride") 17 | print(audio.shape,video.shape,audio.shape[0]//stride) 18 | get_frames = strided_audio.get_frame_wrapper(audio) 19 | for i in range(video.shape[0]): 20 | frame = get_frames(i) 21 | print(frame[0].shape,frame[1]) 22 | break 23 | print("="*10) 24 | if i==3: 25 | break 26 | def main(): 27 | test() 28 | return 29 | if __name__=='__main__': 30 | main() -------------------------------------------------------------------------------- /wav2mov/tests/old/test_discriminators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from wav2mov.models.identity_discriminator import IdentityDiscriminator 5 | from wav2mov.models.sequence_discriminator import SequenceDiscriminator 6 | from wav2mov.models.sync_discriminator import SyncDiscriminator 7 | from wav2mov.models.patch_disc import PatchDiscriminator 8 | 9 | from utils import get_input_of_shape,no_grad 10 | BATCH_SIZE = 1 11 | IMAGE_SIZE = 256 12 | CHANNELS = 3 13 | 14 | 15 | 16 | 17 | class TestDescs(unittest.TestCase): 18 | @no_grad 19 | def test_identity(self): 20 | x = get_input_of_shape((BATCH_SIZE,CHANNELS,IMAGE_SIZE,IMAGE_SIZE)) 21 | y = get_input_of_shape((BATCH_SIZE,CHANNELS,IMAGE_SIZE,IMAGE_SIZE)) 22 | desc = IdentityDiscriminator() 23 | out = desc(x,y) 24 | print(f"identity descriminator : input :{x.shape} and {y.shape} | output : {out.shape}") 25 | self.assertEqual(out.shape,(BATCH_SIZE,16)) 26 | 27 | @no_grad 28 | def test_sequence(self): 29 | image_size = IMAGE_SIZE*IMAGE_SIZE*3 30 | hidden_size = 100 31 | 32 | desc = SequenceDiscriminator(image_size,hidden_size,num_layers=1) 33 | x = get_input_of_shape((BATCH_SIZE,2,image_size)) 34 | out = desc(x) 35 | print(f"sequence descriminator : input :{x.shape} | output : {out.shape}") 36 | self.assertEqual(out.shape,(BATCH_SIZE,hidden_size)) 37 | 38 | @no_grad 39 | def test_sync(self): 40 | audio = get_input_of_shape((BATCH_SIZE,666)) 41 | image = get_input_of_shape((BATCH_SIZE,CHANNELS,IMAGE_SIZE,IMAGE_SIZE)) 42 | desc = SyncDiscriminator() 43 | out = desc(audio,image) 44 | print(f"sync descriminator : input :{audio.shape} and {image.shape} | output : {out.shape}") 45 | self.assertEqual(out.shape,(BATCH_SIZE,128)) 46 | 47 | @no_grad 48 | def test_patch_disc(self): 49 | frame_image = get_input_of_shape((BATCH_SIZE,1,256,256)) 50 | still_image = get_input_of_shape((BATCH_SIZE,1,256,256)) 51 | disc = PatchDiscriminator(1,ndf=64) 52 | out = disc(frame_image,still_image) 53 | print(f'patch disc out: {out.shape}') 54 | self.assertEqual(out.shape,(BATCH_SIZE,1,30,30)) 55 | 56 | def main(): 57 | unittest.main() 58 | return 59 | 60 | if __name__=='__main__': 61 | main() -------------------------------------------------------------------------------- /wav2mov/tests/old/test_file_logger.py: -------------------------------------------------------------------------------- 1 | from wav2mov.logger import Logger 2 | 3 | def test(): 4 | logger = Logger(__name__) 5 | logger.add_console_handler() 6 | logger.add_filehandler('logs/log.log') 7 | 8 | log = [] 9 | for i in range(5): 10 | log.append(f'log {i}') 11 | 12 | logger.info('\t'.join(log)) 13 | 14 | 15 | if __name__ == '__main__': 16 | test() -------------------------------------------------------------------------------- /wav2mov/tests/old/test_logger.py: -------------------------------------------------------------------------------- 1 | from wav2mov.logger import Logger 2 | 3 | log_path = r'E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\tests\log.json' 4 | def test_logger(): 5 | logger = Logger(__file__) 6 | logger.add_console_handler() 7 | logger.add_filehandler(log_path,in_json=True) 8 | logger.debug('testing json logger %d',extra={'great':1}) 9 | logger.debug('testing json logger',extra={'great':1}) 10 | logger.debug('testing json logger',extra={'great':1}) 11 | logger.debug('testing json logger',extra={'great':1}) 12 | 13 | def main(): 14 | test_logger() 15 | if __name__=='__main__': 16 | main() 17 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test_main_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from wav2mov.main.data import get_dataloaders 4 | from wav2mov.config import config 5 | from wav2mov.params import params 6 | 7 | import argparse 8 | # argparser = argparse.ArgumentParser(description='test dataloader') 9 | # argparser.add_argument('--num_videos','-n',type=int,help='number of videos') 10 | # options = argparser.parse_args() 11 | class Options:pass 12 | options = Options() 13 | options.num_videos=36 14 | class TestData(unittest.TestCase): 15 | def test_dataloader(self): 16 | dataloader,mean,std = get_dataloaders(options,config,params,shuffle=True) 17 | sample = next(iter(dataloader.train)) 18 | print(f'train dl : {len(dataloader.train)}') 19 | print(f'val dl : {len(dataloader.val)}') 20 | print(f'channel wise mean : {mean}') 21 | print(f'channel wise std : {std}') 22 | print(f'audio shape : {sample.audio.shape }') 23 | print(f'video shape : {sample.video.shape }') 24 | 25 | if __name__ == '__main__': 26 | unittest.main() -------------------------------------------------------------------------------- /wav2mov/tests/old/test_module_level_logger.py: -------------------------------------------------------------------------------- 1 | from wav2mov.main import callbacks 2 | 3 | def test(): 4 | m_logger = callbacks.m_logger 5 | m_logger.debug(f'{m_logger.name} {m_logger.level}') 6 | if __name__ == '__main__': 7 | test() -------------------------------------------------------------------------------- /wav2mov/tests/old/test_packing.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence,pad_packed_sequence 4 | 5 | from wav2mov.core.data.collates import collate_fn 6 | 7 | from wav2mov.main.data import get_dataloaders 8 | from wav2mov.params import params 9 | from wav2mov.config import config 10 | 11 | import argparse 12 | def parse_args(): 13 | arg_parser = argparse.ArgumentParser(description='Testing Variable Sequence Batch Size') 14 | arg_parser.add_argument('--samples','-s',type=int,default=1,help='number of samples whose details will be printed') 15 | arg_parser.add_argument('--batch_size','-b',type=int,default=10,help='batch size for dataloader') 16 | return arg_parser.parse_args() 17 | 18 | 19 | 20 | 21 | 22 | def test_batch_size(options): 23 | 24 | dl,_,_ = get_dataloaders(config,params,get_mean_std=False,collate_fn=collate_fn) 25 | dl = dl.train 26 | for i in range(options.samples): 27 | sample,lens = next(iter(dl)) 28 | audios,videos = sample 29 | audio_lens,video_lens = lens 30 | print(audios.shape) 31 | videos = pack_padded_sequence(videos,lengths=video_lens,batch_first=True,enforce_sorted=False) 32 | audios = pack_padded_sequence(audios, lengths = audio_lens,batch_first=True,enforce_sorted = False) 33 | print(videos,audios) 34 | print(f' sample {i+1} '.center(30,'=')) 35 | # print('video shape : ',sample.video.shape) 36 | # print('audio shape : ',sample.audio.shape) 37 | 38 | print('video shape : ',videos.data.shape) 39 | print('audio shape : ',audios.data.shape) 40 | print('batch sizes : ',videos.batch_sizes) 41 | # print('lens audio : ',lens.audio) 42 | # print('lens video :',lens.video) 43 | p ,l= pad_packed_sequence(videos,batch_first=True) 44 | print(p.shape,l) 45 | print(''*30) 46 | 47 | 48 | if __name__ == '__main__': 49 | options = parse_args() 50 | BATCH_SIZE = options.batch_size 51 | data = params['data'] 52 | params.set('data', {**data, 'batch_size': BATCH_SIZE}) 53 | test_batch_size(options) 54 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) -------------------------------------------------------------------------------- /wav2mov/tests/old/test_stored_video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms as vtransforms 3 | 4 | import numpy as np 5 | from torchvision.transforms.transforms import CenterCrop 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('--gray','-g',choices=['y','n'],default='y',type=str,help='whether the output should be grayscale or color image') 11 | args = parser.parse_args() 12 | print('settings : gray(y/n): ',args.gray) 13 | from wav2mov.utils.plots import show_img 14 | 15 | 16 | 17 | path = r'E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\datasets\grid_dataset\s10_l_bbat9p\video_frames.npy' 18 | 19 | 20 | def test_db(): 21 | 22 | 23 | video = np.load(path).astype('float64') 24 | video = video.transpose(0,3,1,2) 25 | video = torch.from_numpy(video) 26 | print(f'video shape{video.shape}') 27 | 28 | 29 | 30 | 31 | channels = video.shape[1] 32 | print(' before '.center(10,'=')) 33 | print('mean :',[torch.mean(video[0][:,i,...]) for i in range(channels)]) 34 | print('std : ', [torch.std(video[0][:,i,...]) for i in range(channels)]) 35 | video = video/255 36 | 37 | channels = 1 if args.gray == 'y' else 3 38 | transforms = vtransforms.Compose( 39 | 40 | [ 41 | vtransforms.Grayscale(1), 42 | # vtransforms.CenterCrop(256), 43 | vtransforms.Resize((256, 256)), 44 | vtransforms.Normalize([0.5]*channels, [0.5]*channels) 45 | ]) 46 | print(' after '.center(10,'=')) 47 | print('mean :',[torch.mean(video[0][:,i,...]) for i in range(channels)]) 48 | print('std : ', [torch.std(video[0][:,i,...]) for i in range(channels)]) 49 | print('max : ', [torch.max(video[0][:,i,...]) for i in range(channels)]) 50 | print('min : ', [torch.min(video[0][:,i,...]) for i in range(channels)]) 51 | video = transforms(video) 52 | # show_img(video[0]) 53 | show_img(video[0],cmap='gray') 54 | 55 | 56 | def main(): 57 | test_db() 58 | 59 | if __name__=='__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test_tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | from torch.utils.tensorboard.writer import SummaryWriter 3 | from wav2mov.logger import TensorLogger 4 | 5 | 6 | def test(): 7 | logger = TensorLogger('logs') 8 | writer_1 = SummaryWriter('logs/exp1/test1') 9 | writer_2 = SummaryWriter('logs/exp1/test2') 10 | logger.add_writer('test1',writer_1) 11 | logger.add_writer('test2',writer_2) 12 | 13 | for i in range(10): 14 | logger.add_scalar('test1','aytala',i+1,i) 15 | print('adding 1 ',i) 16 | time.sleep(2) 17 | # writer_1.add_scalar('test1',i+2,i) 18 | 19 | for i in range(10): 20 | logger.add_scalar('test1','macha1',i*2,i) 21 | print('adding 2 ',i) 22 | time.sleep(2) 23 | 24 | if __name__ == '__main__': 25 | test() -------------------------------------------------------------------------------- /wav2mov/tests/old/test_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | # from wav2mov.models.generator import GeneratorBW 4 | from wav2mov.models.generator_v6 import GeneratorBW,Generator,Encoder,Decoder 5 | 6 | params = { 7 | "img_dim":[256,256], 8 | "retain_dim":True, 9 | "device":"cpu", 10 | "in_channels": 1, 11 | "enc_chs": [64, 128, 256, 512, 1024], 12 | "dec_chs": [1024, 512, 256, 128, 64], 13 | "up_chs": [1026, 512, 256, 128, 64], 14 | } 15 | 16 | 17 | 18 | class TestUnet(unittest.TestCase): 19 | def test_encoder(self): 20 | encoder = Encoder(chs=[params['in_channels']] +params['enc_chs']) 21 | image = torch.randn(1,1,256,256) 22 | out = encoder(image) 23 | print('testing encoder ') 24 | for layer in out: 25 | print(layer.shape) 26 | 27 | def test_decoder(self): 28 | encoder = Encoder(chs=[params['in_channels']] + params['enc_chs']) 29 | image = torch.randn(1,1,256,256) 30 | encoded = encoder(image)[::-1] 31 | 32 | audio_noise_encoded = torch.randn(1,2,8,8) 33 | encoded[0] = torch.cat([encoded[0],audio_noise_encoded],dim=1) 34 | decoder = Decoder(up_chs=params['up_chs'],dec_chs=params['dec_chs']) 35 | image = torch.randn(1,1,256,256) 36 | out = decoder(encoded[0],encoded[1:]) 37 | print('test_decoder : ',out.shape) 38 | 39 | def test_gen(self): 40 | gen = Generator(params) 41 | frame_img = torch.randn(1,1,256,256) 42 | audio_noise = torch.randn(1,2,8,8) 43 | out = gen(frame_img,audio_noise) 44 | self.assertEqual(out.shape,(1,1,256,256)) 45 | 46 | def test(): 47 | with torch.no_grad(): 48 | x = torch.randn(1,1,256,256) 49 | a = torch.randn(1,666) 50 | gen = GeneratorBW( { 51 | "device": "cpu", 52 | "in_channels": 1, 53 | "enc_chs": [64, 128, 256, 512, 1024], 54 | "dec_chs": [1024, 512, 256, 128, 64], 55 | "up_chs": [1026, 512, 256, 128, 64], 56 | "img_dim": [256, 256], 57 | "retain_dim": True, 58 | "lr": 1e-4 59 | }) 60 | gen.eval() 61 | assert gen(a,x).shape==(1,1,256,256) 62 | print("Test Passed " ,1) 63 | 64 | if __name__ == '__main__': 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import unittest 4 | 5 | from test_settings import TEST_DIR 6 | 7 | from wav2mov.utils.plots import save_gif 8 | 9 | 10 | class TestUtils(unittest.TestCase): 11 | def test_save_gif_1(self): 12 | images = torch.randn(10,3,256,256) 13 | save_gif(os.path.join(TEST_DIR,'logs','test_1.gif'),images) 14 | def test_save_gif_2(self): 15 | images = [] 16 | for _ in range(5): 17 | images.append(torch.randn(3,256,256)) 18 | save_gif(os.path.join(TEST_DIR, 'logs', 'test_2.gif'), images) 19 | 20 | def main(): 21 | unittest.main() 22 | 23 | if __name__=='__main__': 24 | main() 25 | -------------------------------------------------------------------------------- /wav2mov/tests/old/test_wav2mov.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from wav2mov.core.data.collates import get_batch_collate 5 | 6 | from wav2mov.config import get_config 7 | from wav2mov.params import params 8 | 9 | from wav2mov.models.wav2mov import Wav2Mov 10 | from wav2mov.main.engine import Engine 11 | from wav2mov.main.callbacks import LossMetersCallback,TimeTrackerCallback,ModelCheckpoint 12 | from wav2mov.main.data import get_dataloaders 13 | 14 | from wav2mov.main.options import Options,set_options 15 | 16 | logger = logging.getLogger(__file__) 17 | logger.setLevel(logging.DEBUG) 18 | 19 | # BATCH_SIZE = params['data']['batch_size'] 20 | BATCH_SIZE = 1 21 | 22 | NUM_FRAMES = 25 23 | 24 | def get_input(): 25 | audio = torch.randn(BATCH_SIZE,666*9,device='cuda') 26 | video = torch.randn(BATCH_SIZE,NUM_FRAMES,1,256,256,device='cuda') 27 | audio_frames = torch.randn(BATCH_SIZE,NUM_FRAMES,666+4*666,device='cuda') 28 | return audio,video,audio_frames 29 | 30 | def test(options,hparams,config,logger): 31 | engine = Engine(options,hparams,config,logger) 32 | model = Wav2Mov(hparams,config,logger) 33 | collate_fn = get_batch_collate(hparams['data']) 34 | dataloaders_ntuple = get_dataloaders(options,config,hparams, 35 | get_mean_std=False, 36 | collate_fn=collate_fn) 37 | callbacks = [LossMetersCallback(options,hparams,logger, 38 | verbose=True), 39 | TimeTrackerCallback(hparams,logger), 40 | ModelCheckpoint(model,hparams, 41 | save_every=5)] 42 | 43 | engine.run(model,dataloaders_ntuple,callbacks) 44 | 45 | if __name__ == '__main__': 46 | options = Options().parse() 47 | set_options(options,params) 48 | config = get_config(options.version) 49 | test(options,params,config,logger) -------------------------------------------------------------------------------- /wav2mov/tests/old/test_wav2mov_v7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from wav2mov.config import get_config 5 | from wav2mov.params import params 6 | 7 | from wav2mov.models.wav2mov_v7 import Wav2MovBW 8 | 9 | 10 | logger = logging.getLogger(__file__) 11 | logger.setLevel(logging.DEBUG) 12 | config = get_config('v_test') 13 | 14 | # BATCH_SIZE = params['data']['batch_size'] 15 | BATCH_SIZE = 1 16 | 17 | NUM_FRAMES = 25 18 | 19 | def get_input(): 20 | audio = torch.randn(BATCH_SIZE,666*9,device='cuda') 21 | video = torch.randn(BATCH_SIZE,NUM_FRAMES,1,256,256,device='cuda') 22 | audio_frames = torch.randn(BATCH_SIZE,NUM_FRAMES,666+4*666,device='cuda') 23 | return audio,video,audio_frames 24 | def test(): 25 | model = Wav2MovBW(config,params,logger) 26 | audio,video,audio_frames = get_input() 27 | 28 | model.on_train_start() 29 | for epoch in range(1): 30 | model.on_batch_start() 31 | batch_size,num_frames,channels,height,width = video.shape 32 | audio_frames = audio_frames.reshape(-1,audio_frames.shape[-1]) 33 | model.set_input(audio_frames,video.reshape(batch_size*num_frames,channels,height,width)) 34 | frame_img = video[:,NUM_FRAMES//2,...] 35 | frame_img = frame_img.repeat((1,NUM_FRAMES,1,1,1))#id and gen requires still image for each audio_frame as condition 36 | frame_img = frame_img.reshape(batch_size*num_frames,channels,height,width) 37 | # print('frame_img ',frame_img.shape) 38 | model.set_condition(frame_img) 39 | model.optimize_parameters() 40 | model.optimize_sequence(video,audio) 41 | 42 | 43 | if __name__ == '__main__': 44 | test() -------------------------------------------------------------------------------- /wav2mov/tests/old/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def get_input_of_shape(shape): 3 | return torch.randn(shape) 4 | 5 | 6 | def no_grad(func): 7 | def wrapper(*args, **kwargs): 8 | with torch.no_grad(): 9 | out = func(*args, **kwargs) 10 | return out 11 | return wrapper 12 | -------------------------------------------------------------------------------- /wav2mov/tests/test_audio_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wav2mov.params import params 3 | from wav2mov.config import get_config 4 | from wav2mov.models.generator.audio_encoder import AudioEnocoder 5 | 6 | from wav2mov.logger import get_module_level_logger 7 | logger = get_module_level_logger(__file__) 8 | config = get_config('test_sync') 9 | def test(): 10 | model = AudioEnocoder(params['gen']) 11 | audio = torch.randn(2,10,7,13) 12 | out = model(audio) 13 | logger.debug(f'out shape {out.shape}') 14 | assert(out.shape==(2,10,params['gen']['latent_dim_audio'])) 15 | 16 | def main(): 17 | test() 18 | 19 | if __name__=='__main__': 20 | main() 21 | -------------------------------------------------------------------------------- /wav2mov/tests/test_audio_frames.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import torch 4 | from matplotlib import pyplot as plt 5 | from scipy.io.wavfile import write 6 | 7 | from wav2mov.config import get_config 8 | from wav2mov.params import params as hparams 9 | 10 | from wav2mov.main.options import Options 11 | from wav2mov.core.data.collates import get_batch_collate 12 | from wav2mov.main.data import get_dataloaders as get_dl 13 | 14 | DIR = os.path.dirname(os.path.abspath(__file__)) 15 | def test(): 16 | options = Options().parse() 17 | config = get_config('test_audio_frames') 18 | collate_fn = get_batch_collate(hparams['data']) 19 | dls = get_dl(options,config,hparams,collate_fn=collate_fn) 20 | train_dl = dls.train 21 | sample = next(iter(train_dl)) 22 | audio,audio_frames,video = sample 23 | 24 | print(f'audio : {audio.shape}') 25 | print(f'audio frames : {audio_frames.shape}') 26 | print(f'video : {video.shape}') 27 | means,stds,maxs,mins = [],[],[],[] 28 | for audio_frame in audio_frames[0][:10]: 29 | maxs.append(torch.max(audio_frame).item()) 30 | mins.append(torch.min(audio_frame).item()) 31 | means.append(torch.mean(audio_frame).item()) 32 | stds.append(torch.std(audio_frame).item()) 33 | print('means ',means) 34 | print('stds ',stds) 35 | print('maxs ',maxs) 36 | print('mins ',mins) 37 | plt.imshow(audio_frames[0][30]) 38 | plt.show() 39 | 40 | if __name__ == '__main__': 41 | test() 42 | -------------------------------------------------------------------------------- /wav2mov/tests/test_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import unittest 4 | 5 | from wav2mov.models.generator.id_encoder import IdEncoder 6 | from wav2mov.models.generator.id_decoder import IdDecoder 7 | from wav2mov.models.generator.frame_generator import Generator 8 | from wav2mov.logger import get_module_level_logger 9 | from wav2mov.utils.plots import show_img 10 | logger = get_module_level_logger(__name__) 11 | 12 | class TestGen(unittest.TestCase): 13 | def test_id_encoder(self): 14 | logger.debug(f'id_encoder test') 15 | hparams = { 16 | 'in_channels':1, 17 | 'chs':[64,128,256,512,1024], 18 | 'latent_dim_id':(8,8) 19 | } 20 | 21 | id_encoder = IdEncoder(hparams) 22 | images = torch.randn(1,1,256,256) 23 | encoded,intermediates = id_encoder(images) 24 | req_h,req_w = 128,128 25 | for i,intermediate in enumerate(intermediates): 26 | self.assertEqual((req_h,req_w),intermediate.shape[-2:]) 27 | logger.debug(f'{i} : {intermediate.shape}') 28 | req_h,req_w = req_h//2,req_w//2 29 | logger.debug(f'final encoded {encoded.shape}') 30 | 31 | def test_id_decoder(self): 32 | logger.debug(f'id_decoder test') 33 | hparams = { 34 | 'in_channels':1, 35 | 'chs':[64,128,256,512,1024], 36 | 'latent_dim':16, 37 | 'latent_dim_id':(8,8) 38 | } 39 | id_enocder = IdEncoder(hparams) 40 | id_decoder = IdDecoder(hparams) 41 | images = torch.randn(1,1,256,256) 42 | encoded,intermediates= id_enocder(images) 43 | decoded = id_decoder(encoded,intermediates) 44 | self.assertEqual(decoded.shape,(1,1,256,256)) 45 | 46 | def test_generator(self): 47 | logger.debug(f'test generator') 48 | 49 | hparams = { 50 | 'in_channels':3, 51 | 'chs':[64,128,256,512,1024], 52 | 'latent_dim':16+256+10, 53 | 'latent_dim_id':(8,8), 54 | 'latent_dim_audio':256, 55 | 'latent_dim_noise':10, 56 | 'device':'cpu', 57 | 'lr':2e-4 58 | } 59 | 60 | gen = Generator(hparams) 61 | ref_frames = torch.randn(1,5,3,256,256) 62 | audio_frames = torch.zeros(1,5,5*666) 63 | out = gen(audio_frames,ref_frames) 64 | self.assertEqual(out.shape,(1,5,3,256,256)) 65 | 66 | # show_img(ref_frames[0][0],cmap='gray') 67 | # show_img(out[0][0],cmap='gray') 68 | 69 | test_on_real_img(gen) 70 | 71 | def test_on_real_img(gen): 72 | gen.eval() 73 | ref_frames = torch.from_numpy(np.load(r'E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\datasets\grid_dataset_256_256\s10_l_bbat9p\video_frames.npy')) 74 | show_img(ref_frames[0].permute(2,0,1)) 75 | ref_frames = ref_frames[0].permute(2,0,1).unsqueeze(0).unsqueeze(0).float() 76 | # logger.debug(ref_frames.shape) 77 | audio_frames = torch.randn(1,1,5*666) 78 | out = gen(audio_frames,ref_frames) 79 | show_img(out[0][0]) 80 | 81 | if __name__ == '__main__': 82 | unittest.main() -------------------------------------------------------------------------------- /wav2mov/tests/test_main_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wav2mov.core.data.collates import get_batch_collate 3 | from wav2mov.params import params 4 | from wav2mov.config import get_config 5 | from wav2mov.logger import get_module_level_logger 6 | from wav2mov.main.data import get_dataloaders 7 | from wav2mov.main.options import Options 8 | 9 | logger = get_module_level_logger(__name__) 10 | def test(options,hparams,config): 11 | collate_fn = get_batch_collate(hparams['data']) 12 | dataloader_pack = get_dataloaders(options,config,params,collate_fn=collate_fn) 13 | train_dl,test_dl = dataloader_pack 14 | logger.debug(f'train : {len(train_dl)} test : {len(test_dl)}') 15 | 16 | dl_iter = iter(train_dl) 17 | for _ in range(min(len(train_dl),10)): 18 | sample = next(dl_iter) 19 | audio,video = sample.audio,sample.video 20 | logger.debug(f'video {video.shape} : {torch.mean(video,dim=[0,1,3,4])} ,{torch.std(video,dim=[0,1,3,4])}') 21 | logger.debug(f'audio {audio.shape} : {torch.mean(audio)} ,{torch.std(audio)}') 22 | 23 | def main(): 24 | options = Options().parse() 25 | config = get_config(options.version) 26 | test(options,params,config) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /wav2mov/tests/test_seq_disc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wav2mov.params import params 3 | from wav2mov.models import SequenceDiscriminator 4 | 5 | from wav2mov.logger import get_module_level_logger 6 | logger = get_module_level_logger(__file__) 7 | 8 | def test(): 9 | model = SequenceDiscriminator(params['disc']['sequence_disc']) 10 | frames = torch.randn(1,10,3,256,256) 11 | out = model(frames) 12 | logger.debug(f'out shape {out.shape}') 13 | assert(out.shape==(1,256)) 14 | 15 | def main(): 16 | test() 17 | 18 | if __name__=='__main__': 19 | main() 20 | -------------------------------------------------------------------------------- /wav2mov/tests/test_sync_disc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wav2mov.params import params 3 | from wav2mov.config import get_config 4 | from wav2mov.models import SyncDiscriminator 5 | 6 | from wav2mov.logger import get_module_level_logger 7 | logger = get_module_level_logger(__file__) 8 | config = get_config('test_sync') 9 | def test(): 10 | model = SyncDiscriminator(params['disc']['sync_disc'],config) 11 | audio_frames = torch.randn(2,12,13) 12 | video_frames = torch.randn(2,5,3,256,256) 13 | out = model(audio_frames,video_frames)[0] 14 | logger.debug(f'out shape {out.shape}') 15 | assert(out.shape==(2,1)) 16 | 17 | def main(): 18 | test() 19 | 20 | if __name__=='__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /wav2mov/tests/test_video.py: -------------------------------------------------------------------------------- 1 | import dlib 2 | import cv2 3 | import os 4 | from wav2mov.logger import get_module_level_logger 5 | logger = get_module_level_logger(__name__) 6 | 7 | 8 | DIR = r'D:\dataset_lip\GRID\video_6sub_500' 9 | video = os.path.join(DIR,'s4_p_swbxza.mov') 10 | 11 | 12 | face_detector = dlib.get_frontal_face_detector() 13 | 14 | def convert_and_trim_bb(image, rect): 15 | """ from pyimagesearch 16 | https://www.pyimagesearch.com/2021/04/19/face-detection-with-dlib-hog-and-cnn/ 17 | """ 18 | # extract the starting and ending (x, y)-coordinates of the 19 | # bounding box 20 | startX = rect.left() 21 | startY = rect.top() 22 | endX = rect.right() 23 | endY = rect.bottom() 24 | # ensure the bounding box coordinates fall within the spatial 25 | # dimensions of the image 26 | startX = max(0, startX) 27 | startY = max(0, startY) 28 | endX = min(endX, image.shape[1]) 29 | endY = min(endY, image.shape[0]) 30 | # compute the width and height of the bounding box 31 | w = endX - startX 32 | h = endY - startY 33 | # return our bounding box coordinates 34 | return (startX, startY, w, h) 35 | def show_img(image_name,image): 36 | cv2.imshow(str(image_name),image) 37 | cv2.waitKey(0) 38 | cv2.destroyAllWindows() 39 | 40 | def get_video_frames(video_path,img_size:tuple): 41 | try: 42 | logger.debug(video_path) 43 | cap = cv2.VideoCapture(str(video_path)) 44 | if(not cap.isOpened()): 45 | logger.error("Cannot open video stream or file!") 46 | frames = [] 47 | i = 0 48 | while cap.isOpened(): 49 | frameId = cap.get(1) 50 | ret, image = cap.read() 51 | if not ret: 52 | break 53 | try: 54 | i+=1 55 | #image[top_row:bottom_row,left_column:right_column] 56 | image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)#other libraries including matplotlib,dlib expects image in RGB 57 | print(image.shape,i) 58 | face = face_detector(image)[0]#get first face object 59 | print(face.top(),face.bottom()) 60 | x,y,w,h = convert_and_trim_bb(image,face) 61 | # image = cv2.resize(image[y:y+h,x:x+w],img_size,interpolation=cv2.INTER_CUBIC) 62 | 63 | # show_img(len(frames),image) 64 | except Exception as e: 65 | # print(e) 66 | # logger.error(e) 67 | continue 68 | frames.append(image) 69 | return frames 70 | except Exception as e: 71 | logger.exception(e) 72 | 73 | 74 | if __name__ == '__main__': 75 | get_video_frames(video,(256,256)) -------------------------------------------------------------------------------- /wav2mov/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Package for util files""" -------------------------------------------------------------------------------- /wav2mov/utils/audio.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | class StridedAudio: 5 | def __init__(self,stride,coarticulation_factor=0): 6 | self.coarticulation_factor = coarticulation_factor 7 | self.stride = stride 8 | if isinstance(self.stride,float): 9 | self.stride = math.floor(self.stride) 10 | 11 | 12 | def get_frame_wrapper(self,audio): 13 | # padding = torch.tensor([]) 14 | # audio = torch.from_numpy(audio) 15 | 16 | if self.coarticulation_factor!=0: 17 | padding = torch.cat([torch.tensor([0]*self.stride) for _ in range(self.coarticulation_factor)],dim=0) 18 | audio = torch.cat([padding,audio,padding]) 19 | 20 | audio = audio.unsqueeze(0) 21 | 22 | def get_frame(idx): 23 | center_frame_pos = idx*self.stride 24 | # print(center_frame_pos) 25 | center_frame = audio[:,center_frame_pos:center_frame_pos+self.stride] 26 | if self.coarticulation_factor == 0: 27 | return center_frame, center_frame_pos+self.stride 28 | # curr_idx = center_frame_pos 29 | # print(f'audio shape {audio.shape} centerframe {center_frame.shape}') 30 | for i in range(self.coarticulation_factor): 31 | center_frame = torch.cat([audio[:,center_frame_pos-(i-1)*self.stride:center_frame_pos-i*self.stride],center_frame],dim=1) 32 | last_pos = 0 33 | for i in range(self.coarticulation_factor): 34 | center_frame = torch.cat([audio[:,center_frame_pos+(i+1)*self.stride:center_frame_pos+(i+2)*self.stride]]) 35 | last_pos = center_frame_pos+(i+2)*self.stride 36 | return center_frame,last_pos 37 | return get_frame 38 | 39 | 40 | class StridedAudioV2: 41 | def __init__(self, stride, coarticulation_factor=0,device='cpu'): 42 | """ 43 | 0.15s window ==>16k*0.15 points = 2400 44 | so padding on either side will be 2400//2 = 1200 45 | """ 46 | self.coarticulation_factor = coarticulation_factor 47 | self.stride = stride 48 | if isinstance(self.stride, float): 49 | self.stride = math.floor(self.stride) 50 | self.pad_len = self.coarticulation_factor*self.stride 51 | self.device = device 52 | 53 | def get_frame_wrapper(self, audio): 54 | # padding = torch.tensor([]) 55 | # audio = torch.from_numpy(audio) 56 | 57 | if self.coarticulation_factor != 0: 58 | 59 | padding = torch.zeros((audio.shape[0],self.pad_len),device=self.device) 60 | 61 | # padding = torch.cat([torch.tensor([0]*self.stride) 62 | # for _ in range(self.coarticulation_factor)], dim=0) 63 | 64 | audio = torch.cat([padding, audio,padding],dim=1) 65 | # print(audio.shape,self.padding.shape) 66 | 67 | 68 | def get_frame_from_idx(idx): 69 | center_idx= (idx) + (self.coarticulation_factor) 70 | start_pos = (center_idx-self.coarticulation_factor)*self.stride 71 | end_pos = (center_idx+self.coarticulation_factor+1)*self.stride 72 | return audio[:,start_pos:end_pos] 73 | 74 | def get_frames_from_range(start_idx,end_idx): 75 | start_pos = start_idx + self.coarticulation_factor 76 | start_idx = (start_pos-self.coarticulation_factor)*self.stride 77 | end_pos = end_idx+self.coarticulation_factor 78 | end_idx = (end_pos+self.coarticulation_factor+1)*self.stride 79 | return audio[:,start_idx:end_idx] 80 | 81 | return get_frame_from_idx,get_frames_from_range 82 | 83 | 84 | -------------------------------------------------------------------------------- /wav2mov/utils/cnn_shape_calc.py: -------------------------------------------------------------------------------- 1 | def out_d(in_d,k,p,s): 2 | """calculates out shape of 2d cnn 3 | 4 | Args: 5 | in_d (int): height or width 6 | k (int):kernel height or width 7 | 8 | p (int): padding along height or width 9 | 10 | s (int): stride along height or width 11 | 12 | Returns: 13 | int 14 | """ 15 | return ((in_d-k+2*p)/s)+1 16 | 17 | def out_d_transposed(in_d,k,p,s): 18 | """ find z and p 19 | insert z number of zeros between each row and column (2*(i-1)x2*(i-1)) 20 | pad with p number of zeros 21 | perform standard convolution 22 | """ 23 | return ((in_d-1)*s)+k-(2*p) -------------------------------------------------------------------------------- /wav2mov/utils/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | def move_dir(src,dest): 5 | os.makedirs(dest, exist_ok=True) 6 | print('src :',src) 7 | print('dest :',dest) 8 | shutil.move(src, dest) 9 | print(f'[SHUTIL]: Folder moved : src ({src}) dest ({dest})') 10 | 11 | def mov_log_dir(options,config): 12 | src = os.path.dirname(config['log_fullpath']) 13 | if os.path.exists(src): 14 | dest = os.path.join(os.path.join(config['base_dir'],'logs'), 15 | options.version) 16 | # os.makedirs(dest,exist_ok=True) 17 | move_dir(src,dest) 18 | 19 | def mov_out_dir(options,config): 20 | if not options.test in ['y', 'yes']: 21 | return 22 | run = os.path.basename(options.model_path).strip('gen_').split('.')[0] 23 | version_dir = os.path.dirname(os.path.dirname(options.model_path)) 24 | print('version dir',version_dir) 25 | version = os.path.basename(version_dir) 26 | if 'v' not in version: 27 | version = options.version 28 | src = os.path.join(config['out_dir'], run) 29 | if os.path.exists(src): 30 | dest = os.path.join(config['out_dir'], version) 31 | # os.makedirs(dest,exist_ok=True) 32 | move_dir(src, dest) 33 | -------------------------------------------------------------------------------- /wav2mov/utils/misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def log_run_time(logger): 4 | def timeit(func): 5 | def timed(*args, **kwargs): 6 | start_time = time.time() 7 | res = func(*args, **kwargs) 8 | end_time = time.time() 9 | time_taken = end_time-start_time 10 | logger.log(f'[TIME TAKEN] {func.__name__} took {time_taken:0.2f} seconds (or) {time_taken/60:0.2f} minutes',type="DEBUG") 11 | return res 12 | return timed 13 | return timeit 14 | 15 | 16 | 17 | class AverageMeter: 18 | def __init__(self,name,fmt=':0.4f'): 19 | 20 | self.name = name 21 | self.fmt = fmt 22 | self.reset() 23 | 24 | def reset(self): 25 | self.sum = 0 26 | self.count = 0 27 | self.avg = 0 28 | 29 | def _update_average(self): 30 | self.avg = self.sum/self.count 31 | 32 | def update(self,val,n): 33 | self.count +=n 34 | self.sum += val*n 35 | self._update_average() 36 | 37 | def __str__(self): 38 | fmt_str = '{name} : {avg' + self.fmt + '}' 39 | return fmt_str.format(**self.__dict__) 40 | 41 | def add(self,val): 42 | self.sum += val 43 | self._update_average() 44 | 45 | class AverageMetersList: 46 | def __init__(self, names, fmt=':0.4f'): 47 | self.meters = {name:AverageMeter(name,fmt) for name in names} 48 | 49 | def update(self,d:dict): 50 | """update the average meters 51 | 52 | Args: 53 | d (dict): key is the name of the meter and value is a tuple containing value and the multiplier 54 | 55 | """ 56 | for name,(value,n) in d.items(): 57 | if n==0: 58 | continue 59 | self.meters[name].update(value,n) 60 | 61 | 62 | def reset(self): 63 | for name in self.meters.keys(): 64 | self.meters[name].reset() 65 | 66 | def as_list(self): 67 | return self.meters.values() 68 | 69 | def average(self): 70 | return {name:meter.avg for name,meter in self.meters.items()} 71 | 72 | def get(self,name): 73 | if name not in self.meters: 74 | raise KeyError(f'{name} has not average meter initialized') 75 | return self.meters.get(name) 76 | 77 | def __str__(self): 78 | avg = self.average() 79 | return '\t'.join(f'{key}:{val:0.4f}' for key,val in avg.items()) 80 | 81 | class ProgressMeter: 82 | def __init__(self,steps,meters,prefix=''): 83 | self.batch_fmt_str = self._get_epoch_fmt_str(steps) 84 | self.meters = meters 85 | self.prefix = prefix 86 | 87 | def get_display_str(self,step): 88 | entries = [self.prefix + self.batch_fmt_str.format(step)] 89 | entries += [str(meter) for meter in self.meters] 90 | return '\t'.join(entries) 91 | 92 | def _get_epoch_fmt_str(self,steps): 93 | num_digits = len(str(steps//1)) 94 | fmt = '{:' + str(num_digits) + 'd}' 95 | return '[' + fmt +'/'+ fmt.format(steps) + ']' 96 | 97 | 98 | 99 | def get_duration_in_minutes_seconds(self,duration): 100 | return duration//60,duration -------------------------------------------------------------------------------- /wav2mov/utils/plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.functional import Tensor 4 | from torchvision.io import write_video 5 | 6 | import imageio 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | 10 | from moviepy import editor as mpy 11 | from scipy.io.wavfile import write as write_audio 12 | 13 | import warnings 14 | warnings.filterwarnings( "ignore", module = r"matplotlib\..*" ) 15 | 16 | from wav2mov.logger import get_module_level_logger 17 | logger = get_module_level_logger(__name__) 18 | 19 | def no_grad_wrapper(fn): 20 | def wrapper(*args,**kwargs): 21 | with torch.no_grad(): 22 | return fn(*args,**kwargs) 23 | return wrapper 24 | 25 | @no_grad_wrapper 26 | def show_img(img,cmap='viridis'): 27 | if isinstance(img,np.ndarray): 28 | img_np = img 29 | else: 30 | if len(img.shape)>3: 31 | img = img.squeeze(0) 32 | img_np = img.cpu().numpy() 33 | img_np = np.transpose(img_np, (1, 2, 0)) 34 | # print(img_np.shape) 35 | # print(img_np) 36 | if img_np.shape[2]==1:#if single channel 37 | img_np = img_np.squeeze(2) 38 | plt.imshow(img_np,cmap=cmap) 39 | plt.show() 40 | # 41 | 42 | def save_gif(gif_path,images,duration=0.5): 43 | """creates gif 44 | 45 | Args: 46 | gif_path (str): path where gif file should be saved 47 | images (torch.funcitonal.Tensor): tensor of images of shape (N,C,H,W) 48 | """ 49 | if isinstance(images,Tensor): 50 | images = images.numpy() 51 | 52 | 53 | images = images.transpose(0,2,3,1).astype('uint8') 54 | 55 | imageio.mimsave(gif_path,images,duration=0.5) 56 | 57 | def save_video(hparams,video_path,audio,video_frames): 58 | """ 59 | audio_frames : C,S 60 | video_frames : T,C,H,W 61 | """ 62 | if video_frames.shape[1]==1: 63 | video_frames = video_frames.repeat(1,3,1,1) 64 | logger.debug(f'video frames :{video_frames.shape}, audio : {audio.shape}') 65 | video_frames = video_frames.to(torch.uint8) 66 | write_video(filename= video_path, 67 | video_array = video_frames.permute(0,2,3,1), 68 | fps = hparams['video_fps'], 69 | video_codec="h264", 70 | # audio_array= audio, 71 | # audio_fps = hparams['audio_sf'], 72 | # audio_codec = 'mp3' 73 | ) 74 | dir_name = os.path.dirname(video_path) 75 | temp_audio_path = os.path.join(dir_name,'temp','temp_audio.wav') 76 | os.makedirs(os.path.dirname(temp_audio_path),exist_ok=True) 77 | write_audio(temp_audio_path,hparams['audio_sf'],audio.cpu().numpy().reshape(-1)) 78 | 79 | video_clip = mpy.VideoFileClip(video_path) 80 | audio_clip = mpy.AudioFileClip(temp_audio_path) 81 | video_clip.audio = audio_clip 82 | video_clip.write_videofile(os.path.join(dir_name,'fake_video_with_audio.avi'),fps=hparams['video_fps'],codec='png') 83 | 84 | def save_video_v2(hparams,filepath,audio,video_frames): 85 | def get_video_frames(idx): 86 | idx = int(idx) 87 | # logger.debug(f'{video_frames.shape} ,{video_frames[idx].shape}') 88 | frame = video_frames[idx].permute(1,2,0).squeeze() 89 | return frame.cpu().numpy().astype(np.uint8) 90 | 91 | logger.debug('saving video please wait...') 92 | num_frames = video_frames.shape[0] 93 | video_fps = hparams['data']['video_fps'] 94 | audio_sf = hparams['data']['audio_sf'] 95 | duration = audio.squeeze().shape[0]/audio_sf 96 | # duration = 10 97 | # duration = math.ceil(num_frames/video_fps) 98 | logger.debug(f'duation {duration} seconds') 99 | dir_name = os.path.dirname(filepath) 100 | temp_audio_path = os.path.join(dir_name,'temp','temp_audio.wav') 101 | os.makedirs(os.path.dirname(temp_audio_path),exist_ok=True) 102 | # print(audio.cpu().numpy().reshape(-1).shape) 103 | write_audio(temp_audio_path,audio_sf,audio.cpu().numpy().reshape(-1)) 104 | 105 | video_clip = mpy.VideoClip(make_frame=get_video_frames,duration=duration) 106 | audio_clip = mpy.AudioFileClip(temp_audio_path,fps=audio_sf) 107 | video_clip = video_clip.set_audio(audio_clip) 108 | # print(filepath,video_clip.write_videofile.__doc__) 109 | video_clip.write_videofile( filepath, 110 | fps=video_fps, 111 | codec="png", 112 | bitrate=None, 113 | audio=True, 114 | audio_fps=audio_sf, 115 | preset="medium", 116 | # audio_nbytes=4, 117 | audio_codec=None, 118 | audio_bitrate=None, 119 | # audio_bufsize=2000, 120 | temp_audiofile=None, 121 | # temp_audiofile_path="", 122 | remove_temp=True, 123 | write_logfile=False, 124 | threads=None, 125 | ffmpeg_params=['-s','256x256','-aspect','1:1'], 126 | logger="bar", 127 | # pixel_format='gray 128 | ) 129 | --------------------------------------------------------------------------------