├── .gitignore
├── .gitmodules
├── README.md
├── setup.py
├── wav2mov-docs
├── gan_setup.PNG
└── gen_arch.PNG
└── wav2mov
├── __init__.py
├── config.json
├── config.py
├── core
├── data
│ ├── __init__.py
│ ├── collates.py
│ ├── dataloaders.py
│ ├── datasets.py
│ ├── note.txt
│ ├── raw_datasets.py
│ ├── transforms.py
│ └── utils.py
├── engine
│ ├── __init__.py
│ ├── callbacks.py
│ └── engine.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ └── template_model.py
└── utils
│ ├── __init__.py
│ ├── average_meter.py
│ ├── checkpoints.py
│ ├── logger.py
│ ├── misc.py
│ └── os_utils.py
├── datasets
└── create_file_list.py
├── inference
├── __init__.py
├── audio_utils.py
├── generate.py
├── generate.sh
├── image_utils.py
├── model_utils.py
├── models
│ ├── __init__.py
│ ├── generator
│ │ ├── __init__.py
│ │ ├── audio_encoder.py
│ │ ├── frame_generator.py
│ │ ├── id_decoder.py
│ │ └── id_encoder.py
│ ├── layers
│ │ ├── conv_layers.py
│ │ └── debug_layers.py
│ └── utils.py
├── params.py
├── quantize.py
└── utils.py
├── logger.py
├── losses
├── ReadMe.md
├── __init__.py
├── gan_loss.py
├── l1_loss.py
└── sync_loss.py
├── main
├── callbacks.py
├── data.py
├── engine.py
├── main.py
├── options.py
├── preprocess.py
├── test.py
├── train.py
├── trained.txt
└── validate_params.py
├── models
├── README.md
├── __init__.py
├── discriminators
│ ├── identity_discriminator.py
│ ├── patch_disc.py
│ ├── sequence_discriminator.py
│ ├── sync_discriminator.py
│ └── utils.py
├── generator
│ ├── audio_encoder.py
│ ├── frame_generator.py
│ ├── id_decoder.py
│ ├── id_encoder.py
│ └── noise_encoder.py
├── layers
│ ├── conv_layers.py
│ └── debug_layers.py
├── utils.py
├── wav2mov_inferencer.py
├── wav2mov_template.py
└── wav2mov_trainer.py
├── params.json
├── params.py
├── plans
├── README.md
├── extras.md
├── images
│ ├── components.png
│ ├── gen_arch.png
│ ├── plan_v1.png
│ └── system.png
└── observations.md
├── preprocess.bat
├── preprocess.sh
├── pylintrc
├── requirements.txt
├── run.bat
├── run.sh
├── run_sync_expert.sh
├── settings.py
├── test.bat
├── test.sh
├── tests
├── old
│ ├── log.json
│ ├── test.py
│ ├── test.txt
│ ├── test_3d_cnn.py
│ ├── test_audio_util.py
│ ├── test_batching.py
│ ├── test_config.py
│ ├── test_dataset.py
│ ├── test_discriminators.py
│ ├── test_file_logger.py
│ ├── test_logger.py
│ ├── test_main_data.py
│ ├── test_module_level_logger.py
│ ├── test_packing.py
│ ├── test_settings.py
│ ├── test_stored_video.py
│ ├── test_tensorboard_logger.py
│ ├── test_unet.py
│ ├── test_utils.py
│ ├── test_wav2mov.py
│ ├── test_wav2mov_v7.py
│ └── utils.py
├── test_audio_encoder.py
├── test_audio_frames.py
├── test_generator.py
├── test_main_dataloader.py
├── test_seq_disc.py
├── test_sync_disc.py
└── test_video.py
└── utils
├── __init__.py
├── audio.py
├── cnn_shape_calc.py
├── files.py
├── misc.py
└── plots.py
/.gitignore:
--------------------------------------------------------------------------------
1 | shape_predictor_68_face_landmarks.dat
2 | wav2mov/datasets/*
3 | !wav2mov/datasets/create_file_list.py
4 | wav2mov/runs/
5 | wav2mov/logs/
6 | wav2mov/tests/res
7 | wav2mov/plans/*.txt
8 | *.png
9 | !wav2mov/plans/images/*.png
10 | !wav2mov-docs/*.png
11 | *.gif
12 | *.pt
13 | *.jfif
14 | *.wav
15 | *.avi
16 | archieves/
17 | # Vs code
18 | .vscode
19 | #vim files
20 | *.vimrc
21 | *.*~
22 | # Byte-compiled / optimized / DLL files
23 | __pycache__/
24 | *.py[cod]
25 | *$py.class
26 |
27 | # C extensions
28 | *.so
29 |
30 | # Distribution / packaging
31 | .Python
32 | build/
33 | develop-eggs/
34 | dist/
35 | downloads/
36 | eggs/
37 | .eggs/
38 | lib/
39 | lib64/
40 | parts/
41 | sdist/
42 | var/
43 | wheels/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 | cover/
74 |
75 | # Translations
76 | *.mo
77 | *.pot
78 |
79 | # Django stuff:
80 | *.log
81 | local_settings.py
82 | db.sqlite3
83 | db.sqlite3-journal
84 |
85 | # Flask stuff:
86 | instance/
87 | .webassets-cache
88 |
89 | # Scrapy stuff:
90 | .scrapy
91 |
92 | # Sphinx documentation
93 | docs/_build/
94 |
95 | # PyBuilder
96 | .pybuilder/
97 | target/
98 |
99 | # Jupyter Notebook
100 | .ipynb_checkpoints
101 |
102 | # IPython
103 | profile_default/
104 | ipython_config.py
105 |
106 | # pyenv
107 | # For a library or package, you might want to ignore these files since the code is
108 | # intended to run in multiple environments; otherwise, check them in:
109 | # .python-version
110 |
111 | # pipenv
112 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
113 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
114 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
115 | # install all needed dependencies.
116 | #Pipfile.lock
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [](https://wav2mov.vercel.app)
2 |
3 | ## Speech To Facial Animation Using GANs
4 |
5 |
6 | [](https://www.python.org/) [](https://pytorch.org/) [](#1)
7 |
8 | This repo contains the pytorch implementation of achieving facial animation from given face image and speech input using Generative Adversarial Nets (See [References](#1)).
9 |
10 |
11 | ## Results
12 |
13 | Some of the generated videos are found [here](https://wav2mov-examples.vercel.app/examples).
14 |
15 | ## Implementation
16 | ### GAN setup
17 | 
18 | 
19 | ## References
20 |
21 | [1] Generative Adversarial Nets
22 | ```bibtex
23 | @article{goodfellow2014generative,
24 | title={Generative adversarial networks},
25 | author={Goodfellow, Ian J and Pouget-Abadie, Jean and Mirza, Mehdi and Xu, Bing and Warde-Farley, David and Ozair, Sherjil and Courville, Aaron and Bengio, Yoshua},
26 | journal={arXiv preprint arXiv:1406.2661},
27 | year={2014}
28 | }
29 | ```
30 |
31 | [2] The Audio-Visual Lombard Grid Speech Corpus
32 |
33 | [](http://spandh.dcs.shef.ac.uk/avlombard/)
34 |
35 | ```bibtex
36 | @article{Alghamdi_2018,
37 | doi = {10.1121/1.5042758},
38 | url = {https://doi.org/10.1121%2F1.5042758},
39 | year = 2018,
40 | month = {jun},
41 | publisher = {Acoustical Society of America ({ASA})},
42 | volume = {143},
43 | number = {6},
44 | pages = {EL523--EL529},
45 | author = {Najwa Alghamdi and Steve Maddock and Ricard Marxer and Jon Barker and Guy J. Brown},
46 | title = {A corpus of audio-visual Lombard speech with frontal and profile views},
47 | journal = {The Journal of the Acoustical Society of America}
48 | }
49 | ```
50 |
51 | [3] Realistic Facial Animation using GANs
52 |
53 | [](https://github.com/DinoMan/speech-driven-animation)
54 |
55 | ```bibtex
56 | @article{Vougioukas_2019,
57 | doi = {10.1007/s11263-019-01251-8},
58 | url = {https://doi.org/10.1007%2Fs11263-019-01251-8},
59 | year = 2019,
60 | month = {oct},
61 | publisher = {Springer Science and Business Media {LLC}},
62 | volume = {128},
63 | number = {5},
64 | pages = {1398--1413},
65 | author = {Konstantinos Vougioukas and Stavros Petridis and Maja Pantic},
66 | title = {Realistic Speech-Driven Facial Animation with {GANs}},
67 | journal = {International Journal of Computer Vision}
68 | }
69 | ```
70 |
71 | [4] End to End Facial Animation using Temporal GANs
72 | ```bibtex
73 | @article{vougioukas2018end,
74 | title={End-to-end speech-driven facial animation with temporal gans},
75 | author={Vougioukas, Konstantinos and Petridis, Stavros and Pantic, Maja},
76 | journal={arXiv preprint arXiv:1805.09313},
77 | year={2018}
78 | }
79 | ```
80 |
81 | [5] A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild
82 |
83 |
84 | [](https://github.com/Rudrabha/Wav2Lip)
85 | ```bibtex
86 | @inproceedings{10.1145/3394171.3413532,
87 | author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.},
88 | title = {A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild},
89 | year = {2020},
90 | isbn = {9781450379885},
91 | publisher = {Association for Computing Machinery},
92 | address = {New York, NY, USA},
93 | url = {https://doi.org/10.1145/3394171.3413532},
94 | doi = {10.1145/3394171.3413532},
95 | booktitle = {Proceedings of the 28th ACM International Conference on Multimedia},
96 | pages = {484–492},
97 | numpages = {9},
98 | keywords = {lip sync, talking face generation, video generation},
99 | location = {Seattle, WA, USA},
100 | series = {MM '20}
101 | }
102 | ```
103 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup,find_packages
2 | setup(name="wav2mov",version="1.0.0",packages= find_packages())
3 |
--------------------------------------------------------------------------------
/wav2mov-docs/gan_setup.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov-docs/gan_setup.PNG
--------------------------------------------------------------------------------
/wav2mov-docs/gen_arch.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov-docs/gen_arch.PNG
--------------------------------------------------------------------------------
/wav2mov/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/__init__.py
--------------------------------------------------------------------------------
/wav2mov/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "fixed": {
3 | "seed": 10,
4 | "grid_dataset_dir": "D:\\dataset_lip\\GRID"
5 | },
6 | "runtime": {
7 | "runs_dir": "%(base_dir)s\\runs\\%(v)s\\%(version)s",
8 | "log_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\%(log_filename)s.log",
9 | "train_test_dataset_dir": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14",
10 | "filenames_txt": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14\\filenames.txt",
11 | "filenames_train_txt": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14\\filenames_train.txt",
12 | "filenames_test_txt": "%(base_dir)s\\datasets\\grid_dataset_a5_500_a10to14\\filenames_test.txt",
13 | "gen_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\gen_%(version)s.pt",
14 | "seq_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\seq_disc_%(version)s.pt",
15 | "sync_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\sync_disc_%(version)s.pt",
16 | "sync_expert_checkpoint_fullpath": "/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/sync_expert/Run_11_7_2021__13_36",
17 | "id_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\id_disc_%(version)s.pt",
18 | "params_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\hparams_%(version)s.json",
19 | "optim_gen_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_gen_%(version)s.pt",
20 | "optim_seq_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_seq_disc_%(version)s.pt",
21 | "optim_sync_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_sync_disc_%(version)s.pt",
22 | "optim_id_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_id_disc_%(version)s.pt",
23 | "optim_params_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\optim_hparams_%(version)s.json",
24 | "scheduler_gen_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_gen_%(version)s.pt",
25 | "scheduler_seq_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_seq_disc_%(version)s.pt",
26 | "scheduler_sync_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_sync_disc_%(version)s.pt",
27 | "scheduler_id_disc_checkpoint_fullpath": "%(base_dir)s\\runs\\%(v)s\\%(version)s\\scheduler_id_disc_%(version)s.pt",
28 | "out_dir": "%(base_dir)s\\out\\%(v)s"
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/wav2mov/config.py:
--------------------------------------------------------------------------------
1 | """Config"""
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import pytz
7 | from datetime import datetime
8 | from wav2mov.logger import get_module_level_logger
9 |
10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11 | logger = get_module_level_logger(__name__)
12 | logger.setLevel(logging.WARNING)
13 |
14 | def get_curr_run_str():
15 | now = datetime.now().astimezone(pytz.timezone("Asia/Calcutta"))
16 | date,time = now.date(),now.time()
17 | day,month,year = date.day,date.month,date.year
18 | hour , minutes = time.hour,time.minute
19 | return 'Run_{}_{}_{}__{}_{}'.format(day,month,year,hour,minutes)
20 |
21 | class Config :
22 | def __init__(self,v):
23 | self.vals = {'base_dir':BASE_DIR}
24 | self.version = get_curr_run_str()
25 | #these two are used to decide whether to create a path using os.makedirs
26 | self.fixed_paths = set()
27 | self.runtime_paths = set()
28 | self.v = v
29 |
30 | def _rectify_paths(self,val):
31 | return val % {'base_dir': self.vals['base_dir'],
32 | 'version': self.version,
33 | 'v':self.v,
34 | 'checkpoint_filename': 'checkpoint_'+self.version,
35 | 'log_filename': 'log_' + self.version
36 | }
37 |
38 | def _flatten_dict(self,d):
39 | flattened= {}
40 | for k,v in d.items():
41 | if isinstance(v,dict):
42 | inner_d = self._flatten_dict(v)
43 | for k_inner in inner_d.keys():
44 | if k== 'fixed': self.fixed_paths.add(k_inner)
45 | else: self.runtime_paths.add(k_inner)
46 |
47 | flattened = {**flattened,**inner_d}
48 | else:
49 | flattened[k] = v
50 | return flattened
51 |
52 | def _update_vals_from_dict(self,d:dict):
53 | d = self._flatten_dict(d)
54 | for key,val in d.items():
55 | if isinstance(val,str):
56 | value = self._rectify_paths(val)
57 |
58 | else:
59 | value = val
60 | self.vals[key] = value
61 |
62 | @classmethod
63 | def from_json(cls,json_file_fullpath,v):
64 | with open(json_file_fullpath,'r') as file:
65 | configs = json.load(file)
66 | obj = cls(v)
67 | obj._update_vals_from_dict(configs)
68 | return obj
69 |
70 | def update(self,key,value):
71 | if key in self.vals:
72 | logger.warning(f'Updating existing parameter {key} : changing from {self.vals[key]} with {value}')
73 | self.vals[key] = value
74 |
75 | def __getitem__(self,item):
76 | if item not in self.vals:
77 | logger.error(f'{self.__class__.__name__} object has no key called {item}')
78 | raise KeyError('No key called ', item)
79 | if item in self.fixed_paths:
80 | return self.vals[item]
81 | if os.sep != '\\':
82 | self.vals[item] = re.sub(r'(\\)+', os.sep, self.vals[item])
83 | if 'fullpath' in item or 'dir' in item:
84 | path = os.path.dirname(self.vals[item]) if '.' in os.path.basename(self.vals[item]) else self.vals[item]
85 | os.makedirs(path,exist_ok=True)
86 | logger.debug(f'accessing {self.vals[item]}')
87 | return self.vals[item]
88 |
89 | def get_config(v):
90 | config = Config.from_json(os.path.join(BASE_DIR,'config.json'),v)
91 | return config
92 |
93 | if __name__=='__main__':
94 | pass
95 |
--------------------------------------------------------------------------------
/wav2mov/core/data/__init__.py:
--------------------------------------------------------------------------------
1 | """ Datasets utils
2 | """
3 | from .raw_datasets import RawDataset,GridDataset,RavdessDataset
4 |
5 |
--------------------------------------------------------------------------------
/wav2mov/core/data/collates.py:
--------------------------------------------------------------------------------
1 | """ collate functions """
2 | from collections import namedtuple
3 | import torch
4 | from wav2mov.core.data.utils import Sample,SampleWithFrames
5 |
6 | from wav2mov.core.utils.logger import get_module_level_logger
7 | logger = get_module_level_logger(__name__)
8 | from wav2mov.core.data.utils import AudioUtil
9 | Lens = namedtuple('lens', ('audio', 'video'))
10 |
11 | def get_frames_limit(audio_lens,video_lens,stride):
12 | audio_frames_lens = [audio_len//stride for audio_len in audio_lens]
13 | return min(min(audio_frames_lens),min(video_lens))
14 |
15 | def video_frange_start(num_frames,req_frames):
16 | return (num_frames-req_frames)//2
17 |
18 | def video_frange_end(num_frames,req_frames):
19 | return (num_frames + req_frames)//2
20 |
21 | def get_batch_collate(hparams):
22 | stride = hparams['audio_sf']// hparams['video_fps']
23 | audio_util = AudioUtil(hparams['audio_sf'],hparams['coarticulation_factor'],stride)
24 | def collate_fn(batch):
25 | videos = [(sample.video,sample.video.shape[0]) for sample in batch]
26 | videos,video_lens = list(zip(*videos))
27 |
28 | audios = [(sample.audio.unsqueeze(0),sample.audio.shape[0]) for sample in batch]
29 | audios,audio_lens = list(zip(*audios))
30 |
31 | req_frames = get_frames_limit(audio_lens,video_lens,stride)
32 | ranges = [(video_frange_start(video.shape[0],req_frames),video_frange_end(video.shape[0],req_frames)) for video in videos]
33 | #[<------- total_frames-------->]
34 | #[<---><---req_frames----><---->]
35 | # ^ ^
36 | # frange_start frange_end
37 | videos = [video[ranges[i][0]:ranges[i][1],... ].unsqueeze(0) for i,video in enumerate(videos)]
38 |
39 | audio_frames = [audio_util.get_audio_frames(audio,num_frames=req_frames,get_mfccs=True).unsqueeze(0) for i,audio in enumerate(audios)]
40 | audios = [audio_util.get_limited_audio(audio,num_frames=req_frames,get_mfccs=False) for audio in audios]
41 | return SampleWithFrames(torch.cat(audios),torch.cat(audio_frames),torch.cat(videos))
42 |
43 | return collate_fn
--------------------------------------------------------------------------------
/wav2mov/core/data/dataloaders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import NamedTuple
3 | from collections import namedtuple
4 | from torch.utils.data import DataLoader,Dataset,random_split
5 |
6 |
7 | class DataloadersPack:
8 |
9 |
10 | def get_dataloaders(self,dataset:Dataset,splits,batch_size,seed):
11 |
12 | train_set,test_set,validation_set = self.__split_dataset(dataset,splits,seed)
13 | train_dataloader = self.__get_dataloader(train_set,batch_size)
14 | test_dataloader = self.__get_dataloader(test_set,batch_size)
15 | validation_dataloader = self.__get_dataloader(validation_set,batch_size)
16 |
17 | dataloaders_pack = namedtuple('dataloaders_pack',['train','validation','test'])
18 | return dataloaders_pack(train_dataloader,validation_dataloader,test_dataloader)
19 |
20 | def __get_dataloader(self,dataset,batch_size):
21 |
22 | return DataLoader(dataset,
23 | batch_size=batch_size,#,shuffle=True)
24 | num_workers=2, #[WARNING] runs project.py this number of times
25 | shuffle=True)
26 |
27 | def __split_dataset(self,dataset,splits,seed):
28 | N = len(dataset)
29 | train_size,test_size,val_size = self.__get_split_sizes(N,splits)
30 | return random_split(dataset,[train_size,test_size,val_size],
31 | generator=torch.Generator().manual_seed(seed))
32 |
33 | def __get_split_sizes(self,N,splits):
34 | train_size = int((splits.train_size/10)*N)
35 | test_size = int((splits.test_size/10)*N)
36 | validation_size = N-train_size-test_size
37 | return train_size,test_size,validation_size
38 |
39 |
40 |
41 |
42 | def splitted_dataloaders(dataset:Dataset,splits:namedtuple,batch_size,seed)->NamedTuple:
43 |
44 | return DataloadersPack().get_dataloaders(dataset,splits,batch_size,seed)
--------------------------------------------------------------------------------
/wav2mov/core/data/datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import os
4 | import torch
5 | from torch.utils.data import Dataset
6 | from collections import namedtuple
7 |
8 | from wav2mov.core.data.utils import Sample
9 |
10 |
11 | class AudioVideoDataset(Dataset):
12 | """
13 | Dataset for audio and video numpy files
14 | """
15 | def __init__(self,
16 | root_dir,
17 | filenames_text_filepath,
18 | audio_sf,
19 | video_fps,
20 | num_videos,
21 | transform=None):
22 | self.root_dir = root_dir
23 | self.audio_fs = audio_sf
24 | self.video_fps = video_fps
25 | self.stride = math.floor(self.audio_fs / self.video_fps)
26 | self.transform = transform
27 | self.filenames = []
28 | with open(filenames_text_filepath,'r') as file:
29 | self.filenames = file.read().split('\n')
30 | self.filenames = self.filenames[:num_videos]
31 |
32 | def __len__(self):
33 | return len(self.filenames)
34 |
35 | def __load_from_np(self,path):
36 | return np.load(path)
37 |
38 | def __get_folder_name(self,curr_file):
39 | return os.path.join(self.root_dir,curr_file)
40 |
41 | def get_audio(self,idx):
42 | folder = self.__get_folder_name(self.filenames[idx])
43 | audio_filepath = os.path.join(folder,'audio.npy')
44 | return self.__load_from_np(audio_filepath)
45 |
46 | def get_video_frames(self,idx):
47 | folder = self.__get_folder_name(self.filenames[idx])
48 | video_filepath = os.path.join(folder,'video_frames.npy')
49 | return self.__load_from_np(video_filepath)
50 |
51 | def __getitem__(self,idx):
52 | if torch.is_tensor(idx):
53 | idx = idx.tolist()
54 |
55 | audio = self.get_audio(idx)
56 | video = self.get_video_frames(idx)
57 | audio = torch.from_numpy(audio)
58 | video = torch.from_numpy(video).permute(0,3,1,2)#F,H,W,C ==> F,C,H,W
59 | video = video/255
60 | sample = Sample(audio,video)
61 | if self.transform:
62 | sample = self.transform(sample)
63 | return sample
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/wav2mov/core/data/note.txt:
--------------------------------------------------------------------------------
1 | original GRID DATASET has videos of shape (H,W,C)==(480, 720, 3)
--------------------------------------------------------------------------------
/wav2mov/core/data/raw_datasets.py:
--------------------------------------------------------------------------------
1 | """Raw Datasets Classes : Grid and Ravdess"""
2 | import os
3 | from abc import abstractmethod
4 | from collections import namedtuple
5 | from tqdm import tqdm
6 |
7 | from wav2mov.core.data.utils import get_video_frames,get_audio,get_audio_from_video
8 |
9 |
10 | SampleContainer = namedtuple('SampleContainer',['audio','video'])
11 | Sample = namedtuple('Sample',['path','val'])
12 |
13 |
14 |
15 | class RawDataset:
16 | """
17 | Description of RawDataset
18 |
19 | Args:
20 | location (undefined):
21 | audio_sampling_rate=None (undefined):
22 | video_frame_rate=None (undefined):
23 | website=None (undefined):
24 |
25 | """
26 |
27 |
28 | def __init__(self,
29 | location,
30 | audio_sampling_rate=None,
31 | video_frame_rate=None,
32 | samples_count = None,
33 | website=None):
34 | """
35 | Description of __init__
36 |
37 | Args:
38 | self (undefined):
39 | location (undefined):
40 | audio_sampling_rate=None (undefined):
41 | video_frame_rate=None (undefined):
42 | samples_count=None
43 | website=None (undefined):
44 |
45 | """
46 | self.root_location = location
47 | self.audio_sampling_rate = audio_sampling_rate
48 | self.video_frame_rate = video_frame_rate
49 | self.samples_count = samples_count
50 | self.website = website
51 |
52 | def info(self):
53 | return print(f'self.__class__.__name__ : {self.website}')
54 |
55 | @abstractmethod
56 | def generator(self):
57 | raise NotImplementedError(f'{self.__class__.__name__} should implement generator method')
58 |
59 | class RavdessDataset(RawDataset):
60 |
61 | """Dataset class for RAVDESS dataset.
62 |
63 | [LINK] https://zenodo.org/record/1188976#.YBlpZOgzZPY
64 |
65 | Folder structure
66 | ------------------
67 |
68 | `
69 | root_location:folder
70 | |
71 | |___actor1:folder
72 | | |
73 | | |_ video 1:file
74 | | |_ video 2:file
75 | | |_ ...
76 | |
77 | |___actor2:folder
78 | |
79 | |_ video 1:file
80 | |...
81 | `
82 |
83 | """
84 | link = r"https://zenodo.org/record/1188976#.YBlpZOgzZPY"
85 | name = 'ravdess_dataset'
86 | def __init__(self,location,audio_sampling_rate,video_frame_rate,samples_count=None,link=None):
87 | super().__init__(location,audio_sampling_rate,video_frame_rate,samples_count,link)
88 | self.sub_folders = (folder for folder in os.listdir(location)
89 | if os.path.isdir(os.path.join(location,folder)))
90 |
91 |
92 |
93 | def generator(self)->Sample:
94 | """yields audio and video filenames one by one
95 |
96 | Returns:
97 | Sample: containes audio and video filpaths
98 |
99 | Yields:
100 | Iterator[Sample]:
101 |
102 | Examples:
103 |
104 | >>>dataset = RavdessDataset(...)
105 | >>>for sample in dataset.generator():
106 | >>> audio_filepath = sample.audio
107 |
108 | """
109 | for actor in self.sub_folders:
110 | actor_path = os.path.join(self.root_location,actor)
111 | videos = (video for _,_,video in os.walk(actor_path))
112 | limit = self.samples_count if self.samples_count!=None else len(videos[0])
113 | for video in tqdm(videos[:limit],
114 | desc="Ravdess Dataset",
115 | total=len(videos),ascii=True,colour="green"):
116 | audio_path = get_audio_from_video(os.path.join(actor_path,video))
117 | video_path = os.path.join(actor_path,video)
118 | yield SampleContainer(video=Sample(video_path,val=None),
119 | audio=Sample(audio_path,val=None))
120 |
121 |
122 | class GridDataset(RawDataset):
123 | """Dataset class for Grid dataset.
124 |
125 | [LINK] http://spandh.dcs.shef.ac.uk/avlombard/
126 |
127 | [LINK] [PAPER] https://asa.scitation.org/doi/10.1121/1.5042758
128 |
129 | `Folder structure`::
130 |
131 | root_location:folder
132 | |
133 | |___audio:folder
134 | | |
135 | | |_ .wav:file
136 | | |_ .wav:file
137 | | |_ ...
138 | |
139 | |___video:folder
140 | |
141 | |_ .mov:file
142 | |_ .mov:file
143 | |_ ...
144 |
145 | """
146 | link = "http://spandh.dcs.shef.ac.uk/avlombard/"
147 | name = 'grid_dataset'
148 | def __init__(self,location,audio_sampling_rate,video_frame_rate,samples_count=None,link=None):
149 | super().__init__(location,
150 | audio_sampling_rate,
151 | video_frame_rate,
152 | samples_count,
153 | link)
154 |
155 |
156 | def generator(self,get_filepath_only=True,img_size=(256,256),show_progress_bar=True)->Sample:
157 | video_folder = os.path.join(self.root_location,'video/')
158 | audio_folder = os.path.join(self.root_location,'audio/')
159 | videos = [video for _,_,video in os.walk(video_folder)]
160 | if self.samples_count is None :
161 | self.samples_count = len(videos[0])
162 | self.samples_count = min(len(videos[0]),self.samples_count)
163 | limit = self.samples_count
164 | progress_bar = tqdm(enumerate(videos[0][:limit]),
165 | desc="Processing Grid Dataset",
166 | total=len(videos[0][:limit]), ascii=True, colour="green") if show_progress_bar else enumerate(videos[0][:limit])
167 |
168 | for idx,video_filename in progress_bar:
169 | if isinstance(progress_bar,tqdm):
170 | progress_bar.set_postfix({'file':video_filename})
171 | audio_filename = video_filename.split('.')[0] + '.wav'
172 | video_path = os.path.join(video_folder, video_filename)
173 | audio_path = os.path.join(audio_folder, audio_filename)
174 |
175 | video_val = None if get_filepath_only else get_video_frames(video_path, img_size=img_size)
176 | audio_val = None if get_filepath_only else get_audio(audio_path)
177 |
178 | res = SampleContainer(video=Sample(video_path, val=video_val),
179 | audio=Sample(audio_path, val=audio_val))
180 | # if idx+1==limit:#what if the actual length of videos is less than the limit passed by talaharte hudga
181 | # return res
182 | # else:
183 | yield res
184 |
185 |
186 |
--------------------------------------------------------------------------------
/wav2mov/core/data/transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms as vtransforms
2 |
3 | from wav2mov.core.data.utils import Sample
4 |
5 | class ResizeGrayscale:
6 | def __init__(self,target_shape):
7 | _,*img_size = target_shape
8 | self.transform = vtransforms.Compose(
9 | [
10 | # vtransforms.Grayscale(1),
11 | vtransforms.Resize(img_size),
12 | # vtransforms.Normalize([0.5]*img_channels, [0.5]*img_channels)
13 | ])
14 | def __call__(self,sample):
15 | video = sample.video
16 | video = self.transform(video)
17 | return Sample(sample.audio,video)
18 |
19 | class Normalize:
20 | def __init__(self,mean_std):
21 | audio_mean,audio_std = mean_std['audio']
22 | video_mean,video_std = mean_std['video']
23 |
24 | self.audio_transform = lambda x: (x-audio_mean)/audio_std
25 | self.video_transform = vtransforms.Normalize(video_mean,video_std)
26 |
27 | def __call__(self,sample):
28 | audio,video = sample.audio,sample.video
29 | audio = self.audio_transform(audio)
30 | video = self.video_transform(video)
31 | return Sample(audio,video)
--------------------------------------------------------------------------------
/wav2mov/core/engine/__init__.py:
--------------------------------------------------------------------------------
1 | from .callbacks import Callbacks,CallbackEvents,CallbackDispatcher
2 | from .engine import TemplateEngine
3 |
--------------------------------------------------------------------------------
/wav2mov/core/engine/callbacks.py:
--------------------------------------------------------------------------------
1 |
2 | from enum import Enum
3 | class CallbackEvents(Enum):
4 | TRAIN_START = "on_train_start"
5 | TRAIN_END = "on_train_end"
6 | EPOCH_START = "on_epoch_start"
7 | EPOCH_END = "on_epoch_end"
8 | BATCH_START = "on_batch_start"
9 | BATCH_END = "on_batch_end"
10 | RUN_START = "on_run_start"
11 | RUN_END = "on_run_end"
12 | NONE = "none"
13 |
14 | class Callbacks:
15 | def on_train_start(self, *args,**kwargs):
16 | pass
17 | def on_train_end(self, *args,**kwargs):
18 | pass
19 | def on_epoch_start(self, *args,**kwargs):
20 | pass
21 | def on_epoch_end(self,*args, **kwargs):
22 | pass
23 | def on_batch_start(self,*args, **kwargs):
24 | pass
25 | def on_batch_end(self, *args,**kwargs):
26 | pass
27 | def on_run_start(self,*args,**kwargs):
28 | pass
29 | def on_run_end(self,*args,**kwargs):
30 | pass
31 |
32 | class CallbackDispatcher:
33 | def __init__(self):
34 | self.callbacks = []
35 |
36 | def register(self,callbacks):
37 | self.callbacks = callbacks
38 |
39 | def dispatch(self,event,*args,**kwargs):
40 | for callback in self.callbacks:
41 | # print(callback.__class__.__name__,args)
42 | getattr(callback,event.value)(*args,**kwargs)
43 |
44 |
--------------------------------------------------------------------------------
/wav2mov/core/engine/engine.py:
--------------------------------------------------------------------------------
1 | from wav2mov.core.engine.callbacks import CallbackEvents,CallbackDispatcher
2 | from wav2mov.core.utils.logger import get_module_level_logger
3 | logger = get_module_level_logger(__name__)
4 |
5 | class TemplateEngine:
6 |
7 | def __init__(self):
8 | self.__event = CallbackEvents.NONE
9 | self.__dispatcher = CallbackDispatcher()
10 |
11 | @property
12 | def event(self):
13 | return self.__event
14 |
15 | @event.setter
16 | def event(self,event):
17 | self.__event = event
18 |
19 | def register(self,callbacks):
20 | self.__dispatcher.register(callbacks)
21 |
22 | def dispatch(self,event,*args,**kwargs):
23 | self.event = event
24 | # logger.debug(f'{self.event},args : {args} kwargs : {kwargs}')
25 | self.__dispatcher.dispatch(event,*args,**kwargs)
26 |
27 | def on_run_start(self,*args,**kwargs):
28 | pass
29 | def on_run_end(self,*args,**kwargs):
30 | pass
31 | def on_train_start(self,*args,**kwargs):
32 | pass
33 | def on_epoch_start(self,*args,**kwargs):
34 | pass
35 | def on_batch_start(self,*args,**kwargs):
36 | pass
37 | def on_batch_end(self,*args,**kwargs):
38 | pass
39 | def log(self,*args,**kwargs):
40 | pass
41 | def validate(self,*args,**kwargs):
42 | pass
43 | def on_epoch_end(self,*args,**kwargs):
44 | pass
45 | def on_train_end(self,*args,**kwargs):
46 | pass
47 | def run(self,*args,**kwargs):
48 | """
49 | training script goes here
50 | """
51 | print("[TEMPLATE ENGINE] 'run' function not implemented")
52 | pass
53 | def test(self,*args,**kwargs):
54 | """
55 | test script goes here
56 | """
57 | print("[TEMPLATE ENGINE] 'testing' function not implemented")
58 | pass
--------------------------------------------------------------------------------
/wav2mov/core/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseModel
2 | from .template_model import TemplateModel
--------------------------------------------------------------------------------
/wav2mov/core/models/base_model.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import torch
3 | from torch import nn
4 |
5 | import logging
6 | logger = logging.getLogger(__name__)
7 |
8 | class BaseModel(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 |
13 | def save_to(self,checkpoint_fullpath):
14 | torch.save(self,checkpoint_fullpath)
15 | logger.log(f'Model saved at {checkpoint_fullpath}','INFO')
16 |
17 | def load_from(self,checkpoint_fullpath):
18 | try:
19 | self.load_statedict(torch.load(checkpoint_fullpath))
20 | except:
21 | logger.log(f'Cannot load checkpoint from {checkpoint_fullpath}',type="ERROR")
22 |
23 | @abstractmethod
24 | def forward(self,*args):
25 | raise NotImplementedError(f'Forward method is not defined in {self.__class__.__name__}')
26 |
27 | def freeze_learning(self):
28 | for p in self.parameters():
29 | p.require_grad = False
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/wav2mov/core/models/template_model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | class TemplateModel(nn.Module):
4 | def __init__(self):
5 | super().__init__()
6 |
7 | def forward(self,*args,**kwargs):
8 | raise NotImplementedError("Forward method must be implemented")
9 |
10 | def on_run_start(self,*args,**kwargs):
11 | pass
12 | def on_run_end(self,*args,**kwargs):
13 | pass
14 | def on_train_start(self,*args,**kwargs):
15 | pass
16 | def on_epoch_start(self,*args,**kwargs):
17 | pass
18 | def on_batch_start(self,*args,**kwargs):
19 | pass
20 | def setup_input(self,*args,**kwargs):
21 | pass
22 | def optimize_parameters(self,*args,**kwargs):
23 | pass
24 | def on_batch_end(self,*args,**kwargs):
25 | pass
26 | def on_epoch_end(self,*args,**kwargs):
27 | pass
28 | def log(self,*args,**kwargs):
29 | pass
30 | def validate(self,*args,**kwargs):
31 | pass
32 | def on_train_end(self,*args,**kwargs):
33 | pass
34 |
35 | def load(self,checkpoint):
36 | self.load_state_dict(checkpoint)
--------------------------------------------------------------------------------
/wav2mov/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .checkpoints import save_checkpoint,save_fully_trained_model
--------------------------------------------------------------------------------
/wav2mov/core/utils/average_meter.py:
--------------------------------------------------------------------------------
1 |
2 | class AverageMeter():
3 | """class for tracking metrics generated during training
4 | """
5 | def __init__(self,name,fmt=':0.2f'):
6 | self.name = name
7 | self.fmt = fmt
8 | self.reset()
9 |
10 | def reset(self):
11 | self.curr_val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self,val,n=1):
17 | self.curr_val = val
18 | self.sum += val*n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
--------------------------------------------------------------------------------
/wav2mov/core/utils/checkpoints.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | def save_checkpoint(model,loss_val,location,save_params=True):
4 | statedict = {
5 | 'model_state_dict': model.state_dict(),
6 | 'min_val_loss' : loss_val,
7 | 'model_arch': str(model)
8 | }
9 | os.makedirs(os.path.dirname(location),exist_ok=True)
10 | torch.save(statedict,location)
11 |
12 | if save_params:
13 | hparams_dir = os.path.dirname(location)
14 | hparams_filename = 'hparams_' + os.path.basename(location).split('.')[0] + '.json'
15 | model.hparams.to_json(os.path.join(hparams_dir,hparams_filename))
16 | # with open(os.path.join(hparams_dir,hparams_filename),'a+') as file:
17 | # json.dump(model.hparams.asdict(),)
18 |
19 |
20 |
21 |
22 | def load_checkpoint(location):
23 | return torch.load(location)
24 |
25 | def save_fully_trained_model(model,loss_val,location,save_params=True):
26 | basename,file_extension = os.path.basename(location).split('.')
27 | new_name = basename +'_fully_trained.' + file_extension
28 | new_location = os.path.join(os.path.dirname(location),new_name)
29 |
30 | save_checkpoint(model,loss_val,new_location,save_params=save_params)
31 |
--------------------------------------------------------------------------------
/wav2mov/core/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logging.basicConfig(level=logging.DEBUG,format="%(levelname)s : %(name)s : %(asctime)s | %(msg)s ")
3 |
4 |
5 | def get_module_level_logger(name):
6 | logger = logging.getLogger(name)
7 | logger.setLevel(logging.DEBUG)
8 | logger.propagate = False
9 | return logger
10 |
--------------------------------------------------------------------------------
/wav2mov/core/utils/misc.py:
--------------------------------------------------------------------------------
1 | from colorama import init,Fore,Back,Style
2 |
3 | from .os_utils import is_windows
4 | from tqdm import tqdm
5 |
6 | init() # colorama
7 |
8 |
9 | class COLOURS:
10 | RED = Fore.RED
11 | GREEN = Fore.GREEN
12 | BLUE = Fore.BLUE
13 | YELLOW = Fore.YELLOW
14 | WHITE = Fore.WHITE
15 |
16 | # def colored(text,color=None,reset=True):
17 | # if color is None:
18 | # color = COLOURS.WHITE
19 | # return f"{color}{text}{Style.RESET_ALL}" if reset else f"{color}{text}"
20 |
21 | def __coloured(text:str,colour)->str:
22 | return f"{colour}{text}{Style.RESET_ALL}"
23 |
24 | def error(text:str)->str:
25 | return __coloured(text,COLOURS.RED)
26 |
27 | def debug(text:str)->str:
28 | return __coloured(text,COLOURS.BLUE)
29 |
30 | def success(text:str)->str:
31 | return __coloured(text,COLOURS.GREEN)
32 |
33 |
34 |
35 |
36 |
37 | def string_wrapper(msg):
38 | if is_windows():
39 | return debug(msg)
40 | else:
41 | return msg
42 |
43 | def get_tqdm_iterator(iterable,description,colour="green"):
44 | """setups the tqdm progress bar
45 |
46 | Arguments:
47 | `iterable` {Iterable}
48 | `description` {str} -- Description to be shown in progress bar
49 |
50 | Keyword Arguments:
51 | `colour` {str} -- colour of progress bar (default: {"green"})
52 |
53 | Returns:
54 | `tqdm`
55 | """
56 | return tqdm(iterable,ascii=True,desc=description,total=len(iterable),colour=colour)
57 |
58 |
--------------------------------------------------------------------------------
/wav2mov/core/utils/os_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | def is_windows():
5 | return sys.platform.startswith('win')
--------------------------------------------------------------------------------
/wav2mov/datasets/create_file_list.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from tqdm import tqdm
4 |
5 | FILENAME = 'filenames.txt'
6 | FILENAME_TRAIN = 'filenames_train.txt'
7 | FILENAME_TEST = 'filenames_test.txt'
8 | DATASET = 'grid_dataset_a5_500_a10to14'
9 | FROM_DIR = ''
10 | TO_DIR = ''
11 |
12 | VIDEOS_PER_ACTOR = 30
13 |
14 | if len(FROM_DIR)==0:
15 | FROM_DIR = os.path.dirname(os.path.abspath(__file__))
16 | FROM_DIR = os.path.join(FROM_DIR,DATASET)
17 | TO_DIR = FROM_DIR
18 |
19 | def get_folders_list():
20 | folders = [folder for _,folder,_ in os.walk(FROM_DIR)][0]
21 | folders_strs = []
22 | for folder in tqdm(folders):
23 | if os.path.isdir(os.path.join(FROM_DIR,folder)):
24 | folders_strs.append(folder+'\n')
25 | return folders_strs
26 |
27 | def write_to_text_file(folders_list):
28 | with open(os.path.join(TO_DIR,FILENAME),'w') as file:
29 | file.writelines(folders_list)
30 | print(f'list written to {os.path.join(TO_DIR,FILENAME)}')
31 |
32 |
33 | def write_to_file(file_path,content_list):
34 | with open(file_path,'w') as file:
35 | file.writelines(content_list)
36 |
37 | def get_actors(folders_list):
38 | actors = set()
39 | for folder in folders_list:
40 | actors.add(folder.split('_')[0])
41 | return sorted(actors)
42 |
43 | def get_train_test_list(folders_list):
44 | actors = get_actors(folders_list)
45 | train_actors = set(actors[:-1])
46 | train_dict = {}
47 | test_list = []
48 | for folder in folders_list:
49 | actor = folder.split('_')[0]
50 | if actor in train_actors:
51 | if len(train_dict.get(actor,[])) possible_num_frames:
97 | raise ValueError(f'given audio has {possible_num_frames} frames but {num_frames} frames requested.')
98 | start_idx = (possible_num_frames-num_frames)//2
99 | end_idx = (possible_num_frames+num_frames)//2 #start_idx + (num_frames)
100 | padding = torch.zeros((1,self.coarticulation_factor*self.stride),device=self.device)
101 | audio = torch.cat([padding,audio,padding],dim=1)
102 | if get_mfccs:
103 | frames = [self.get_frame_from_idx(audio,idx) for idx in range(start_idx,end_idx)]
104 | frames = [self.extract_mfccs(frame) for frame in frames]# each of shape [t,13]
105 | # frames = [((frame-mean[i])/(std[i]+1e-7)) for i,frame in enumerate(frames)]
106 | frames = torch.stack(frames,axis=0)# 1,num_frames,(t,13)
107 | # logger.warning(f'frames {frames.shape} mean : {mean.shape}')
108 | return (frames-mean)/(std+1e-7)
109 | frames = [self.get_frame_from_idx(audio,idx) for idx in range(start_idx,end_idx)]
110 | #each frame is of shape (1,frame_size) so can be catenated along zeroth dimension .
111 | return torch.cat(frames,dim=0)
112 |
113 | def get_limited_audio(self,audio,num_frames,start_frame=None,get_mfccs=False) :
114 | possible_num_frames = audio.shape[-1]//self.stride
115 | if num_frames>possible_num_frames:
116 | logger.error(f'Given num_frames {num_frames} is larger the possible_num_frames {possible_num_frames}')
117 |
118 | mean,std = self.get_mfccs_mean_std(audio)
119 | padding = torch.zeros((audio.shape[0],self.coarticulation_factor*self.stride),device=self.device)
120 | audio = torch.cat([padding,audio,padding],dim=1)
121 |
122 | # possible_num_frames = audio.shape[-1]//self.stride
123 | actual_start_frame = (possible_num_frames-num_frames)//2
124 | # [......................................................]
125 | # [................................]
126 | # |<-----num_frames---------------->|
127 | #.........^
128 | # actual start frame
129 | if start_frame is None:
130 | start_frame = actual_start_frame
131 |
132 | if start_frame+num_frames>possible_num_frames:#[why > not >=]think if possible num_frames is 50 and 50 is the required num_frames and start_frame is zero
133 | logger.warning(f'Given Audio has {possible_num_frames} frames. Given starting frame {start_frame} cannot be consider for getting {num_frames} frames. Changing startframes to {actual_start_frame} frame.')
134 | start_frame = actual_start_frame
135 |
136 | end_frame = start_frame + (num_frames) #exclusive
137 |
138 | start_pos = self.__get_center_idx(start_frame)
139 | end_pos = self.__get_center_idx(end_frame-1)
140 |
141 | audio = audio[:,self.__get_start_idx(start_pos):self.__get_end_idx(end_pos)]
142 | if get_mfccs:
143 | mfccs = list(map(self.extract_mfccs,audio))
144 | # mfccs = [(mfcc-mean[i]/(std[i]+1e-7)) for i,mfcc in enumerate(mfccs)]
145 | mfccs = torch.stack(mfccs,axis=0)
146 | return (mfccs-mean)/(std+1e-7)
147 | return audio
148 |
149 | def load_audio(audio_path):
150 | audio,sr = librosa.load(audio_path,sr=params.AUDIO_SF)
151 | return torch.from_numpy(audio)
152 |
153 | def get_num_frames(audio):
154 | audio_len = audio.squeeze().shape[0]
155 | return audio_len//params.STRIDE
156 |
157 | def preprocess_audio(audio):
158 | if not isinstance(audio,Tensor):
159 | audio = torch.from_numpy(audio)
160 | audio = (audio-params.AUDIO_MEAN)/(params.AUDIO_STD + params.EPSILON)
161 | framewise_mfccs = AudioUtil(params.AUDIO_SF,params.COARTICULATION_FACTOR,
162 | params.STRIDE,params.DEVICE).get_audio_frames(audio,get_mfccs=True)
163 | return framewise_mfccs
164 |
--------------------------------------------------------------------------------
/wav2mov/inference/generate.py:
--------------------------------------------------------------------------------
1 | # get audio and image
2 | # get audio_frames and extract face from image and resize
3 | # apply transforms
4 | # generator forward pass
5 | import argparse
6 | import os
7 | import torch
8 |
9 | from inference import audio_utils,image_utils,utils,model_utils
10 | logger = utils.get_module_level_logger(__name__)
11 |
12 | def is_exists(file):
13 | return os.path.isfile(file)
14 |
15 | def create_batch(sample):
16 | return sample.unsqueeze(0)
17 |
18 | def save_video(audio,video_frames):
19 | video_path = os.path.join(DIR,'out','fake_video_without_audio.avi')
20 | #####################################################################
21 | #⚠ IMPORTANT TO DE NORMALiZE ELSE OUTPUT WILL BE BLACK!
22 | #####################################################################
23 | video_frames = ((video_frames*0.5)+0.5)*255
24 | utils.save_video(video_path,audio,video_frames)
25 |
26 | DIR = os.path.dirname(os.path.abspath(__file__))
27 |
28 | def generate_video(image,audio):
29 | audio_feats = audio_utils.preprocess_audio(audio)
30 | num_frames = audio_feats.shape[0]
31 | image = image_utils.preprocess_image(image)
32 | images = image_utils.repeat_img(image,num_frames)
33 | logger.debug(f'✅ Preprocessing Done.')
34 | model = model_utils.get_model()
35 | logger.debug(f'✅ Model Loaded.')
36 | with torch.no_grad():
37 | model.eval()
38 | gen_frames = model(create_batch(audio_feats),create_batch(images))
39 | logger.debug(f'✅ Frames Generated.')
40 | B,T,*img_shape = gen_frames.shape
41 | gen_frames = gen_frames.reshape(B*T,*img_shape)
42 | save_video(audio,gen_frames)
43 |
44 | if __name__ == '__main__':
45 | logger.debug(f'✅ Modules loaded.')
46 | # default_image = r'ref_frame_Run_27_5_2021__16_6.png'
47 | # default_image = r'ref_frame_Run_19_6_2021__12_14.png'
48 | # default_image = r'image.jfif'
49 | # default_image = r'musk.jfif'
50 | default_image = r'train_ref_frame_Run_9_7_2021__14_52.png'
51 | default_image = r'train_ref_frame_Run_11_7_2021__19_18.png'
52 | # default_audio = r'03-01-01-01-02-02-01.wav'
53 | default_audio = r'audio.wav'
54 | arg_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
55 | description="Wav2Mov | Speech Controlled Facial Animation")
56 | arg_parser.add_argument('--image','-i',type=str,required=False,
57 | help='image containing face of the person',
58 | default=os.path.join(DIR,'inputs',default_image))
59 | arg_parser.add_argument('--audio','-a',type=str,required=False,
60 | help='speech for which face should be animated',
61 | default=os.path.join(DIR,'inputs',default_audio))
62 | options = arg_parser.parse_args()
63 |
64 |
65 | image_path = options.image
66 | audio_path = options.audio
67 | if not (is_exists(image_path)):
68 | raise FileNotFoundError(f'[ERROR] ❌ image path is incorrect :{image_path}')
69 | if not (is_exists(audio_path)):
70 | raise FileNotFoundError(f'[ERROR] ❌ audio path is incorrect :{audio_path}')
71 | image = image_utils.load_image(image_path)
72 | audio = audio_utils.load_audio(audio_path)
73 | generate_video(image,audio)
74 | logger.debug('✅ Video saved')
75 |
--------------------------------------------------------------------------------
/wav2mov/inference/generate.sh:
--------------------------------------------------------------------------------
1 | python -m inference.generate
2 |
--------------------------------------------------------------------------------
/wav2mov/inference/image_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import dlib
3 | import torch
4 | from torchvision import utils as vutils
5 | from torchvision.transforms import Normalize
6 | from inference import params
7 |
8 | face_detector = None
9 | def load_face_detector():
10 | return dlib.get_frontal_face_detector()
11 |
12 | def convert_and_trim_bb(image, rect):
13 | """ from pyimagesearch
14 | https://www.pyimagesearch.com/2021/04/19/face-detection-with-dlib-hog-and-cnn/
15 | """
16 | # extract the starting and ending (x, y)-coordinates of the
17 | # bounding box
18 | start_x = rect.left()
19 | start_y = rect.top()
20 | endX = rect.right()
21 | endY = rect.bottom()
22 | # ensure the bounding box coordinates fall within the spatial
23 | # dimensions of the image
24 | start_x = max(0, start_x)
25 | start_y = max(0, start_y)
26 | endX = min(endX, image.shape[1])
27 | endY = min(endY, image.shape[0])
28 | # compute the width and height of the bounding box
29 | w = endX - start_x
30 | h = endY - start_y
31 | # return our bounding box coordinates
32 | return (start_x, start_y, w, h)
33 |
34 |
35 | def load_image(image_path):
36 | img = cv2.imread(image_path,cv2.IMREAD_COLOR)
37 | if params.IMAGE_CHANNELS==1:
38 | return cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
39 | # height,width = params.IMAGE_SIZE
40 | # img = cv2.resize(img,(width,height),cv2.INTER_CUBIC)
41 | return cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
42 |
43 | def save_image(image,path,normalize=False):
44 | vutils.save_image(image,path,normalize=normalize)
45 |
46 | def show_image(image):
47 | cv2.imshow('image',image)
48 | cv2.waitKey(0)
49 | cv2.destroyAllWindows()
50 |
51 | def preprocess_image(image):
52 | global face_detector
53 | if face_detector is None:
54 | face_detector = load_face_detector()
55 | face = face_detector(image)[0]
56 | x,y,w,h = convert_and_trim_bb(image,face)
57 | target_height,target_width = params.IMAGE_SIZE
58 | image = cv2.resize(image[y:y+h,x:x+w],(target_width,target_height),interpolation=cv2.INTER_CUBIC)
59 | if len(image.shape)==2:
60 | image = image.reshape(image.shape[0],image.shape[1],1)
61 | # show_image(image)
62 | image = torch.from_numpy(image).float().permute(2,0,1)/255
63 | image= Normalize(params.VIDEO_MEAN,params.VIDEO_STD)(image)
64 | # show_image(image.permute(1,2,0).numpy())
65 | return image
66 |
67 | def repeat_img(img,n):
68 | if img.shape==3:
69 | img = img.unsqueeze(0)
70 |
71 | return img.repeat(n,1,1,1)
72 |
--------------------------------------------------------------------------------
/wav2mov/inference/model_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from inference.models import Generator
4 | from inference import params
5 | from inference.utils import get_module_level_logger
6 | logger = get_module_level_logger(__name__)
7 |
8 | gen_hparams = {
9 | "in_channels": 3,
10 | "chs": [64, 128, 256, 512, 1024],
11 | "latent_dim": 272,
12 | "latent_dim_id": [8, 8],
13 | "comment": "laten_dim not eq latent_dim_id + latent_dim_audio, its 4x4 + 256",
14 | "latent_dim_audio": 256,
15 | "device": "cpu",
16 | "lr": 2e-4
17 | }
18 |
19 |
20 | def get_model(checkpoint_path=None):
21 | if checkpoint_path is None:
22 | checkpoint_path = params.GEN_CHECKPOINT_PATH
23 |
24 | if not os.path.isfile(checkpoint_path):
25 | logger.error(f'NO FILE : {checkpoint_path}')
26 | raise FileNotFoundError(
27 | 'Please make sure to put generator file in pt_files folder ')
28 | state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
29 | model = Generator(gen_hparams)
30 | model.load_state_dict(state_dict)
31 | model.eval()
32 | return model
33 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator.frame_generator import Generator
2 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/generator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/inference/models/generator/__init__.py
--------------------------------------------------------------------------------
/wav2mov/inference/models/generator/audio_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from inference.models.layers.conv_layers import Conv1dBlock,Conv2dBlock
4 | from inference.models.layers.debug_layers import IdentityDebugLayer
5 | from inference.models.utils import squeeze_batch_frames,get_same_padding
6 | from inference.utils import get_module_level_logger
7 |
8 | logger = get_module_level_logger(__name__)
9 |
10 | class AudioEnocoder(nn.Module):
11 | def __init__(self,hparams):
12 | super().__init__()
13 | self.hparams = hparams
14 | padding_31 = get_same_padding(3,1)
15 | padding_32 = get_same_padding(3,2)
16 | padding_42 = get_same_padding(4,2)
17 | # each frame has mfcc of shape (7,13)
18 | self.conv_encoder = nn.Sequential(
19 | Conv2dBlock(1,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()),#7,13
20 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
21 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
22 |
23 | Conv2dBlock(32,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13
24 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
25 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
26 |
27 | Conv2dBlock(64,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13
28 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
29 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
30 |
31 | Conv2dBlock(128,256,3,(2,1),padding=padding_32,use_norm=True,use_act=True,act=nn.ReLU()), #4,13
32 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
33 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
34 |
35 | Conv2dBlock(256,512,(4,3),(2,1),padding=1,use_norm=True,use_act=True,act=nn.ReLU()), #2,13
36 | Conv2dBlock(512,512,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
37 | Conv2dBlock(512,512,(2,5),(1,2),padding=(0,1),use_norm=True,use_act=True,act=nn.ReLU()),#1,6,
38 | )
39 |
40 | self.features_len = 6*512#out of conv layers
41 | self.num_layers = 1
42 | self.hidden_size = self.hparams['latent_dim_audio']
43 | self.gru = nn.GRU(input_size=self.features_len,
44 | hidden_size=self.hidden_size,
45 | num_layers=1,
46 | batch_first=True)
47 | self.final_act = nn.Tanh()
48 |
49 | def forward(self, x):
50 | """ x : audio frames of shape B,T,t,13"""
51 | batch_size,num_frames,*_ = x.shape
52 | x = self.conv_encoder(squeeze_batch_frames(x).unsqueeze(1))#add channel dimension
53 | x = x.reshape(batch_size,num_frames,self.features_len)
54 | x,_ = self.gru(x)
55 | return self.final_act(x) # shape (batch_size,num_frames,hidden_size=latent_dim_audio)
56 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/generator/frame_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn,optim
3 |
4 | from inference.models.generator.audio_encoder import AudioEnocoder
5 | from inference.models.generator.id_encoder import IdEncoder
6 | from inference.models.generator.id_decoder import IdDecoder
7 |
8 | from inference.models.utils import squeeze_batch_frames
9 | from inference.utils import get_module_level_logger
10 | logger = get_module_level_logger(__name__)
11 |
12 | class Generator(nn.Module):
13 | def __init__(self,hparams):
14 | super().__init__()
15 | self.hparams = hparams
16 | self.id_encoder = IdEncoder(self.hparams)
17 | self.id_decoder = IdDecoder(self.hparams)
18 | self.audio_encoder = AudioEnocoder(self.hparams)
19 | # self.noise_encoder = NoiseEncoder(self.hparams)
20 |
21 | def forward(self,audio_frames,ref_frames):
22 | batch_size,num_frames,*_ = ref_frames.shape
23 | assert num_frames == audio_frames.shape[1]
24 | encoded_id , intermediates = self.id_encoder(squeeze_batch_frames(ref_frames))
25 | encoded_id = encoded_id.reshape(batch_size*num_frames,-1,1,1)
26 | encoded_audio = self.audio_encoder(audio_frames).reshape(batch_size*num_frames,-1,1,1)
27 | # encoded_noise = self.noise_encoder(batch_size,num_frames).reshape(batch_size*num_frames,-1,1,1)
28 | # logger.debug(f'encoded_id {encoded_id.shape} encoded_audio {encoded_audio.shape} encoded_noise {encoded_noise.shape}')
29 | encoded = torch.cat([encoded_id,encoded_audio],dim=1)#along channel dimension
30 | gen_frames = self.id_decoder(encoded,intermediates)
31 | _,*img_shape = gen_frames.shape
32 | return gen_frames.reshape(batch_size,num_frames,*img_shape)
33 |
34 | def get_optimizer(self):
35 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5, 0.999))
36 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/generator/id_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from inference.models.layers.conv_layers import Conv2dBlock,ConvTranspose2dBlock,DoubleConvTranspose2d
5 | from inference.models.utils import get_same_padding
6 | from inference.utils import get_module_level_logger
7 |
8 | logger = get_module_level_logger(__name__)
9 |
10 | class IdDecoder(nn.Module):
11 | def __init__(self,hparams):
12 | super().__init__()
13 | self.hparams = hparams
14 | self.in_channels = hparams['in_channels']
15 | self.latent_channels = hparams['latent_dim']
16 | latent_id_height,latent_id_width = hparams['latent_dim_id']
17 | chs = self.hparams['chs'][::-1]#1024,512,256,128,64
18 | chs = [self.latent_channels] + chs + [self.in_channels]
19 | padding = get_same_padding(kernel_size=4,stride=2)
20 | self.convs = nn.ModuleList()
21 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[0],
22 | out_ch=chs[1],
23 | kernel_size=(latent_id_height,latent_id_width),
24 | stride=2,
25 | padding=0,
26 | use_norm=True,
27 | use_act=True,
28 | act=nn.ReLU()
29 | ))
30 | for i in range(1,len(chs)-2):
31 | self.convs.append( DoubleConvTranspose2d(in_ch=chs[i],
32 | skip_ch=chs[i],
33 | out_ch=chs[i+1],
34 | kernel_size=(4,4),
35 | stride=2,
36 | padding=padding,
37 | use_norm=True,
38 | use_act=True,
39 | act=nn.ReLU())
40 | )
41 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[-2],
42 | out_ch=chs[-1],
43 | kernel_size=(4,4),
44 | stride=2,
45 | padding=padding,
46 | use_norm=True,
47 | use_act=True,
48 | act=nn.Tanh()
49 | ))
50 | def forward(self,encoded,skip_outs):
51 | """[summary]
52 |
53 | Args:
54 | skip_outs ([type]): 64,128,256,512,1024
55 | encoded ([type]): latent_channels,
56 | """
57 | encoded = encoded.reshape(encoded.shape[0],self.latent_channels,1,1)
58 | skip_outs = skip_outs[::-1]
59 | for i,block in enumerate(self.convs[:-2]):
60 | x = block(encoded)
61 | encoded = torch.cat([x,skip_outs[i]],dim=1)#cat along channel axis
62 |
63 | encoded = self.convs[-2](encoded)
64 | return self.convs[-1](encoded)
65 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/generator/id_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from inference.models.layers.conv_layers import Conv2dBlock
5 | from inference.models.utils import get_same_padding
6 | from inference.utils import get_module_level_logger
7 |
8 | logger = get_module_level_logger(__name__)
9 |
10 | class IdEncoder(nn.Module):
11 | def __init__(self,hparams):
12 | super().__init__()
13 | self.hparams = hparams
14 | in_channels = hparams['in_channels']
15 | chs = self.hparams['chs']
16 | chs = [in_channels] + chs +[1]# 1 is added here not in params because see how channels are being used in id_decoder
17 | padding = get_same_padding(kernel_size=4,stride=2)
18 | self.conv_blocks = nn.ModuleList(Conv2dBlock(chs[i],chs[i+1],
19 | kernel_size=(4,4),
20 | stride=2,
21 | padding=padding,
22 | use_norm=True,
23 | use_act=True,
24 | act=nn.ReLU()
25 | ) for i in range(len(chs)-2)
26 | )
27 |
28 | self.conv_blocks.append(Conv2dBlock(chs[-2],chs[-1],
29 | kernel_size=(4,4),
30 | stride=2,
31 | padding=padding,
32 | use_norm=False,
33 | use_act=True,
34 | act=nn.Tanh()))
35 | def forward(self,images):
36 | intermediates = []
37 | for block in self.conv_blocks[:-1]:
38 | images = block(images)
39 | intermediates.append(images)
40 | encoded = self.conv_blocks[-1](images)
41 | return encoded,intermediates
42 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/layers/conv_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from inference.models.utils import get_same_padding
5 |
6 | class DoubleConv2dBlock(nn.Module):
7 | """height and width are halved in feature map"""
8 | def __init__(self,
9 | in_ch,
10 | out_ch,
11 | use_norm=True,
12 | act=None):
13 | super().__init__()
14 | self.use_norm = use_norm
15 | self.conv1 = nn.Conv2d(in_ch, out_ch, 3,padding=get_same_padding(3,1),bias=not self.use_norm)
16 | self.batch_norm = nn.BatchNorm2d(out_ch)
17 | self.relu = nn.LeakyReLU(0.2) if act is None else act()
18 | self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
19 |
20 | def forward(self, x):
21 | # print('double conv block',x.shape,type(x),x.device,next(self.parameters()).device)
22 | x = self.conv1(x)
23 | if self.use_norm :
24 | x = self.batch_norm(x)
25 | return self.relu(self.conv2(self.relu(x)))
26 |
27 | class Conv1dBlock(nn.Module):
28 | def __init__(self,
29 | in_ch,
30 | out_ch,
31 | kernel_size,
32 | stride,
33 | padding,
34 | use_norm=True,
35 | use_act=True,act=None,
36 | residual=False):
37 |
38 | super().__init__()
39 | self.use_norm = use_norm
40 | self.use_act = use_act
41 | self.act = nn.LeakyReLU(0.2) if act is None else act
42 | self.norm = nn.BatchNorm1d(out_ch)
43 | self.conv = nn.Conv1d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm)
44 | self.residual = residual
45 |
46 | def forward(self,in_x):
47 | x = self.conv(in_x)
48 | if self.use_norm:
49 | x = self.norm(x)
50 | if self.residual:
51 | x += in_x
52 | if self.use_act:
53 | x = self.act(x)
54 | return x
55 |
56 | class Conv2dBlock(nn.Module):
57 | def __init__(self,
58 | in_ch,
59 | out_ch,
60 | kernel_size,
61 | stride,
62 | padding,
63 | use_norm=True,
64 | use_act=True,act=None,
65 | residual=False):
66 |
67 | super().__init__()
68 | self.use_norm = use_norm
69 | self.use_act = use_act
70 | self.act = nn.LeakyReLU(0.2) if act is None else act
71 | self.norm = nn.BatchNorm2d(out_ch)
72 | self.conv = nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm)
73 | self.residual = residual
74 |
75 | def forward(self,in_x):
76 | x = self.conv(in_x)
77 | if self.use_norm:
78 | x = self.norm(x)
79 | if self.residual:
80 | x += in_x
81 | if self.use_act:
82 | x = self.act(x)
83 | return x
84 |
85 | class ConvTranspose2dBlock(nn.Module):
86 | def __init__(self,in_ch,
87 | out_ch,
88 | kernel_size,
89 | stride,
90 | padding,
91 | use_norm=True,
92 | use_act=True,
93 | act=None):
94 | super().__init__()
95 | self.use_norm = use_norm
96 | self.use_act = use_act
97 | self.act = nn.LeakyReLU(0.2) if act is None else act
98 | self.norm = nn.BatchNorm2d(out_ch)
99 | self.conv = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm)
100 |
101 | def forward(self,x):
102 | x = self.conv(x)
103 | if self.use_norm:
104 | x = self.norm(x)
105 | if self.use_act:
106 | x = self.act(x)
107 | return x
108 |
109 | class DoubleConvTranspose2d(nn.Module):
110 | def __init__(self,in_ch,
111 | skip_ch,
112 | out_ch,
113 | kernel_size,
114 | stride,
115 | padding,
116 | use_norm=True,
117 | use_act=True,
118 | act=None):
119 | super().__init__()
120 | self.use_norm = use_norm
121 | self.use_act = use_act
122 |
123 | self.conv1 = nn.Conv2d(in_ch+skip_ch,in_ch,kernel_size=3,stride=1,padding=1,bias=False)
124 | self.norm1 = nn.BatchNorm2d(in_ch)
125 | self.conv2 = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding)
126 | self.norm2 = nn.BatchNorm2d(out_ch)
127 | self.act = nn.LeakyReLU(0.2) if act is None else act
128 |
129 | def forward(self,x):
130 | x = nn.functional.relu(self.norm1(self.conv1(x)))
131 | x = self.conv2(x)
132 | if self.use_norm:
133 | x = self.norm2(x)
134 | if self.use_act:
135 | x = self.act(x)
136 | return x
137 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/layers/debug_layers.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from inference.utils import get_module_level_logger
4 | logger = get_module_level_logger(__name__)
5 |
6 | class IdentityDebugLayer(nn.Module):
7 | def __init__(self,name):
8 | super().__init__()
9 | self.name = name
10 | def forward(self,x):
11 | logger.debug(f'{self.name} : {x.shape}')
12 | return x
13 |
--------------------------------------------------------------------------------
/wav2mov/inference/models/utils.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | from torch.nn import init
3 |
4 | from inference.utils import get_module_level_logger
5 | logger = get_module_level_logger(__name__)
6 |
7 | def get_same_padding(kernel_size,stride,in_size=0):
8 | """https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks
9 | https://github.com/DinoMan/speech-driven-animation/blob/dc9fe4aa77b4e042177328ea29675c27e2f56cd4/sda/utils.py#L18-L21
10 |
11 | padding = 'same'
12 | • Padding such that feature map size has size In_size/Stride
13 |
14 | • Output size is mathematically convenient
15 |
16 | • Also called 'half' padding
17 |
18 | out = (in-k+2*p)/s + 1
19 | if out == in/s:
20 | in/s = (in-k+2*p)/s + 1
21 | ((in/s)-1)*s + k -in = 2*p
22 | (in-s)+k-in = 2*p
23 | in case of s==1:
24 | p = (k-1)/2
25 | """
26 | out_size = ceil(in_size/stride)
27 | return ceil(((out_size-1)*stride+ kernel_size-in_size)/2)#(k-1)//2 for same padding
28 |
29 | def init_weights(net, init_type='normal', init_gain=0.02):
30 | """Initialize network weights.
31 | src : https://github1s.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
32 | Parameters:
33 | net (network) -- network to be initialized
34 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
35 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
36 |
37 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
38 | work better for some applications. Feel free to try yourself.
39 | """
40 | def init_func(m): # define the initialization function
41 | classname = m.__class__.__name__
42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
43 | if init_type == 'normal':
44 | init.normal_(m.weight.data, 0.0, init_gain)
45 | elif init_type == 'xavier':
46 | init.xavier_normal_(m.weight.data, gain=init_gain)
47 | elif init_type == 'kaiming':
48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
49 | elif init_type == 'orthogonal':
50 | init.orthogonal_(m.weight.data, gain=init_gain)
51 | else:
52 | raise NotImplementedError(
53 | 'initialization method [%s] is not implemented' % init_type)
54 | if hasattr(m, 'bias') and m.bias is not None:
55 | init.constant_(m.bias.data, 0.0)
56 | # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
57 | elif classname.find('BatchNorm2d') != -1:
58 | init.normal_(m.weight.data, 1.0, init_gain)
59 | init.constant_(m.bias.data, 0.0)
60 |
61 | logger.debug(f'initializing {net.__class__.__name__} with {init_type} weights')
62 | # print('initialize network with %s' % init_type)
63 | net.apply(init_func) # apply the initialization function
64 |
65 |
66 | def init_net(net, init_type='normal', init_gain=0.02):
67 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
68 | Parameters:
69 | net (network) -- the network to be initialized
70 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
71 | gain (float) -- scaling factor for normal, xavier and orthogonal.
72 |
73 |
74 | Return an initialized network.
75 | """
76 |
77 | init_weights(net, init_type, init_gain=init_gain)
78 | return net
79 |
80 |
81 | def squeeze_batch_frames(target):
82 | batch_size,num_frames,*extra = target.shape
83 | return target.reshape(batch_size*num_frames,*extra)
84 |
--------------------------------------------------------------------------------
/wav2mov/inference/params.py:
--------------------------------------------------------------------------------
1 | EPSILON = 1e-7
2 | DEVICE = 'cpu'
3 | COARTICULATION_FACTOR = 2
4 | AUDIO_SF = 16000
5 | VIDEO_FPS = 24
6 | STRIDE = AUDIO_SF//VIDEO_FPS
7 | IMAGE_CHANNELS = 3
8 | IMAGE_SIZE = (256,256)
9 | VIDEO_MEAN = [0.5,0.5,0.5]
10 | VIDEO_STD = [0.5,0.5,0.5]
11 | # AUDIO_MEAN = 5.6196e-07
12 | # AUDIO_STD = 0.0142
13 | AUDIO_MEAN = -4.3688e-07
14 | AUDIO_STD = 0.0123
15 | GEN_CHECKPOINT_PATH = r'.\models\checkpoints\gen_Run_11_7_2021__19_18.pt'
16 |
--------------------------------------------------------------------------------
/wav2mov/inference/quantize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.quantization
4 | from torch import nn
5 | # here is our floating point instance
6 | from inference.model_utils import get_model
7 | # import the modules used here in this recipe
8 |
9 |
10 | def print_size_of_model(m, label=""):
11 | torch.save(m.state_dict(), "temp.p")
12 | size=os.path.getsize("temp.p")
13 | print("model: ",label,' \t','Size (MB):', size/1e6)
14 | os.remove('temp.p')
15 | return size
16 |
17 | DIR = os.path.dirname(os.path.abspath(__file__))
18 | model = get_model()
19 | target_path =r'models/checkpoints/gen_quantized.pt'
20 | # this is the call that does the work
21 | model_quantized = torch.quantization.quantize_dynamic(
22 | model, {nn.Conv1d,nn.Conv2d,nn.ConvTranspose2d, nn.Linear,nn.GRU}, dtype=torch.qint8
23 | )
24 |
25 | # compare the sizes
26 | # f=print_size_of_model(float_lstm,"fp32")
27 | f=print_size_of_model(model,"fp32")
28 | q=print_size_of_model(model_quantized,"int8")
29 | print("{0:.2f} times smaller".format(f/q))
30 |
--------------------------------------------------------------------------------
/wav2mov/inference/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch
4 | from torchvision.io import write_video
5 | from scipy.io.wavfile import write as write_audio
6 | from moviepy import editor as mpy
7 |
8 | from inference import params
9 | def get_module_level_logger(name):
10 | m_logger = logging.getLogger(name)
11 | m_logger.setLevel(logging.DEBUG)
12 | m_logger.propagate = False
13 | handler = logging.StreamHandler()
14 | handler.setFormatter(logging.Formatter("%(levelname)-5s : %(filename)s :%(lineno)s | %(asctime)s | %(msg)s ", "%b %d,%Y %H:%M:%S"))
15 | m_logger.addHandler(handler)
16 | return m_logger
17 |
18 | logger = get_module_level_logger('utils')
19 |
20 | def save_video(video_path, audio, video_frames):
21 | """
22 | audio_frames : C,S
23 | video_frames : F,C,H,W
24 | """
25 | # if has batch dimension remove it
26 | hparams = {'video_fps':params.VIDEO_FPS,'audio_sf':params.AUDIO_SF}
27 | if len(video_frames.shape) == 5:
28 | video_frames = video_frames[0]
29 | if video_frames.shape[1] == 1:
30 | video_frames = video_frames.repeat(1, 3, 1, 1)
31 | logger.warning('Grayscale images...')
32 | logger.debug(f'✅ video frames :{video_frames.shape[:]}, audio : {audio.shape[:]}')
33 | video_frames = video_frames.to(torch.uint8)
34 | os.makedirs(os.path.dirname(video_path),exist_ok=True)
35 | write_video(filename=video_path,
36 | video_array=video_frames.permute(0, 2, 3, 1),
37 | fps=hparams['video_fps'],
38 | video_codec="h264",
39 | )
40 | dir_name = os.path.dirname(video_path)
41 | temp_audio_path = os.path.join(dir_name, 'temp', 'temp_audio.wav')
42 | os.makedirs(os.path.dirname(temp_audio_path), exist_ok=True)
43 | write_audio(temp_audio_path,hparams['audio_sf'], audio.cpu().numpy().reshape(-1))
44 |
45 | video_clip = mpy.VideoFileClip(video_path)
46 | audio_clip = mpy.AudioFileClip(temp_audio_path)
47 | video_clip.audio = audio_clip
48 | video_clip.write_videofile(os.path.join(dir_name, 'fake_video_with_audio.avi'), fps=hparams['video_fps'], codec='png',logger=None)
49 |
--------------------------------------------------------------------------------
/wav2mov/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import re
4 | from datetime import datetime
5 | from torch.utils.tensorboard.writer import SummaryWriter
6 | from pythonjsonlogger import jsonlogger
7 |
8 | logging.basicConfig(level=logging.ERROR,format="%(levelname)-5s : %(name)s : %(asctime)s | %(msg)s ")
9 | TIME_FORMAT = "%b %d,%Y %H:%M:%S"
10 |
11 | def get_module_level_logger(name):
12 | m_logger = logging.getLogger(name)
13 | m_logger.setLevel(logging.DEBUG)
14 | m_logger.propagate = False
15 | handler = logging.StreamHandler()
16 | handler.setFormatter(logging.Formatter("%(levelname)-5s : %(name)s : %(asctime)s | %(msg)s ",TIME_FORMAT))
17 | m_logger.addHandler(handler)
18 | return m_logger
19 |
20 |
21 | logger = get_module_level_logger(__name__)
22 | class CustomJsonFormatter(jsonlogger.JsonFormatter):
23 | def add_fields(self,log_record,record,message_dict):
24 | # print(log_record,vars(record),message_dict)
25 | # for key,value in vars(record).items():
26 | # print(f'{key} : {value}')
27 | super().add_fields(log_record,record,message_dict)
28 | # if not log_record.get('timestamp'):
29 | # now = datetime.utcnow().strftime(TIME_FORMAT)
30 | # log_record['timestamp'] = now
31 | # if log_record.get('level'):
32 | # log_record['level'] = log_record['level'].upper()
33 | # else:
34 | # log_record['level'] = record.levelname
35 | if log_record.get('asctime'):
36 | log_record['asctime'] = datetime.utcnow().strftime(TIME_FORMAT)
37 |
38 |
39 |
40 | class TensorLogger:
41 | def __init__(self,runs_dir):
42 | self.writers = {}
43 | self.runs_dir = runs_dir
44 | def create_writer(self,name):
45 | self.writers[name] = SummaryWriter(os.path.join(self.runs_dir,name))
46 |
47 | def add_writer(self, name, writer):
48 | self.writers[name] = writer
49 |
50 | def add_writers(self,names):
51 | for name in names:
52 | self.create_writer(name)
53 |
54 | def add_scalar(self,writer_name,tag,scalar,global_step):
55 | if writer_name not in self.writers:
56 | logger.warning(f'No writer found named {writer_name}')
57 | self.create_writer(writer_name)
58 | self.writers[writer_name].add_scalar(tag,scalar,global_step)
59 |
60 |
61 | def add_scalars(self,d,global_step):
62 | for writer_name,(tag,scalar) in d.items():
63 | self.add_scalar(writer_name,tag,scalar,global_step)
64 |
65 |
66 |
67 |
68 |
69 | class Logger:
70 | def __init__(self, name):
71 | self.log_fullpath = None
72 | self.name = name
73 | self.logger = logging.getLogger(name)
74 | self.logger.propagate = False
75 | self.logger.setLevel(logging.DEBUG)
76 | self.is_json = False
77 | self.is_first_log = True
78 |
79 | def __add_handler(self, handler: logging.Handler):
80 | handler.setLevel(logging.DEBUG)
81 | self.logger.addHandler(handler)
82 |
83 | @classmethod
84 | def __get_formatter(cls, fmt: str):
85 | return logging.Formatter(fmt,TIME_FORMAT)
86 |
87 | @classmethod
88 | def __get_json_formatter(cls,fmt):
89 | return CustomJsonFormatter(fmt=fmt)
90 |
91 | def __json_log_begin(self):
92 | # print(f'inside json begin')
93 | with open(self.log_fullpath,'a+') as file:
94 | # print(file.read())
95 | file.write('[\n')
96 | # print(file.read())
97 |
98 | def __json_log_end(self):
99 | print('writing "]" to json log')
100 | with open(self.log_fullpath,'a+') as file:
101 | file.write(']\n')
102 |
103 |
104 | def add_filehandler(self, log_fullpath, fmt:str=None,in_json=False):
105 | self.log_fullpath = log_fullpath
106 | if fmt is None:
107 | fmt = '%(levelname)-5s : %(filename)s : %(asctime)s : line no: %(lineno)d : %(message)s'
108 | if in_json:
109 | self.is_json = True
110 | folder = os.path.dirname(self.log_fullpath)
111 | filename= os.path.basename(self.log_fullpath)
112 | filename = filename.split('.')[0] + '.json'
113 |
114 | self.log_fullpath = os.path.join(folder,filename)
115 |
116 | self.__json_log_begin()
117 | file_handler = logging.FileHandler(self.log_fullpath)
118 | file_handler.setFormatter(self.__get_json_formatter(fmt))
119 | else:
120 | file_handler = logging.FileHandler(self.log_fullpath)
121 | file_handler.setFormatter(self.__get_formatter(fmt))
122 | # self.log_fullpath = re.sub(r'(\\)',os.sep,self.log_fullpath)
123 | os.makedirs(os.path.dirname(self.log_fullpath), exist_ok=True)
124 | self.__add_handler(file_handler)
125 |
126 | def add_console_handler(self, fmt: str = None):
127 | if fmt is None:
128 | fmt = "%(levelname)-5s : %(filename)s : %(asctime)s : line no: %(lineno)d : %(message)s"
129 | console_handler = logging.StreamHandler()
130 | console_handler.setFormatter(self.__get_formatter(fmt))
131 | self.__add_handler(console_handler)
132 |
133 | def __add_comma(self):
134 | if self.is_first_log:
135 | self.is_first_log = False
136 | return
137 | with open(self.log_fullpath,'a+') as file:
138 | file.write(',\n')
139 |
140 | @property
141 | def debug(self):
142 | if self.is_json:self.__add_comma()
143 | return self.logger.debug
144 |
145 | @property
146 | def info(self):
147 | if self.is_json:self.__add_comma()
148 | return self.logger.info
149 |
150 | @property
151 | def warning(self):
152 | if self.is_json:self.__add_comma()
153 | return self.logger.warning
154 |
155 | @property
156 | def error(self):
157 | if self.is_json:self.__add_comma()
158 | return self.logger.error
159 |
160 | @property
161 | def exception(self):
162 | if self.is_json:self.__add_comma()
163 | return self.logger.exception
164 |
165 | def cleanup(self):
166 | if self.is_json:
167 | self.__json_log_end()
168 |
169 |
--------------------------------------------------------------------------------
/wav2mov/losses/ReadMe.md:
--------------------------------------------------------------------------------
1 | # Loss / Criterions
2 |
3 | ## Gan Loss
4 |
5 | + Binary Cross Entropy Loss
6 |
7 | ## Sync Loss
8 |
9 | + Binary Cross Entropy of (cosine similarity between audio and video embeddings) and the (real/fake) labels
10 |
11 | ## L1 Loss
12 |
13 | + Reconstruction loss : Mean Absolute Error
14 |
--------------------------------------------------------------------------------
/wav2mov/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .gan_loss import GANLoss
2 | from .sync_loss import SyncLoss
3 | from .l1_loss import L1_Loss
--------------------------------------------------------------------------------
/wav2mov/losses/gan_loss.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch import nn
4 |
5 | class GANLoss(nn.Module):
6 | """ To abstract away the task of creating real/fake labels and calculating loss
7 | [Reference]: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
8 | """
9 | def __init__(self,device,real_label=None,fake_label=0.0):
10 | super().__init__()
11 | self.real_label = real_label
12 | self.fake_label = fake_label
13 | # self.register_buffer('real_label',torch.tensor(real_label))
14 | # self.register_buffer('fake_label',torch.tensor(fake_label))
15 | self.loss = nn.BCEWithLogitsLoss()
16 | self.device = device
17 |
18 | def get_target_tensor(self,preds,is_real_target):
19 | real_label = torch.tensor( round(random.uniform(0.8,1),2) if self.real_label is None else self.real_label)
20 | target_tensor = real_label if is_real_target else torch.tensor(self.fake_label)
21 | return target_tensor.expand_as(preds).to(self.device)
22 |
23 | def forward(self,preds,is_real_target):
24 | target_tensor = self.get_target_tensor(preds,is_real_target)
25 | return self.loss(preds,target_tensor)
26 |
--------------------------------------------------------------------------------
/wav2mov/losses/l1_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class L1_Loss(nn.Module):
5 | def __init__(self):
6 | super().__init__()
7 | self.loss = nn.L1Loss()
8 | def forward(self,preds,targets):
9 | """
10 | preds and targets are of the shape = (N,C,H,W)
11 | """
12 | # height = preds.shape[-2]
13 | # preds = preds[...,height//2:,:]
14 | # targets = targets[...,height//2:,:]
15 | # # print(torch.mean(abs(preds-targets)).shape) #torch.size([])
16 | # return torch.mean(abs(preds-targets))#get mean of error of all batch samples
17 |
18 | return self.loss(preds,targets)
--------------------------------------------------------------------------------
/wav2mov/losses/sync_loss.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch import nn
4 |
5 | from wav2mov.logger import get_module_level_logger
6 | logger = get_module_level_logger(__name__)
7 |
8 | class SyncLoss(nn.Module):
9 | """Abstracts away the funciton of applying loss of synchronity between audio and video frames.
10 | [Reference]:
11 | https://github.com/Rudrabha/Wav2Lip/blob/master/hq_wav2lip_train.py#L181-L186
12 | """
13 | def __init__(self,device,real_label=None,fake_label=0.0):
14 | super().__init__()
15 | self.real_label = real_label
16 | self.fake_label = fake_label
17 | # self.register_buffer('real_label',torch.tensor(real_label))
18 | # self.register_buffer('fake_label',torch.tensor(fake_label))
19 | self.loss = nn.BCELoss()
20 | self.device = device
21 |
22 | def get_target_tensor(self,preds,is_real_target):
23 | real_label = torch.tensor( round(random.uniform(0.8,1),2) if self.real_label is None else self.real_label )
24 | target_tensor = if is_real_target else torch.tensor(self.fake_label)
25 | return target_tensor.expand_as(preds).to(self.device)
26 |
27 | def forward(self,preds,is_real_target):
28 | # preds = nn.functional.cosine_similarity(audio_embedding,video_embedding)
29 | target_tensor = self.get_target_tensor(preds,is_real_target)
30 | loss = self.loss(preds,target_tensor)
31 | logger.debug(f'[sync] loss {is_real_target} frames :{loss.item():0.4f}')
32 | return loss
33 |
--------------------------------------------------------------------------------
/wav2mov/main/data.py:
--------------------------------------------------------------------------------
1 | """ provides utils for datasets and dataloaders """
2 | import os
3 | import numpy as np
4 | from tqdm import tqdm
5 | from collections import namedtuple
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms as vtransforms
10 |
11 | from wav2mov.core.data.datasets import AudioVideoDataset
12 | from wav2mov.core.data.transforms import ResizeGrayscale,Normalize
13 | from wav2mov.logger import get_module_level_logger
14 | logger = get_module_level_logger(__name__)
15 |
16 | DataloadersPack = namedtuple('dataloaders',('train','val'))
17 |
18 |
19 | To_Grayscale = vtransforms.Grayscale(1)
20 | def get_dataset(options,config,hparams):
21 | hparams = hparams['data']
22 | root_dir = config['train_test_dataset_dir']
23 | filenames_train_txt = config['filenames_train_txt']
24 | filenames_test_txt = config['filenames_test_txt']
25 | video_fps = hparams['video_fps']
26 | audio_sf = hparams["audio_sf"]
27 | img_size = hparams['img_size']
28 | target_img_shape = (hparams['img_channels'],img_size,img_size)
29 |
30 | num_videos_train = int(options.num_videos*0.9)
31 | num_videos_test = options.num_videos - num_videos_train
32 |
33 | mean_std_train = get_mean_and_std(root_dir,filenames_train_txt,img_channels=1)
34 | to_gray_transform = ResizeGrayscale(target_img_shape)
35 | normalize = Normalize(mean_std_train)
36 | transforms_composed = vtransforms.Compose([to_gray_transform,normalize])
37 | dataset_train = AudioVideoDataset(root_dir=root_dir,
38 | filenames_text_filepath=filenames_train_txt,
39 | audio_sf=audio_sf,
40 | video_fps=video_fps,
41 | num_videos=num_videos_train,
42 | transform=transforms_composed)
43 | dataset_test = AudioVideoDataset(root_dir=root_dir,
44 | filenames_text_filepath=filenames_test_txt,
45 | audio_sf=audio_sf,
46 | video_fps=video_fps,
47 | num_videos=num_videos_test,
48 | transform=transforms_composed)
49 | return dataset_train,dataset_test
50 |
51 | def get_dataloaders(options,config,params,shuffle=True,collate_fn=None):
52 | hparams = params['data']
53 | batch_size = hparams['mini_batch_size']
54 | train_ds,test_ds = get_dataset(options,config,params)
55 | train_dl = DataLoader(train_ds,batch_size=batch_size,shuffle=shuffle,collate_fn=collate_fn,pin_memory=True)
56 | test_dl = DataLoader(test_ds,batch_size=1,shuffle=shuffle,collate_fn=collate_fn)
57 | return DataloadersPack(train_dl,test_dl)
58 |
59 |
60 | def get_video_mean_and_std(root_dir,filenames,img_channels):
61 | channels_sum,channels_squared_sum,num_batches = 0,0,0
62 | # num_items = 0
63 | progress_bar = tqdm(enumerate(filenames),ascii=True,total=len(filenames),desc='video')
64 | for _,filename in progress_bar:
65 | progress_bar.set_postfix({'file':filename})
66 | video_path = os.path.join(root_dir,filename,'video_frames.npy')
67 | video = torch.from_numpy(np.load(video_path))
68 | video = video/255 #of shape (F,H,W,C)
69 | if img_channels==1:
70 | video = video.permute(0,3,1,2)#F,C,H,W
71 | video = To_Grayscale(video)
72 | video = video.permute(0,2,3,1)
73 |
74 | channels_sum += torch.mean(video,dim=[0,1,2])
75 | #except for the channel dimension as we want mean and std
76 | # for each channel
77 | channels_squared_sum += torch.mean(video**2,dim=[0,1,2])
78 | # num_items += video.shape[0]
79 | num_batches += 1
80 | mean = channels_sum/num_batches
81 |
82 | std = ((channels_squared_sum/num_batches) - mean**2)**0.5
83 | return mean,std
84 |
85 | def get_audio_mean_and_std(root_dir,filenames):
86 | running_mean_sum , running_squarred_mean_sum = 0,0
87 | progress_bar = tqdm(enumerate(filenames),ascii=True,total=len(filenames),desc='audio')
88 | for _,filename in progress_bar:
89 | progress_bar.set_postfix({'file':filename})
90 | audio_path = os.path.join(root_dir,filename,'audio.npy')
91 | audio = torch.from_numpy(np.load(audio_path))
92 | running_mean_sum += torch.mean(audio)
93 | running_squarred_mean_sum += torch.mean(audio**2)
94 | mean = running_mean_sum/len(filenames)
95 | std = ((running_squarred_mean_sum/len(filenames))-mean**2)**0.5
96 | return mean,std
97 |
98 | def get_mean_and_std(root_dir,filenames_txt,img_channels):
99 | logger.debug('Calculating mean and standard deviation for the dataset.Please wait...')
100 | ret = {}
101 | #mean = E(X)
102 | #variance = E(X**2)- E(X)**2
103 | #standard deviation = variance**0.5
104 | filenames_path = os.path.join(root_dir,filenames_txt)
105 | with open(filenames_path) as file:
106 | filenames = file.read().split('\n')
107 | for i,name in enumerate(filenames[:]):
108 | if not name.strip():
109 | del filenames[i]
110 |
111 | ret['video'] = ([0.5]*img_channels,[0.5]*img_channels)
112 | ret['audio'] = get_audio_mean_and_std(root_dir,filenames)
113 | logger.debug(f'[MEAN and STANDARD_DEVIATION] {ret}')
114 | return ret
--------------------------------------------------------------------------------
/wav2mov/main/engine.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from wav2mov.core.engine import TemplateEngine
4 | from wav2mov.core.engine import CallbackEvents as Events
5 |
6 |
7 | class State:
8 | def __init__(self,names):
9 | self.names = names
10 | for name in self.names:
11 | setattr(self,name,None)
12 |
13 | def reset(self,names):
14 | for name in names:
15 | setattr(self,name,None)
16 |
17 |
18 | class Engine(TemplateEngine):
19 | def __init__(self,options,hparams,config,logger):
20 | super().__init__()
21 | self.logger = logger
22 | self.configure(options,hparams,config)
23 | self.state = State(['num_batches','cur_batch_size','epoch','batch_idx','start_epoch','logs'])
24 |
25 |
26 | def configure(self,options,hparams,config):
27 | self.hparams = hparams
28 | self.options = options
29 | self.config = config
30 |
31 | def load_checkpoint(self,model):
32 | prev_epoch = 0
33 | if getattr(self.options, 'model_path', None) is None:
34 | return prev_epoch
35 | loading_version = os.path.basename(self.options.model_path)
36 | self.logger.debug(f'Loading pretrained weights : {self.config.version} <== {loading_version}')
37 |
38 | prev_epoch = model.load(checkpoint_dir=self.options.model_path)
39 | if prev_epoch is None:
40 | prev_epoch = 0
41 | self.logger.debug(f'weights loaded successfully: {self.config.version} <== {loading_version}')
42 | return prev_epoch + 1
43 |
44 | def dispatch(self, event):
45 | super().dispatch(event,state=self.state)
46 |
47 | def run(self,model,dataloaders_ntuple,callbacks=None):
48 | callbacks = callbacks or []
49 | callbacks = [model] + callbacks
50 | self.register(callbacks)
51 | self.state.start_epoch = self.load_checkpoint(model)
52 |
53 | train_dl = dataloaders_ntuple.train
54 | self.state.num_batches = len(train_dl)
55 | num_epochs = self.hparams['num_epochs']
56 | self.dispatch(Events.RUN_START)
57 | self.dispatch(Events.TRAIN_START)
58 | for epoch in range(self.state.start_epoch,num_epochs):
59 | self.state.epoch = epoch
60 | self.dispatch(Events.EPOCH_START)
61 | for batch_idx,batch in enumerate(train_dl):
62 | self.state.batch_idx = batch_idx
63 | self.state.cur_batch_size = batch[0].shape[0] #makes the system tight coupled though!?
64 | self.dispatch(Events.BATCH_START)
65 | model.setup_input(batch,state=self.state)
66 | logs = model.optimize(state=self.state)
67 | self.state.logs = logs
68 | self.dispatch(Events.BATCH_END)
69 | self.state.reset(['logs'])
70 |
71 | self.dispatch(Events.EPOCH_END)
72 | self.dispatch(Events.TRAIN_END)
73 | self.dispatch(Events.RUN_END)
74 |
75 | def to_device(self,device,*args):
76 | return [arg.to(device) for arg in args]
77 |
78 | def test(self,model,test_dl):
79 | last_epoch = self.load_checkpoint(model)
80 | if last_epoch is None or last_epoch==0:
81 | self.logger.warning(f'Testing an untrained model !!!.')
82 | sample = next(iter(test_dl))
83 | audio,audio_frames,video = sample
84 | # audio,audio_frames,video = self.to_device('cpu',audio,audio_frames,video)
85 | model.test(audio,audio_frames,video)
86 |
87 |
88 |
--------------------------------------------------------------------------------
/wav2mov/main/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Main file of the Wav2Mov Project
3 | It is the entry point to various functions
4 | """
5 | import os
6 | import random
7 | import torch
8 | from torch.utils.tensorboard.writer import SummaryWriter
9 |
10 | from wav2mov.logger import Logger
11 | from wav2mov.config import get_config
12 |
13 | from wav2mov.params import params
14 |
15 | from wav2mov.main.preprocess import create_from_grid_dataset
16 | from wav2mov.main.train import train_model
17 | from wav2mov.main.test import test_model
18 |
19 | from wav2mov.main.options import Options,set_options
20 | from wav2mov.main.validate_params import check_batchsize
21 | torch.manual_seed(params['seed'])
22 | random.seed(params['seed'])
23 |
24 | def get_logger(config,filehandler_required=False):
25 | local_logger = Logger(__name__)
26 | local_logger.add_console_handler()
27 | if filehandler_required:
28 | local_logger.add_filehandler(config['log_fullpath'],in_json=True)
29 | return local_logger
30 |
31 | def preprocess(preprocess_logger,config):
32 | if options.grid_dataset_dir:
33 | config.update('grid_dataset_dir',options.grid_dataset_dir)
34 | create_from_grid_dataset(config,preprocess_logger)
35 |
36 | def train(train_logger,args_options,config):
37 | train_model(args_options,params,config,train_logger)
38 |
39 | def test(test_logger,args_options,config):
40 | test_model(args_options,params,config,test_logger)
41 |
42 | def save_message(options,config,logger):
43 | if not getattr(options,'msg'):return
44 | if options.log=='n':return
45 | path = os.path.join(os.path.dirname(config['log_fullpath']),f'message_{config.version}.txt')
46 | logger.debug('message written to %(path)s'%{'path':path})
47 | with open(path,'a+') as file:
48 | file.write(options.msg)
49 | file.write('\n')
50 |
51 | def main(config):
52 |
53 | allowed = ('y', 'yes')
54 | if options.preprocess in allowed:
55 | preprocess(logger,config)
56 |
57 | if options.train in allowed:
58 | required = ('num_videos', 'num_epochs')
59 |
60 | if not all(getattr(options, option) for option in required):
61 | raise RuntimeError('Cannot train without the options :', required)
62 | check_batchsize(hparams=params['data'])
63 | train(logger, options,config)
64 |
65 | if options.test in allowed:
66 | required = ('model_path',)
67 |
68 | if not all(getattr(options, option) for option in required):
69 | raise RuntimeError(
70 | f'Cannot test without the options : {required}')
71 | test(logger, options,config)
72 |
73 | if __name__ == '__main__':
74 |
75 |
76 | options = Options().parse()
77 | config = get_config(options.version)
78 | set_options(options,params)
79 | logger = get_logger(config,filehandler_required=options.log in ['y', 'yes'])
80 | save_message(options,config,logger)
81 | try:
82 | main(config)
83 | except Exception as e:
84 | if options.train in ['y','yes']:
85 | params.save(config['params_checkpoint_fullpath'])
86 | logger.exception(e)
87 |
88 | finally:
89 | logger.cleanup()
90 |
91 |
92 |
--------------------------------------------------------------------------------
/wav2mov/main/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | class Options():
4 | def __init__(self):
5 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
6 | description='Wav2Mov | End to End Speech to facial animation model')
7 |
8 | self.parser.add_argument('--version',
9 | type=str,
10 | help='version of the file being run. Example : v1,v2,...',
11 | required=True)
12 |
13 | self.parser.add_argument('--log',
14 | default='y',
15 | choices=['y','n','yes','no'],
16 | type=str,
17 | help='whether to initialize logger')
18 |
19 | self.parser.add_argument('--train',
20 | default='n',
21 | choices=['y','n','yes','no'],
22 | type=str,
23 | help='whether to train the model')
24 |
25 | self.parser.add_argument('--device','-d',
26 | default='cpu',
27 | choices=['cpu','cuda'],
28 | type=str,
29 | help='device on which model operations are done',
30 | required=True)
31 |
32 | self.parser.add_argument('--num_epochs','-e',
33 | type=int,
34 | help='device on which model operations are done',
35 | )
36 |
37 | self.parser.add_argument('--test',
38 | default='n',
39 | choices=['y','n','yes','no'],
40 | type=str,
41 | help='run test script')
42 | self.parser.add_argument('--train_sync_expert',
43 | default='n',
44 | choices=['y','n','yes','no'],
45 | type=str,
46 | help='only train sync disc')
47 |
48 | self.parser.add_argument('--preprocess',
49 | default='n',
50 | choices=['y','n','yes','no'],
51 | type=str,help='run preprocess script')
52 |
53 | self.parser.add_argument('--grid_dataset_dir','-grid',
54 | type=str,
55 | help="path of raw dataset")
56 | # self.parser.add_argument('--device',
57 | # default='cuda',
58 | # choices=['cpu','cuda'],
59 | # type=str,help='device where processing should be done')
60 |
61 | self.parser.add_argument('--model_path','-path',
62 | type=str,
63 | help='generator checkpoint fullpath')
64 |
65 | self.parser.add_argument('--num_videos','-v',
66 | type=int,
67 | help='num of videos on which the model should be trained')
68 |
69 |
70 | self.parser.add_argument('--msg','-m',
71 | type=str,
72 | help='any message about current run')
73 |
74 | self.parser.add_argument('--test_sample_num','-snum',
75 | type=int,
76 | help='sample to be taken from test dataloader',
77 | )
78 |
79 | def parse(self):
80 | return self.parser.parse_args()
81 |
82 | def set_device(options, params):
83 | device = options.device
84 | params.set('device', device)
85 | params.set('gen', {**params['gen'], 'device': device})
86 |
87 | def set_epochs(options,params):
88 | num_epochs = options.num_epochs
89 | params.set('num_epochs',num_epochs)
90 |
91 | def set_train_sync_expert(options,params):
92 | train_sync_expert = True if options.train_sync_expert in ('y','yes') else False
93 | params.set('train_sync_expert',train_sync_expert)
94 |
95 | def set_options(options, params):
96 | set_device(options,params)
97 | set_epochs(options,params)
98 | set_train_sync_expert(options,params)
99 |
--------------------------------------------------------------------------------
/wav2mov/main/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | Module for creating and saving numoy file from raw dataset
3 | + mouth landmarks and audio frames are extracted.
4 | """
5 |
6 | import numpy as np
7 | import os
8 |
9 |
10 | from wav2mov.core.data import RawDataset, GridDataset
11 | from wav2mov.datasets import create_file_list as create_file_list_module
12 |
13 | # currently these are not applied anywhere
14 | # audio file in util to see which values are being used instead.
15 | video_frame_rate = 30
16 | audio_sampling_rate = 16_000
17 | win_len = 0.033
18 | win_hop = 0.033
19 |
20 |
21 |
22 | def create(raw_dataset: RawDataset,config,logger) -> str:
23 | """ Creates numpy file containing video and audio frames from given dataset
24 |
25 | Args:
26 |
27 | + `raw_dataset` (RawDataset): either GridDataset or Ravdess dataset
28 |
29 | + `config` : Config object containing different location information
30 |
31 | Returns:
32 |
33 | + `str`: log string containing info about saved file.
34 | """
35 |
36 | samples_count = raw_dataset.samples_count
37 | dataset_dir = config['train_test_dataset_dir']
38 | filenames = []
39 |
40 | for sample in raw_dataset.generator(get_filepath_only=False,img_size=(256,256),show_progress_bar=True):
41 | audio_filepath,audio_vals = sample.audio
42 | _,video_frames = sample.video
43 | folder_name = os.path.basename(audio_filepath).split('.')[0]
44 |
45 |
46 | if raw_dataset.name in os.path.basename(dataset_dir):
47 | dest_dir = os.path.join(dataset_dir,folder_name)
48 | else:
49 | dest_dir = os.path.join(dataset_dir,f'{raw_dataset.name}',folder_name)
50 |
51 | os.makedirs(dest_dir,exist_ok=True)
52 | np.save(os.path.join(dest_dir,'video_frames.npy'),
53 | np.array(video_frames))
54 | np.save(os.path.join(dest_dir,'audio.npy'),
55 | audio_vals)
56 | filenames.append(folder_name + '\n')
57 |
58 | filenames[-1] = filenames[-1].strip('\n')
59 | with open(os.path.join(dataset_dir, 'filenames.txt'),'a+') as file:
60 | file.writelines(filenames)
61 |
62 | log = f'Samples generated : {samples_count}\n'
63 | log += f'Location : { dataset_dir }\n'
64 | log += f'Filenames are listed in filenames.txt\n'
65 | logger.info(log)
66 | create_file_list_module.main()
67 | logger.debug('train and test filelists created')
68 |
69 | return log
70 |
71 |
72 | def create_from_grid_dataset(config,logger):
73 | dataset = GridDataset(config['grid_dataset_dir'],
74 | audio_sampling_rate=audio_sampling_rate,
75 | video_frame_rate=video_frame_rate,
76 | samples_count=125)
77 |
78 | create(dataset,config,logger)
79 | print(f'{dataset.__class__.__name__} successfully processed.')
80 | # logger.info(log)
81 |
--------------------------------------------------------------------------------
/wav2mov/main/test.py:
--------------------------------------------------------------------------------
1 | # import logging
2 | import os
3 | import re
4 | import torch
5 | from torchvision import utils as vutils
6 | # from tqdm import tqdm
7 | from scipy.io.wavfile import write
8 | from wav2mov.core.data.collates import get_batch_collate
9 | from wav2mov.models.wav2mov_inferencer import Wav2movInferencer
10 | from wav2mov.main.data import get_dataloaders
11 | from wav2mov.utils.plots import save_gif,save_video
12 |
13 | SAMPLE_NUM = 5
14 |
15 | def squeeze_frames(video):
16 | batch_size,num_frames,*extra = video.shape
17 | return video.reshape(batch_size*num_frames,*extra)
18 |
19 | def denormalize_frames(frames):
20 | return ((frames*0.5)+0.5)*255
21 |
22 | def make_path_compatible(path):
23 | if os.sep != '\\':#not windows
24 | return re.sub(r'(\\)+',os.sep,path)
25 |
26 | def test_sample(model,dl,options,hparams,config,logger,suffix):
27 | global SAMPLE_NUM
28 | if hasattr(options,'test_sample_num'):
29 | sample_num = options.test_sample_num
30 | else:
31 | sample_num = SAMPLE_NUM
32 | SAMPLE_NUM = sample_num
33 | logger.debug(f'FOR {suffix} sample')
34 | version = os.path.basename(options.model_path).strip('gen_').split('.')[0]
35 | out_dir = os.path.join(config['out_dir'],version)
36 | logger.debug(f'[OUTPUT DIR] {out_dir}')
37 | logger.debug(options.version)
38 | sample_iter = (iter(dl))
39 | sample = next(sample_iter)
40 |
41 | for _ in range(min(sample_num-1,len(dl)-1)):
42 | sample = next(sample_iter)
43 |
44 | # print(sample,len(sample))
45 | audio,audio_frames,video = sample
46 | batch_size = audio.shape[0]
47 | if batch_size>1:
48 | audio,audio_frames,video = (audio[0].unsqueeze(0),
49 | audio_frames[0].unsqueeze(0),
50 | video[0].unsqueeze(0))
51 | fake_video_frames,ref_video_frame = model.test(audio_frames,video,get_ref_video_frame=True)
52 | fake_video_frames = squeeze_frames(fake_video_frames)
53 | video = squeeze_frames(video)
54 | os.makedirs(out_dir,exist_ok=True)
55 | save_path_fake_video_frames = os.path.join(out_dir,f'{suffix}_fake_frames_{version}.png')
56 | save_path_real_video_frames = os.path.join(out_dir,f'{suffix}_real_frames_{version}.png')
57 | save_path_ref_video_frame = os.path.join(out_dir,f'{suffix}_ref_frame_{version}.png')
58 | save_path_fake_video_frames = make_path_compatible(save_path_fake_video_frames)
59 | save_path_real_video_frames = make_path_compatible(save_path_real_video_frames)
60 | save_path_ref_video_frame = make_path_compatible(save_path_ref_video_frame)
61 |
62 | save_video(hparams['data'],os.path.join(out_dir,f'{suffix}_fake_video.avi'),audio,denormalize_frames(fake_video_frames))
63 | logger.debug(f'video saved : {suffix}_{fake_video_frames.shape}')
64 |
65 | vutils.save_image(denormalize_frames(ref_video_frame),save_path_ref_video_frame,normalize=True)
66 | vutils.save_image(denormalize_frames(fake_video_frames),save_path_fake_video_frames,normalize=True)
67 | vutils.save_image(denormalize_frames(video),save_path_real_video_frames,normalize=True)
68 |
69 | gif_path = os.path.join(out_dir,f'{suffix}_fake_frames_{version}.gif')
70 | save_gif(gif_path,denormalize_frames(fake_video_frames))
71 | gif_path = os.path.join(out_dir,f'{suffix}_real_frames_{version}.gif')
72 | save_gif(gif_path,denormalize_frames(video))
73 |
74 | def test_model(options,hparams,config,logger):
75 | logger.debug(f'Testing model...')
76 | version = os.path.basename(options.model_path).strip('gen_').split('.')[0]
77 | logger.debug(f'loading version : {version}')
78 | out_dir = os.path.join(config['out_dir'],version)
79 |
80 | model = Wav2movInferencer(hparams,logger)
81 | checkpoint = torch.load(options.model_path,map_location='cpu')
82 | state_dict,last_epoch = checkpoint['state_dict'],checkpoint['epoch']
83 | model.load(state_dict)
84 | logger.debug(f'model was trained for {last_epoch+1} epochs. ')
85 | collate_fn = get_batch_collate(hparams['data'])
86 | logger.debug('Loading dataloaders')
87 | dataloaders = get_dataloaders(options,config,hparams,collate_fn=collate_fn)
88 | train_dl =dataloaders.train
89 | test_dl = dataloaders.val
90 | # write(os.path.join(out_dir,f'audio_{SAMPLE_NUM}.wav'),16000,audio.cpu().numpy().reshape(-1))
91 | # logger.debug(f'audio saved : audio_{SAMPLE_NUM}.wav')
92 | test_sample(model,train_dl,options,hparams,config,logger,'train')
93 | test_sample(model,test_dl,options,hparams,config,logger,'test')
94 |
95 | logger.debug(f'results are saved in {out_dir}')
96 | msg = "#"*25
97 | msg += f'\ntest_run for version {version}.\n'
98 | msg += '='*25
99 | msg += f'\nlast epoch : {last_epoch+1}\n'
100 | msg += f'curr_version : {config.version}\n'
101 | msg += f'sample num : {SAMPLE_NUM}\n'
102 | msg += "#"*25
103 |
104 | with open(os.path.join(out_dir,'info.txt'),'a+') as file:
105 | file.write(msg)
106 |
107 |
--------------------------------------------------------------------------------
/wav2mov/main/train.py:
--------------------------------------------------------------------------------
1 | from wav2mov.models.wav2mov_trainer import Wav2MovTrainer
2 | from wav2mov.main.engine import Engine
3 | from wav2mov.main.callbacks import LossMetersCallback,ModelCheckpoint,TimeTrackerCallback,LoggingCallback
4 |
5 | from wav2mov.core.data.collates import get_batch_collate
6 | from wav2mov.main.data import get_dataloaders
7 |
8 |
9 |
10 |
11 | def train_model(options,hparams,config,logger):
12 | engine = Engine(options,hparams,config,logger)
13 | model = Wav2MovTrainer(hparams,config,logger)
14 | collate_fn = get_batch_collate(hparams['data'])
15 | dataloaders_ntuple = get_dataloaders(options,config,hparams,
16 | collate_fn=collate_fn)
17 | callbacks = [LossMetersCallback(options,hparams,config,logger,
18 | verbose=True),
19 | LoggingCallback(options,hparams,config,logger),
20 | TimeTrackerCallback(hparams,logger),
21 | ModelCheckpoint(model,hparams,config,logger,
22 | save_every=5)]
23 |
24 | engine.run(model,dataloaders_ntuple,callbacks)
--------------------------------------------------------------------------------
/wav2mov/main/trained.txt:
--------------------------------------------------------------------------------
1 | E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\runs\v6\Run_7_4_2021__18_24\gen_Run_7_4_2021__18_24.pt
--------------------------------------------------------------------------------
/wav2mov/main/validate_params.py:
--------------------------------------------------------------------------------
1 | def check_batchsize(hparams):
2 | if hparams['batch_size']%hparams['mini_batch_size']:
3 | raise ValueError(f'Batch size must be evenly divisible by mini_batch_size\n'
4 | f'Currently batch_size : {hparams["batch_size"]} mini_batch_size :{hparams["mini_batch_size"]}')
--------------------------------------------------------------------------------
/wav2mov/models/README.md:
--------------------------------------------------------------------------------
1 | # Architectures
2 |
3 | ## Generator
4 |
5 | * follows UNET architecture
6 | * input : audio_frames (B,F,Sw) and ref_frames (B,F,C,H,W)
7 | * out is IMAGE_SIZE*IMAGE_SIZE
8 |
9 | ## identity Descriminator
10 |
11 | * input is real/fake frame and still face image(ref face image)
12 | * helps to retain identity of face of given reference image and to produce realistic images.
13 |
14 | ## sequence Descriminator
15 |
16 | * input is consecutive frames
17 | * helps to produce cohesive video frames.
18 |
19 | ## sync Descriminator
20 |
21 | * input is audio frame(666 points*5) and corresponding video frames.
22 | * helps to synchronize produced video frames with that of audio.
23 |
24 |
--------------------------------------------------------------------------------
/wav2mov/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .discriminators.identity_discriminator import IdentityDiscriminator
2 | from .discriminators.sequence_discriminator import SequenceDiscriminator
3 | from .discriminators.sync_discriminator import SyncDiscriminator
4 | from .discriminators.patch_disc import PatchDiscriminator
5 | from .generator.frame_generator import Generator
--------------------------------------------------------------------------------
/wav2mov/models/discriminators/identity_discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn,optim
3 | from torch.nn import functional as F
4 |
5 | from wav2mov.core.models.base_model import BaseModel
6 | from wav2mov.models.utils import squeeze_batch_frames,get_same_padding
7 | from wav2mov.models.layers.conv_layers import Conv2dBlock
8 |
9 | from wav2mov.logger import get_module_level_logger
10 | logger = get_module_level_logger(__name__)
11 |
12 | class IdentityDiscriminator(BaseModel):
13 | def __init__(self,hparams):
14 |
15 | super().__init__()
16 | self.hparams = hparams
17 | in_channels = hparams['in_channels']*2
18 | chs = self.hparams['chs']
19 | chs = [in_channels] + chs
20 | padding = get_same_padding(kernel_size=4,stride=2)
21 | relu_neg_slope = self.hparams['relu_neg_slope']
22 | self.conv_blocks = nn.ModuleList(Conv2dBlock(chs[i],chs[i+1],
23 | kernel_size=(4,4),
24 | stride=2,
25 | padding=padding,
26 | use_norm=True,
27 | use_act=True,
28 | act=nn.LeakyReLU(relu_neg_slope)
29 | ) for i in range(len(chs)-2)
30 | )
31 |
32 | self.conv_blocks.append(Conv2dBlock(chs[-2],chs[-1],
33 | kernel_size=(4,4),
34 | stride=2,
35 | padding=padding,
36 | use_norm=False,
37 | use_act=False
38 | )
39 | )
40 |
41 | def forward(self,x,y):
42 | """
43 | x : frame image (B,F,H,W)
44 | y : still image
45 | """
46 | assert x.shape==y.shape
47 |
48 | if len(x.shape)>4:#frame dim present
49 | x = squeeze_batch_frames(x)
50 | y = squeeze_batch_frames(y)
51 |
52 | x = torch.cat([x,y],dim=1)#along channels
53 | for block in self.conv_blocks:
54 | x = block(x)
55 | return x
56 |
57 | def get_optimizer(self):
58 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5,0.999))
59 |
--------------------------------------------------------------------------------
/wav2mov/models/discriminators/patch_disc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn,optim
3 | from wav2mov.models.utils import squeeze_batch_frames
4 |
5 | from wav2mov.logger import get_module_level_logger
6 | logger = get_module_level_logger(__name__)
7 |
8 |
9 | class PatchDiscriminator(nn.Module):
10 | """
11 | ref : https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
12 | """
13 | def __init__(self,hparams,norm_layer=nn.BatchNorm2d,use_bias = True):
14 | super().__init__()
15 | self.hparams = hparams
16 | input_nc = hparams['in_channels']
17 | ndf = hparams['ndf']
18 | n_layers = hparams['num_layers']
19 | use_bias = norm_layer==nn.InstanceNorm2d
20 | #batch normalization has affine = True so it has additional scaling and biasing/shifting terms by default.
21 | kw = 4
22 | padw = 1
23 | sequence = [nn.Conv2d(input_nc*2,ndf,kw,2,padw),
24 | nn.LeakyReLU(0.2,inplace=True)]#inplace=True saves memory
25 |
26 | nf_mult = 1
27 | nf_mult_prev = 1
28 |
29 | for n in range(1,n_layers):
30 | nf_mult_prev = nf_mult
31 | nf_mult = min(2**n,8) #multiplier is clipped at 8 : so if ndf is 64 max possible is 64*8=512
32 | sequence += [
33 | nn.Conv2d(ndf*nf_mult_prev,ndf*nf_mult,kw,2,padw,bias=use_bias),
34 | norm_layer(ndf*nf_mult),
35 | nn.LeakyReLU(0.2,True)
36 | ]
37 | nf_mult_prev = nf_mult
38 | nf_mult = min(2**n_layers,8)
39 | sequence += [
40 | nn.Conv2d(ndf*nf_mult_prev,ndf*nf_mult,kw,1,padw,bias=use_bias),
41 | norm_layer(ndf*nf_mult),
42 | nn.LeakyReLU(0.2,True)
43 | ]
44 |
45 | sequence += [
46 | nn.Conv2d(ndf*nf_mult,1,kw,1,padw)
47 | ]
48 |
49 | self.disc = nn.Sequential(*sequence)
50 |
51 |
52 | def forward(self,frame_image,ref_image):
53 | assert frame_image.shape == ref_image.shape
54 | batch_size = frame_image.shape[0]
55 | is_frame_dim_present = False
56 | if len(frame_image.shape)>4:
57 | is_frame_dim_present = True
58 | batch_size,num_frames,*img_shape = frame_image.shape
59 | frame_image = squeeze_batch_frames(frame_image)
60 | ref_image = squeeze_batch_frames(ref_image)
61 | frame_image = torch.cat([frame_image,ref_image],dim=1)#catenate images along the channel dimension
62 | #frame image now has a shape of (N,C+C,H,W)
63 | return self.disc(frame_image)
64 |
65 | def get_optimizer(self):
66 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5, 0.999))
67 |
--------------------------------------------------------------------------------
/wav2mov/models/discriminators/sequence_discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn,optim
3 |
4 | from wav2mov.core.models.base_model import BaseModel
5 | from wav2mov.models.layers.conv_layers import Conv2dBlock
6 | from wav2mov.models.utils import get_same_padding,squeeze_batch_frames
7 |
8 | from wav2mov.logger import get_module_level_logger
9 | logger = get_module_level_logger(__name__)
10 |
11 | class SequenceDiscriminator(BaseModel):
12 | def __init__(self,hparams):
13 | super().__init__()
14 | self.hparams = hparams
15 | relu_neg_slope = self.hparams['relu_neg_slope']
16 | in_size, h_size, num_layers = self.hparams['in_size'],self.hparams['h_size'],self.hparams['num_layers']
17 | self.gru = nn.GRU(input_size=in_size,hidden_size=h_size,num_layers=num_layers,batch_first = True)
18 | in_channels = self.hparams['in_channels']
19 | chs = [in_channels] + self.hparams['chs']
20 | kernel,stride = 4,2
21 | padding = get_same_padding(kernel,stride)
22 | cnn = nn.ModuleList([Conv2dBlock(chs[i],chs[i+1],kernel,stride,padding,
23 | use_norm=True,use_act=True,
24 | act=nn.LeakyReLU(relu_neg_slope)) for i in range(len(chs)-2)])
25 | cnn.append(Conv2dBlock(chs[-2],chs[-1],kernel,stride,padding,
26 | use_norm=False,use_act=True,
27 | act=nn.Tanh()))
28 | self.cnn = nn.Sequential(*cnn)
29 | ############################################
30 | # channels : 3 => 64 => 128 => 256 => 512
31 | # frame sz : 256=> 128 => 64 => 32 => 16 =>8 = Width and height =4 (since only upper half)
32 | # height is half so : final height is 4
33 | # thus out of self.cnn : 512x4x8
34 | ############################################
35 |
36 | def forward(self,frames):
37 | """frames : B,T,C,H,W"""
38 | img_height = frames.shape[-2]
39 | frames = frames[...,0:img_height//2,:]#consider upper half
40 | batch_size,num_frames,*img_size = frames.shape
41 | frames = squeeze_batch_frames(frames)
42 | frames = self.cnn(frames)
43 | frames = frames.reshape(batch_size,num_frames,-1)
44 | out,_ = self.gru(frames)#out is of shape (batch_size,seq_len,num_dir*hidden_dim)
45 | return out[:,-1,:]#batch_size,hidden_size
46 |
47 | def get_optimizer(self):
48 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5,0.999))
49 |
--------------------------------------------------------------------------------
/wav2mov/models/discriminators/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import torch
4 | from torchvision import utils as vutils
5 |
6 | def squeeze_frames(video):
7 | batch_size,num_frames,*extra = video.shape
8 | return video.reshape(batch_size*num_frames,*extra)
9 |
10 | def denormalize_frames(frames):
11 | return ((frames*0.5)+0.5)*255
12 |
13 | def make_path_compatible(path):
14 | if os.sep != '\\':#not windows
15 | return re.sub(r'(\\)+',os.sep,path)
16 |
17 |
18 | def save_series(frames,config,i=0):
19 | frames = squeeze_frames(frames)
20 | save_dir = os.path.join(config['runs_dir'],'images')
21 | os.makedirs(save_dir,exist_ok=True)
22 | save_path = os.path.join(save_dir,f'sync_frames_fake_{i}.png')
23 | vutils.save_image(denormalize_frames(frames),save_path,normalize=True)
--------------------------------------------------------------------------------
/wav2mov/models/generator/audio_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from wav2mov.models.layers.conv_layers import Conv1dBlock,Conv2dBlock
4 | from wav2mov.models.layers.debug_layers import IdentityDebugLayer
5 | from wav2mov.logger import get_module_level_logger
6 | from wav2mov.models.utils import squeeze_batch_frames,get_same_padding
7 |
8 | logger = get_module_level_logger(__name__)
9 |
10 | class AudioEnocoder(nn.Module):
11 | def __init__(self,hparams):
12 | super().__init__()
13 | self.hparams = hparams
14 | padding_31 = get_same_padding(3,1)
15 | padding_32 = get_same_padding(3,2)
16 | padding_42 = get_same_padding(4,2)
17 | # each frame has mfcc of shape (7,13)
18 | self.conv_encoder = nn.Sequential(
19 | Conv2dBlock(1,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()),#7,13
20 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
21 | Conv2dBlock(32,32,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
22 |
23 | Conv2dBlock(32,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13
24 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
25 | Conv2dBlock(64,64,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
26 |
27 | Conv2dBlock(64,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU()), #7,13
28 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
29 | Conv2dBlock(128,128,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
30 |
31 | Conv2dBlock(128,256,3,(2,1),padding=padding_32,use_norm=True,use_act=True,act=nn.ReLU()), #4,13
32 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
33 | Conv2dBlock(256,256,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
34 |
35 | Conv2dBlock(256,512,(4,3),(2,1),padding=1,use_norm=True,use_act=True,act=nn.ReLU()), #2,13
36 | Conv2dBlock(512,512,3,1,padding=padding_31,use_norm=True,use_act=True,act=nn.ReLU(),residual=True),
37 | Conv2dBlock(512,512,(2,5),(1,2),padding=(0,1),use_norm=True,use_act=True,act=nn.ReLU()),#1,6,
38 | )
39 |
40 | self.features_len = 6*512#out of conv layers
41 | self.num_layers = 1
42 | self.hidden_size = self.hparams['latent_dim_audio']
43 | self.gru = nn.GRU(input_size=self.features_len,
44 | hidden_size=self.hidden_size,
45 | num_layers=1,
46 | batch_first=True)
47 | self.final_act = nn.Tanh()
48 |
49 | def forward(self, x):
50 | """ x : audio frames of shape B,T,t,13"""
51 | batch_size,num_frames,*_ = x.shape
52 | x = self.conv_encoder(squeeze_batch_frames(x).unsqueeze(1))#add channel dimension
53 | x = x.reshape(batch_size,num_frames,self.features_len)
54 | x,_ = self.gru(x)
55 | return self.final_act(x) # shape (batch_size,num_frames,hidden_size=latent_dim_audio)
56 |
--------------------------------------------------------------------------------
/wav2mov/models/generator/frame_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn,optim
3 |
4 | from wav2mov.models.generator.audio_encoder import AudioEnocoder
5 | from wav2mov.models.generator.noise_encoder import NoiseEncoder
6 | from wav2mov.models.generator.id_encoder import IdEncoder
7 | from wav2mov.models.generator.id_decoder import IdDecoder
8 |
9 | from wav2mov.models.utils import squeeze_batch_frames
10 | from wav2mov.logger import get_module_level_logger
11 | logger = get_module_level_logger(__name__)
12 |
13 | class Generator(nn.Module):
14 | def __init__(self,hparams):
15 | super().__init__()
16 | self.hparams = hparams
17 | self.id_encoder = IdEncoder(self.hparams)
18 | self.id_decoder = IdDecoder(self.hparams)
19 | self.audio_encoder = AudioEnocoder(self.hparams)
20 | # self.noise_encoder = NoiseEncoder(self.hparams)
21 |
22 | def forward(self,audio_frames,ref_frames):
23 | batch_size,num_frames,*_ = ref_frames.shape
24 | assert num_frames == audio_frames.shape[1]
25 | encoded_id , intermediates = self.id_encoder(squeeze_batch_frames(ref_frames))
26 | encoded_id = encoded_id.reshape(batch_size*num_frames,-1,1,1)
27 | encoded_audio = self.audio_encoder(audio_frames).reshape(batch_size*num_frames,-1,1,1)
28 | # encoded_noise = self.noise_encoder(batch_size,num_frames).reshape(batch_size*num_frames,-1,1,1)
29 | # logger.debug(f'encoded_id {encoded_id.shape} encoded_audio {encoded_audio.shape} encoded_noise {encoded_noise.shape}')
30 | encoded = torch.cat([encoded_id,encoded_audio],dim=1)#along channel dimension
31 | gen_frames = self.id_decoder(encoded,intermediates)
32 | _,*img_shape = gen_frames.shape
33 | return gen_frames.reshape(batch_size,num_frames,*img_shape)
34 |
35 | def get_optimizer(self):
36 | return optim.Adam(self.parameters(), lr=self.hparams['lr'], betas=(0.5, 0.999))
37 |
--------------------------------------------------------------------------------
/wav2mov/models/generator/id_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from wav2mov.models.layers.conv_layers import Conv2dBlock,ConvTranspose2dBlock,DoubleConvTranspose2d
5 | from wav2mov.logger import get_module_level_logger
6 | from wav2mov.models.utils import get_same_padding
7 |
8 | logger = get_module_level_logger(__name__)
9 |
10 | class IdDecoder(nn.Module):
11 | def __init__(self,hparams):
12 | super().__init__()
13 | self.hparams = hparams
14 | self.in_channels = hparams['in_channels']
15 | self.latent_channels = hparams['latent_dim']
16 | latent_id_height,latent_id_width = hparams['latent_dim_id']
17 | chs = self.hparams['chs'][::-1]#1024,512,256,128,64
18 | chs = [self.latent_channels] + chs + [self.in_channels]
19 | padding = get_same_padding(kernel_size=4,stride=2)
20 | self.convs = nn.ModuleList()
21 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[0],
22 | out_ch=chs[1],
23 | kernel_size=(latent_id_height,latent_id_width),
24 | stride=2,
25 | padding=0,
26 | use_norm=True,
27 | use_act=True,
28 | act=nn.ReLU()
29 | ))
30 | for i in range(1,len(chs)-2):
31 | self.convs.append( DoubleConvTranspose2d(in_ch=chs[i],
32 | skip_ch=chs[i],
33 | out_ch=chs[i+1],
34 | kernel_size=(4,4),
35 | stride=2,
36 | padding=padding,
37 | use_norm=True,
38 | use_act=True,
39 | act=nn.ReLU())
40 | )
41 | self.convs.append(ConvTranspose2dBlock(in_ch=chs[-2],
42 | out_ch=chs[-1],
43 | kernel_size=(4,4),
44 | stride=2,
45 | padding=padding,
46 | use_norm=True,
47 | use_act=True,
48 | act=nn.Tanh()
49 | ))
50 | def forward(self,encoded,skip_outs):
51 | """[summary]
52 |
53 | Args:
54 | skip_outs ([type]): 64,128,256,512,1024
55 | encoded ([type]): latent_channels,
56 | """
57 | encoded = encoded.reshape(encoded.shape[0],self.latent_channels,1,1)
58 | skip_outs = skip_outs[::-1]
59 | for i,block in enumerate(self.convs[:-2]):
60 | x = block(encoded)
61 | encoded = torch.cat([x,skip_outs[i]],dim=1)#cat along channel axis
62 |
63 | encoded = self.convs[-2](encoded)
64 | return self.convs[-1](encoded)
65 |
--------------------------------------------------------------------------------
/wav2mov/models/generator/id_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from wav2mov.models.layers.conv_layers import Conv2dBlock
5 | from wav2mov.logger import get_module_level_logger
6 | from wav2mov.models.utils import get_same_padding
7 |
8 | logger = get_module_level_logger(__name__)
9 |
10 | class IdEncoder(nn.Module):
11 | def __init__(self,hparams):
12 | super().__init__()
13 | self.hparams = hparams
14 | in_channels = hparams['in_channels']
15 | chs = self.hparams['chs']
16 | chs = [in_channels] + chs +[1]# 1 is added here not in params because see how channels are being used in id_decoder
17 | padding = get_same_padding(kernel_size=4,stride=2)
18 | self.conv_blocks = nn.ModuleList(Conv2dBlock(chs[i],chs[i+1],
19 | kernel_size=(4,4),
20 | stride=2,
21 | padding=padding,
22 | use_norm=True,
23 | use_act=True,
24 | act=nn.ReLU()
25 | ) for i in range(len(chs)-2)
26 | )
27 |
28 | self.conv_blocks.append(Conv2dBlock(chs[-2],chs[-1],
29 | kernel_size=(4,4),
30 | stride=2,
31 | padding=padding,
32 | use_norm=False,
33 | use_act=True,
34 | act=nn.Tanh()))
35 | def forward(self,images):
36 | intermediates = []
37 | for block in self.conv_blocks[:-1]:
38 | images = block(images)
39 | intermediates.append(images)
40 | encoded = self.conv_blocks[-1](images)
41 | return encoded,intermediates
42 |
--------------------------------------------------------------------------------
/wav2mov/models/generator/noise_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from wav2mov.logger import get_module_level_logger
5 | from wav2mov.models.utils import squeeze_batch_frames
6 |
7 | logger = get_module_level_logger(__name__)
8 |
9 | class NoiseEncoder(nn.Module):
10 | def __init__(self,hparams):
11 | super().__init__()
12 | self.hparams = hparams
13 | self.features_len = 10
14 | self.hidden_size = self.hparams['latent_dim_noise']
15 | self.gru = nn.GRU(input_size=self.features_len,
16 | hidden_size=self.hidden_size,
17 | num_layers=1,
18 | batch_first=True)
19 | #input should be of shape batch_size,seq_len,input_size
20 | def forward(self,batch_size,num_frames):
21 | noise = torch.randn(batch_size,num_frames,self.features_len,device=self.hparams['device'])
22 | out,_ = self.gru(noise)
23 | return out#(batch_size,seq_len,hidden_size)
24 |
--------------------------------------------------------------------------------
/wav2mov/models/layers/conv_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from wav2mov.models.utils import get_same_padding
5 |
6 | class DoubleConv2dBlock(nn.Module):
7 | """height and width are halved in feature map"""
8 | def __init__(self,
9 | in_ch,
10 | out_ch,
11 | use_norm=True,
12 | act=None):
13 | super().__init__()
14 | self.use_norm = use_norm
15 | self.conv1 = nn.Conv2d(in_ch, out_ch, 3,padding=get_same_padding(3,1),bias=not self.use_norm)
16 | self.batch_norm = nn.BatchNorm2d(out_ch)
17 | self.relu = nn.LeakyReLU(0.2) if act is None else act()
18 | self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
19 |
20 | def forward(self, x):
21 | # print('double conv block',x.shape,type(x),x.device,next(self.parameters()).device)
22 | x = self.conv1(x)
23 | if self.use_norm :
24 | x = self.batch_norm(x)
25 | return self.relu(self.conv2(self.relu(x)))
26 |
27 | class Conv1dBlock(nn.Module):
28 | def __init__(self,
29 | in_ch,
30 | out_ch,
31 | kernel_size,
32 | stride,
33 | padding,
34 | use_norm=True,
35 | use_act=True,act=None,
36 | residual=False):
37 |
38 | super().__init__()
39 | self.use_norm = use_norm
40 | self.use_act = use_act
41 | self.act = nn.LeakyReLU(0.2) if act is None else act
42 | self.norm = nn.BatchNorm1d(out_ch)
43 | self.conv = nn.Conv1d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm)
44 | self.residual = residual
45 |
46 | def forward(self,in_x):
47 | x = self.conv(in_x)
48 | if self.use_norm:
49 | x = self.norm(x)
50 | if self.residual:
51 | x += in_x
52 | if self.use_act:
53 | x = self.act(x)
54 | return x
55 |
56 | class Conv2dBlock(nn.Module):
57 | def __init__(self,
58 | in_ch,
59 | out_ch,
60 | kernel_size,
61 | stride,
62 | padding,
63 | use_norm=True,
64 | use_act=True,act=None,
65 | residual=False):
66 |
67 | super().__init__()
68 | self.use_norm = use_norm
69 | self.use_act = use_act
70 | self.act = nn.LeakyReLU(0.2) if act is None else act
71 | self.norm = nn.BatchNorm2d(out_ch)
72 | self.conv = nn.Conv2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm)
73 | self.residual = residual
74 |
75 | def forward(self,in_x):
76 | x = self.conv(in_x)
77 | if self.use_norm:
78 | x = self.norm(x)
79 | if self.residual:
80 | x += in_x
81 | if self.use_act:
82 | x = self.act(x)
83 | return x
84 |
85 | class ConvTranspose2dBlock(nn.Module):
86 | def __init__(self,in_ch,
87 | out_ch,
88 | kernel_size,
89 | stride,
90 | padding,
91 | use_norm=True,
92 | use_act=True,
93 | act=None):
94 | super().__init__()
95 | self.use_norm = use_norm
96 | self.use_act = use_act
97 | self.act = nn.LeakyReLU(0.2) if act is None else act
98 | self.norm = nn.BatchNorm2d(out_ch)
99 | self.conv = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding,bias=self.use_norm)
100 |
101 | def forward(self,x):
102 | x = self.conv(x)
103 | if self.use_norm:
104 | x = self.norm(x)
105 | if self.use_act:
106 | x = self.act(x)
107 | return x
108 |
109 | class DoubleConvTranspose2d(nn.Module):
110 | def __init__(self,in_ch,
111 | skip_ch,
112 | out_ch,
113 | kernel_size,
114 | stride,
115 | padding,
116 | use_norm=True,
117 | use_act=True,
118 | act=None):
119 | super().__init__()
120 | self.use_norm = use_norm
121 | self.use_act = use_act
122 |
123 | self.conv1 = nn.Conv2d(in_ch+skip_ch,in_ch,kernel_size=3,stride=1,padding=1,bias=False)
124 | self.norm1 = nn.BatchNorm2d(in_ch)
125 | self.conv2 = nn.ConvTranspose2d(in_ch,out_ch,kernel_size,stride,padding)
126 | self.norm2 = nn.BatchNorm2d(out_ch)
127 | self.act = nn.LeakyReLU(0.2) if act is None else act
128 |
129 | def forward(self,x):
130 | x = nn.functional.relu(self.norm1(self.conv1(x)))
131 | x = self.conv2(x)
132 | if self.use_norm:
133 | x = self.norm2(x)
134 | if self.use_act:
135 | x = self.act(x)
136 | return x
--------------------------------------------------------------------------------
/wav2mov/models/layers/debug_layers.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from wav2mov.logger import get_module_level_logger
4 | logger = get_module_level_logger(__name__)
5 |
6 | class IdentityDebugLayer(nn.Module):
7 | def __init__(self,name):
8 | super().__init__()
9 | self.name = name
10 | def forward(self,x):
11 | logger.debug(f'{self.name} : {x.shape}')
12 | return x
13 |
--------------------------------------------------------------------------------
/wav2mov/models/utils.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | from torch.nn import init
3 |
4 | from wav2mov.logger import get_module_level_logger
5 | logger = get_module_level_logger(__name__)
6 |
7 | def get_same_padding(kernel_size,stride,in_size=0):
8 | """https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks
9 | https://github.com/DinoMan/speech-driven-animation/blob/dc9fe4aa77b4e042177328ea29675c27e2f56cd4/sda/utils.py#L18-L21
10 |
11 | padding = 'same'
12 | • Padding such that feature map size has size In_size/Stride
13 |
14 | • Output size is mathematically convenient
15 |
16 | • Also called 'half' padding
17 |
18 | out = (in-k+2*p)/s + 1
19 | if out == in/s:
20 | in/s = (in-k+2*p)/s + 1
21 | ((in/s)-1)*s + k -in = 2*p
22 | (in-s)+k-in = 2*p
23 | in case of s==1:
24 | p = (k-1)/2
25 | """
26 | out_size = ceil(in_size/stride)
27 | return ceil(((out_size-1)*stride+ kernel_size-in_size)/2)#(k-1)//2 for same padding
28 |
29 | def init_weights(net, init_type='normal', init_gain=0.02):
30 | """Initialize network weights.
31 | src : https://github1s.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
32 | Parameters:
33 | net (network) -- network to be initialized
34 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
35 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
36 |
37 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
38 | work better for some applications. Feel free to try yourself.
39 | """
40 | def init_func(m): # define the initialization function
41 | classname = m.__class__.__name__
42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
43 | if init_type == 'normal':
44 | init.normal_(m.weight.data, 0.0, init_gain)
45 | elif init_type == 'xavier':
46 | init.xavier_normal_(m.weight.data, gain=init_gain)
47 | elif init_type == 'kaiming':
48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
49 | elif init_type == 'orthogonal':
50 | init.orthogonal_(m.weight.data, gain=init_gain)
51 | else:
52 | raise NotImplementedError(
53 | 'initialization method [%s] is not implemented' % init_type)
54 | if hasattr(m, 'bias') and m.bias is not None:
55 | init.constant_(m.bias.data, 0.0)
56 | # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
57 | elif classname.find('BatchNorm2d') != -1:
58 | init.normal_(m.weight.data, 1.0, init_gain)
59 | init.constant_(m.bias.data, 0.0)
60 |
61 | logger.debug(f'initializing {net.__class__.__name__} with {init_type} weights')
62 | # print('initialize network with %s' % init_type)
63 | net.apply(init_func) # apply the initialization function
64 |
65 |
66 | def init_net(net, init_type='normal', init_gain=0.02):
67 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
68 | Parameters:
69 | net (network) -- the network to be initialized
70 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
71 | gain (float) -- scaling factor for normal, xavier and orthogonal.
72 |
73 |
74 | Return an initialized network.
75 | """
76 |
77 | init_weights(net, init_type, init_gain=init_gain)
78 | return net
79 |
80 |
81 | def squeeze_batch_frames(target):
82 | batch_size,num_frames,*extra = target.shape
83 | return target.reshape(batch_size*num_frames,*extra)
--------------------------------------------------------------------------------
/wav2mov/models/wav2mov_inferencer.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from wav2mov.core.models import TemplateModel
4 | from wav2mov.core.data.utils import AudioUtil
5 |
6 | from wav2mov.models import Generator
7 |
8 | import logging
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 | logger.setLevel(logging.DEBUG)
12 |
13 | def no_grad_wrapper(fn):
14 | def wrapper(*args,**kwargs):
15 | with torch.no_grad():
16 | return fn(*args,**kwargs)
17 | return wrapper
18 |
19 | class Wav2movInferencer(TemplateModel):
20 | def __init__(self,hparams,logger):
21 | super().__init__()
22 | self.hparams = hparams
23 | self.logger = logger
24 | self.device = 'cpu'
25 | self.gen = Generator(hparams['gen'])
26 | self.stride = self.hparams['data']['audio_sf']//self.hparams['data']['video_fps']
27 | self.audio_util = AudioUtil(self.hparams['data']['coarticulation_factor'],self.stride,self.device)
28 |
29 | def load(self,checkpoint):
30 | self.gen.load_state_dict(checkpoint)
31 |
32 | def _squeeze_batch_frames(self,target):
33 | batch_size,num_frames,*extra = target.shape
34 | return target.reshape(batch_size*num_frames,*extra)
35 |
36 | def forward(self,audio_frames,ref_video_frames):
37 | self.logger.debug(f'[GENERATION] audio_frames : {audio_frames.shape} | ref_video_frames : {ref_video_frames.shape}')
38 | self.gen.eval()
39 | batch_size,num_frames,*extra = ref_video_frames.shape
40 | assert batch_size==audio_frames.shape[0] and num_frames ==audio_frames.shape[1]
41 | # audio_frames = self._squeeze_batch_frames(audio_frames)
42 | # ref_video_frames = self._squeeze_batch_frames(ref_video_frames)
43 | fake_frames = self.gen(audio_frames,ref_video_frames)
44 | return fake_frames
45 |
46 | @no_grad_wrapper
47 | def generate(self,audio_frames,ref_video_frames,fraction=None):
48 | if fraction is None:
49 | return self(audio_frames,ref_video_frames)
50 | else:
51 | return self._generate_with_fraction(audio_frames,ref_video_frames,fraction)
52 |
53 | @no_grad_wrapper
54 | def test(self,audio_frames,video):
55 | """test the generation of face images
56 |
57 | Args:
58 | audio (tensor | (B,S)):
59 | audio_frames (tensor | (B,F,Sw)):
60 | video (tensor | (B,F,C,H,W)):
61 | """
62 | fraction = 2
63 | ref_video_frames = self._get_ref_video_frames(video)
64 |
65 | return self.generate(audio_frames,ref_video_frames,fraction),ref_video_frames[:,0,...]
66 |
67 | def __get_random_indices(self,low,high,num_frames):
68 | return random.sample(range(low,high),k=num_frames)
69 |
70 | def _get_ref_video_frames(self,video):
71 | num_frames = video.shape[1]
72 | ################################################################
73 | # random_ids = self.__get_random_indices(0,num_frames,num_frames)
74 | # ref_video_frames = video[:,random_ids]
75 | ################################################################
76 | REF_FRAME_IDX = int(0.6*num_frames)
77 | ref_video_frames = video[:,REF_FRAME_IDX,:,:,:]
78 | ref_video_frames = ref_video_frames.unsqueeze(1)
79 | ref_video_frames = ref_video_frames.repeat(1,num_frames,1,1,1)
80 | ################################################################
81 | return ref_video_frames
82 |
83 | def _squeeze_frames(self,video):
84 | bsize,nframes,*extra = video.shape
85 | return video.reshape(bsize*nframes,*extra)
86 |
87 |
88 | def _generate_with_fraction(self,audio_frames,ref_video_frames,fraction):
89 | num_frames = audio_frames.shape[1]
90 | start_frame = 0
91 | fake_video_frames= []
92 | for i in range(fraction):
93 | end_frame = int((1/fraction)*(i+1)*num_frames) if i!=fraction-1 else num_frames
94 | audio_frames_sample = audio_frames[:,start_frame:end_frame,...]
95 | ref_video_frames_sample = ref_video_frames[:,start_frame:end_frame,...]
96 | fake_video_frames.append(self(audio_frames_sample,ref_video_frames_sample))
97 | start_frame = end_frame
98 |
99 | return torch.cat(fake_video_frames,dim=1)
100 |
101 |
--------------------------------------------------------------------------------
/wav2mov/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 10,
3 | "ref_frame_idx": -10,
4 | "device": "cpu",
5 | "img_size": 256,
6 | "num_epochs": 4,
7 | "train_sync_expert": false,
8 | "train_sync": false,
9 | "pre_learning_epochs": 100,
10 | "adversarial_with_id": 300,
11 | "adversarial_with_sync": 100,
12 | "stop_adversarial_with_sync": 1000,
13 | "adversarial_with_seq": 300,
14 | "img_channels": 3,
15 | "num_frames_fraction": 15,
16 | "data": {
17 | "img_size": 256,
18 | "img_channels": 3,
19 | "coarticulation_factor": 2,
20 | "audio_sf": 16000,
21 | "video_fps": 24,
22 | "batch_size": 6,
23 | "mini_batch_size": 6,
24 | "mean": 0.516,
25 | "std": 0.236
26 | },
27 | "disc": {
28 | "sync_disc": {
29 | "in_channels": 3,
30 | "lr": 1e-4,
31 | "relu_neg_slope": 0.01
32 | },
33 | "sequence_disc": {
34 | "in_channels": 3,
35 | "chs": [64, 128, 256, 512, 1],
36 | "in_size": 32,
37 | "h_size": 256,
38 | "num_layers": 1,
39 | "lr": 1e-4,
40 | "relu_neg_slope": 0.01
41 | },
42 |
43 | "identity_disc": {
44 | "in_channels": 3,
45 | "chs": [64, 128, 256, 512, 1024, 1],
46 | "lr": 1e-4,
47 | "relu_neg_slope": 0.01
48 | },
49 | "patch_disc": {
50 | "ndf": 64,
51 | "in_channels": 3,
52 | "num_layers": 3,
53 | "lr": 1e-4
54 | }
55 | },
56 |
57 | "gen": {
58 | "in_channels": 3,
59 | "chs": [64, 128, 256, 512, 1024],
60 | "latent_dim": 272,
61 | "latent_dim_id": [8, 8],
62 | "comment": "laten_dim not eq latent_dim_id + latent_dim_audio, its 4x4 + 256",
63 | "latent_dim_audio": 256,
64 | "device": "cpu",
65 | "lr": 2e-4
66 | },
67 |
68 | "scales_archieved": {
69 | "lambda_seq_disc": 0.3,
70 | "lambda_sync_disc": 0.8,
71 | "lambda_id_disc": 1,
72 | "lambda_L1": 50
73 | },
74 | "scales": {
75 | "lambda_seq_disc": 0.6,
76 | "lambda_sync_disc": 0.8,
77 | "lambda_id_disc": 1,
78 | "lambda_L1": 50
79 | },
80 | "scheduler": {
81 | "gen": {
82 | "step_size": 20,
83 | "gamma": 0.02
84 | },
85 | "discs": {
86 | "step_size": 20,
87 | "gamma": 0.1
88 | },
89 | "max_epoch": 100
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/wav2mov/params.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from wav2mov.settings import BASE_DIR
4 | from wav2mov.logger import get_module_level_logger
5 | logger = get_module_level_logger(__name__)
6 | class Params:
7 | def __init__(self):
8 | self.vals = {}
9 |
10 |
11 |
12 | def _flatten_dict(self, d):
13 | flattened = {}
14 | for k, v in d.items():
15 | if isinstance(v, dict):
16 | inner_d = self._flatten_dict(v)
17 | flattened = {**flattened, **inner_d}
18 | else:
19 | flattened[k] = v
20 | return flattened
21 |
22 | def update(self,key,value):
23 | if key in self.vals:
24 | logger.warning(f'Updating existing parameter {key} : changing from {self.vals[key]} with {value}')
25 |
26 | self.vals[key] = value
27 |
28 | def _update_vals_from_dict(self, d: dict):
29 | self.vals = d
30 | # d = self._flatten_dict(d)
31 | # self.vals = {**self.vals,**d}
32 |
33 | @classmethod
34 | def from_json(cls, json_file_fullpath):
35 | with open(json_file_fullpath, 'r') as file:
36 | configs = json.load(file)
37 | obj = cls()
38 | obj._update_vals_from_dict(configs)
39 | return obj
40 |
41 | def __getitem__(self, item):
42 | if item not in self.vals:
43 | logger.error(f'{self.__class__.__name__} object has no key called {item}')
44 | raise KeyError(f'No key called {item}')
45 |
46 | return self.vals[item]
47 |
48 | def __repr__(self):
49 | fstr = ''
50 | for key,val in self.vals.items():
51 | fstr += f"{key} : {val}\n"
52 | return fstr
53 |
54 | def set(self,key,val):
55 | if key in self.vals:
56 | logger.warning(f'updating existing value of {key} with value {self.vals[key] } to {val}')
57 | self.vals[key] = val
58 |
59 | def save(self,file_fullpath):
60 | with open(file_fullpath,'w') as file:
61 | json.dump(self.vals,file)
62 |
63 |
64 | #Singleton Object
65 | params = Params.from_json(os.path.join(BASE_DIR, 'params.json'))
66 |
67 | if __name__ == '__main__':
68 | pass
69 |
--------------------------------------------------------------------------------
/wav2mov/plans/README.md:
--------------------------------------------------------------------------------
1 | # Plans
2 |
3 | 
4 | 
5 | 
6 | 
7 |
--------------------------------------------------------------------------------
/wav2mov/plans/extras.md:
--------------------------------------------------------------------------------
1 | # ✨ Mystics of Python 🐍
2 |
3 | ## Exploring the unexplored 🌍 🗺
4 |
5 | ------------------
6 |
7 | #### Image shapes:
8 |
9 | + Numpy / Matplotlib / opencv treats image with shape H,W,C
10 |
11 | + Pytorch requires image to be in thes shape C,H,W
12 |
13 | #### * and ** operators
14 |
15 | ------------------
16 |
17 | Below are 6 different use cases for * and ** in python programming:
18 | [source : StackOverflow : https://stackoverflow.com/a/59630576/12988588](https://stackoverflow.com/a/59630576/12988588)
19 |
20 | + __To accept any number of positional arguments using *args:__
21 |
22 | ```python
23 | def foo(*args):
24 | pass
25 | ```
26 |
27 | here foo accepts any number of positional arguments,
28 | i. e., the following calls are valid ``foo(1)``, ``foo(1, 'bar')``
29 |
30 | + __To accept any number of keyword arguments using **kwargs:__
31 |
32 | ```python
33 | def foo(**kwargs):
34 | pass
35 | ```
36 |
37 | here 'foo' accepts any number of keyword arguments,
38 | i. e., the following calls are valid ``foo(name='Tom')``,``foo(name='Tom', age=33)``
39 |
40 | + __To accept any number of positional and keyword arguments using *args, **kwargs:__
41 |
42 | ```python
43 | def foo(*args, **kwargs):
44 | pass
45 | ```
46 |
47 | here foo accepts any number of positional and keyword arguments,
48 | i. e., the following calls are valid ``foo(1,name='Tom')``, ``foo(1, 'bar', name='Tom', age=33)
49 | ``
50 |
51 | + __To enforce keyword only arguments using *:__
52 |
53 | ```python
54 | def foo(pos1, pos2, *, kwarg1):
55 | pass
56 | ```
57 |
58 | here * means that foo only accept keyword arguments after pos2, hence foo(1, 2, 3) raises TypeError
59 | but ``foo(1, 2, kwarg1=3)`` is ok.
60 |
61 | + __To express no further interest in more positional arguments using `*_` (Note: this is a convention only):__
62 |
63 | ```python
64 | def foo(bar, baz, *_):
65 | pass
66 | ```
67 |
68 | means (by convention) foo only uses bar and baz arguments in its working and will ignore others.
69 |
70 | + __To express no further interest in more keyword arguments using `\**_` (Note: this is a convention only):__
71 |
72 | ```python
73 | def foo(bar, baz, **_):
74 | pass
75 | ```
76 |
77 | means (by convention) foo only uses bar and baz arguments in its working and will ignore others.
78 |
79 | + __BONUS__: From python 3.8 onward, one can use `/` in function definition to enforce positional only parameters.
80 | In the following example, parameters a and b are positional-only, while c or d can be positional or keyword,
81 | and e or f are required to be keywords:
82 |
83 | ```python
84 | def f(a, b, /, c, d, *, e, f):
85 | pass
86 | ```
87 |
88 |
89 |
--------------------------------------------------------------------------------
/wav2mov/plans/images/components.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/components.png
--------------------------------------------------------------------------------
/wav2mov/plans/images/gen_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/gen_arch.png
--------------------------------------------------------------------------------
/wav2mov/plans/images/plan_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/plan_v1.png
--------------------------------------------------------------------------------
/wav2mov/plans/images/system.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/plans/images/system.png
--------------------------------------------------------------------------------
/wav2mov/plans/observations.md:
--------------------------------------------------------------------------------
1 | # Observations 🔍
2 |
3 | ***
4 |
5 | + __sync discriminator loss is very small (of the order of e-9)__
6 |
7 | + solved by using removing batch normalization and using leaky version of relu with slope of 0.2 instead of vanilla relu.
8 |
9 | + __changes in losses is very small__
10 |
11 | ## lip variations
12 | + it requires longer times to get variations .
13 | + size of dataset plays huge role
14 | (60 videos with 2 actores took nearly 200 epochs to show mild variations between frames)
15 | + train sync discriminator only on real frames initially both on synchronized and unsynchronized audio and video frames.
16 |
--------------------------------------------------------------------------------
/wav2mov/preprocess.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 | TITLE wav2mov
3 | echo Running preprocessing
4 |
5 | set DEVICE="cpu"
6 | set VERSION="preprocess_500_a23456"
7 | set LOG="y"
8 | set GRID_DATASET_DIR="/content/drive/MyDrive/Colab Notebooks/projects/wav2mov/datasets/grid_a5_500_a10to14_raw"
9 | echo %GRID_DATASET_DIR%
10 | @REM python main/main.py --preprocess=y -grid=%GRID_DATASET_DIR% --device=%DEVICE% --version=%VERSION% --log=%LOG%
--------------------------------------------------------------------------------
/wav2mov/preprocess.sh:
--------------------------------------------------------------------------------
1 | echo "Preprocess"
2 | DEVICE='cpu'
3 | VERSION='preprocess_500_a23456'
4 | LOG='y'
5 | GRID_DATASET_DIR='/content/drive/MyDrive/Colab Notebooks/projects/wav2mov/datasets/grid_a5_500_a10to14_raw'
6 | python main/main.py --preprocess=y -grid="${GRID_DATASET_DIR}" --device=${DEVICE} --version=${VERSION} --log=${LOG}
7 |
--------------------------------------------------------------------------------
/wav2mov/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/requirements.txt
--------------------------------------------------------------------------------
/wav2mov/run.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 | TITLE wav2mov
3 | echo Running training script
4 |
5 | set EPOCHS=1
6 | set NUM_VIDEOS=14
7 | set VERSION="v9"
8 | set "COMMENT='GPU | scaled id ,gen_id losses by FRACTION | prelearning till 15 epoch | 10 l1_l and 1 id_l'"
9 | set DEVICE="cuda"
10 | set IS_TRAIN="y"
11 | set LOG="n"
12 | set "MODEL_PATH='/content/drive/MyDrive/Colab Notebooks/projects/wav2mov_engine/wav2mov/runs/v9/Run_6_5_2021__17_28'"
13 | @REM echo 'options chosen are '"$EPOCHS"' is_training = '"$IS_TRAIN"
14 | @REM #python main/main.py --train=%IS_TRAIN% -e=%EPOCHS% -v=%NUM_VIDEOS% -m="%COMMENT%" --device=%DEVICE% --version=%VERSION% --log=%LOG% --model_path=%MODEL_PATH%
15 | python main/main.py --train=%IS_TRAIN% -e=%EPOCHS% -v=%NUM_VIDEOS% -m="%COMMENT%" --device=%DEVICE% --version=%VERSION% --log=%LOG%
16 |
--------------------------------------------------------------------------------
/wav2mov/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # run main/main.py
3 | EPOCHS=300
4 | NUM_VIDEOS=120
5 | VERSION="v16_sync_expert"
6 | TRAIN_SYNC_EXPERT='n'
7 | COMMENT='GPU | 225 to 300| no noise encoder | interleaved sync training | only l1 at the beginning'
8 | # COMMENT='GPU | 375 to 450| seq upper half and sync lower half | sync bce loss'
9 | DEVICE="cuda"
10 | IS_TRAIN="y"
11 | LOG='y'
12 | MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/Run_11_7_2021__17_44"
13 | #echo 'options chosen are '"$EPOCHS"' is_training = '"$IS_TRAIN"
14 | # python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --train_sync_expert=${TRAIN_SYNC_EXPERT}
15 | python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --model_path="$MODEL_PATH" --train_sync_expert=${TRAIN_SYNC_EXPERT}
16 |
--------------------------------------------------------------------------------
/wav2mov/run_sync_expert.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # run main/main.py
3 | EPOCHS=450
4 | NUM_VIDEOS=120
5 | VERSION="sync_expert"
6 | TRAIN_SYNC_EXPERT='y'
7 | COMMENT='GPU |350 to 450| train only sync disc to make it a expert'
8 | # COMMENT='GPU | 375 to 450| seq upper half and sync lower half | sync bce loss'
9 | DEVICE="cuda"
10 | IS_TRAIN="y"
11 | LOG='y'
12 | # MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/v16_sync_expert/Run_10_7_2021__23_46"
13 | # MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/Run_10_7_2021__14_19"
14 | MODEL_PATH="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/Run_11_7_2021__11_24"
15 | #echo 'options chosen are '"$EPOCHS"' is_training = '"$IS_TRAIN"
16 | # python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --train_sync_expert=${TRAIN_SYNC_EXPERT}
17 | python main/main.py --train=$IS_TRAIN -e=$EPOCHS -v=$NUM_VIDEOS -m="$COMMENT" --device=$DEVICE --version=$VERSION --log=$LOG --model_path="$MODEL_PATH" --train_sync_expert=${TRAIN_SYNC_EXPERT}
18 |
--------------------------------------------------------------------------------
/wav2mov/settings.py:
--------------------------------------------------------------------------------
1 | """ Base Dir """
2 | import os
3 |
4 |
5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
6 |
--------------------------------------------------------------------------------
/wav2mov/test.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 | TITLE Wav2Mov
3 |
4 | set is_test=y
5 | set "model_path=E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\runs\v6\%1\gen_%1.pt"
6 | set log=n
7 | set device=cpu
8 | set version=v9
9 | echo %model_path%
10 | python main/main.py --device=%device% --test=%is_test% --version=%version% --model_path=%model_path% --log=%log% -v=14
--------------------------------------------------------------------------------
/wav2mov/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | VERSION="v16_sync_expert"
3 | echo $1
4 | is_test='y'
5 | model_path="/content/drive/MyDrive/Colab Notebooks/wav2mov-dev_phase_8/wav2mov/runs/${VERSION}/${1}/gen_${1}.pt"
6 | log='n'
7 | device='cpu'
8 | sample_num=2
9 | python main/main.py --device=${device} --test=${is_test} --version=${VERSION} --model_path="${model_path}" --log=${log} -v=14 --test_sample_num=${sample_num}
10 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/log.json:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrashanthaTP/wav2mov/fabf89aec6c149b223a9d4d187f763363177abe1/wav2mov/tests/old/log.json
--------------------------------------------------------------------------------
/wav2mov/tests/old/test.py:
--------------------------------------------------------------------------------
1 | lines = ['aytala\n','macha']
2 | with open('test.txt','a+') as f:
3 | f.writelines(lines)
4 |
5 | with open('test.txt','r') as f:
6 | read_lines = f.read().split('\n')
7 | print(read_lines)
8 |
9 |
10 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test.txt:
--------------------------------------------------------------------------------
1 | aytala
2 | machaaytala
3 | machaaytala
4 | machaaytala
5 | machaaytala
6 | machaaytala
7 | machaaytala
8 | macha
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_3d_cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from wav2mov.models.sequence_discriminator import SequenceDiscriminatorCNN
5 |
6 |
7 | def get_shape(in_shape,kernels,strides,padding):
8 |
9 | new_shape = []
10 | for i,k,s,p in zip(in_shape,kernels,strides,padding):
11 | new_shape.append(((i-k+2*p)/s)+1)
12 |
13 | return new_shape
14 |
15 | def test(hparams):
16 | x1 = torch.randn(1,1,256,256) #N,Cin,Depth,Height,Width
17 | x2 = torch.randn(1,1,256,256)
18 | model = SequenceDiscriminatorCNN(hparams)
19 |
20 | print(model(x1,x2).shape)
21 | strides = (2,2,2)
22 | padding = (1,1,1)
23 | in_shape = (2,256,256)
24 | kernels = (2,4,4)
25 | print(get_shape(in_shape,kernels,strides,padding))
26 | """
27 | out shape : torch.Size([1, 6, 2, 128, 128])
28 | out shape : torch.Size([1, 32, 2, 64, 64])
29 | out shape : torch.Size([1, 64, 2, 32, 32])
30 | out shape : torch.Size([1, 32, 2, 16, 16])
31 | out shape : torch.Size([1, 16, 2, 8, 8])
32 | out shape : torch.Size([1, 1, 2, 4, 4])
33 | torch.Size([1, 32])
34 |
35 | 1=>4=>8=>16=>8=>4=>1
36 | out shape : torch.Size([1, 4, 2, 128, 128])
37 | out shape : torch.Size([1, 8, 2, 64, 64])
38 | out shape : torch.Size([1, 16, 2, 32, 32])
39 | out shape : torch.Size([1, 8, 2, 16, 16])
40 | out shape : torch.Size([1, 4, 2, 8, 8])
41 | out shape : torch.Size([1, 1, 2, 4, 4])
42 | torch.Size([1, 32])
43 | [2.0, 128.0, 128.0]
44 | """
45 |
46 | def main():
47 | hparams ={'in_channels':1,'chs':[4,8,16,8,4,1]}
48 | test(hparams)
49 |
50 | if __name__=='__main__':
51 | main()
52 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_audio_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import unittest
4 | from wav2mov.core.data.utils import AudioUtil
5 | from wav2mov.utils.audio import StridedAudioV2
6 |
7 |
8 |
9 | logger = logging.getLogger(__file__)
10 | logger.setLevel(logging.DEBUG)
11 | class TestAudioUtil(unittest.TestCase):
12 | def test_frames_cont(self):
13 | BATCH_SIZE = 2
14 | audio = torch.randn(BATCH_SIZE,668*40)
15 |
16 | for coarticulation_factor in range(5):
17 | strided_audio = StridedAudioV2(666,coarticulation_factor)
18 | get_frames_from_idx,_ = strided_audio.get_frame_wrapper(audio)
19 | logging.debug('For coarticulation factor of {}'.format(coarticulation_factor))
20 | for i in range((audio.shape[0]//666)):
21 | frame = get_frames_from_idx(i)
22 | if i%10==0:
23 | logging.debug(f'{i},{frame.shape}')
24 | self.assertEqual(frame.shape, (1,(coarticulation_factor*2 + 1)*666))
25 |
26 | def test_frames_range(self):
27 | BATCH_SIZE = 2
28 | audio = torch.randn(BATCH_SIZE,668*40)
29 |
30 | for coarticulation_factor in range(5):
31 | strided_audio = StridedAudioV2(666,coarticulation_factor)
32 | _,get_frames_from_range = strided_audio.get_frame_wrapper(audio)
33 | logging.debug('For coarticulation factor of {}'.format(coarticulation_factor))
34 | num_frames=5
35 | for i in range((audio.shape[0]//666)):
36 | frame = get_frames_from_range(i,i+num_frames-1)
37 | if i%10==0:
38 | logging.debug(f'{i},{frame.shape}')
39 | self.assertEqual(frame.shape, (1,(num_frames+2)*666))
40 |
41 | def test_limit_audio(self):
42 | COARTICULATION_FACTOR = 2
43 | STRIDE = 666
44 | audio_util = AudioUtil(COARTICULATION_FACTOR,STRIDE)
45 | audio = torch.randn(1,45000)
46 | limited_audio = audio_util.get_limited_audio(audio,5,2)
47 | self.assertEqual(limited_audio.shape,(1,5*STRIDE +2*COARTICULATION_FACTOR*STRIDE))
48 |
49 | if __name__ == '__main__':
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_batching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from wav2mov.params import params
4 | from wav2mov.config import get_config
5 |
6 | from wav2mov.core.data.utils import AudioUtil
7 | from wav2mov.core.data.collates import get_batch_collate
8 |
9 | from wav2mov.main.data import get_dataloaders
10 |
11 | STRIDE = 666
12 | audio_util = AudioUtil(2,stride=STRIDE)
13 |
14 | class Options:
15 | num_videos = 9
16 | v = 'v_test'
17 | options = Options()
18 | config = get_config(options.v)
19 |
20 | class TestBatching(unittest.TestCase):
21 | def test_audio_batching(self):
22 | AUDIO_LEN = 44200
23 | audio = torch.randn(AUDIO_LEN)
24 | frames = audio_util.get_audio_frames(audio.unsqueeze(0))
25 | self.assertEqual(frames.shape,(AUDIO_LEN//STRIDE,STRIDE*(4+1)))
26 |
27 |
28 | def test_collate(self):
29 | params.update('data',{**params['data'],'mini_batch_size':2})
30 | params.update('data',{**params['data'],'coarticulation_factor':0})
31 | collate = get_batch_collate(hparams=params['data'])
32 |
33 | dl = get_dataloaders(options,config,params,get_mean_std=False,collate_fn=collate)
34 |
35 | batch_size = params['data']['mini_batch_size']
36 | # print(vars(dl.train))
37 | # self.assertEqual(len(dl.train),batch_size)
38 |
39 | for i,sample in enumerate(dl.train):
40 | audio = sample.audio
41 | audio_frames = sample.audio_frames
42 | video = sample.video
43 | print(f'Batch {i}')
44 | print(f'audio :{audio.shape}')
45 | print(f'audio frames :{audio_frames.shape}')
46 | print(f'video : {video.shape} ')
47 |
48 | self.assertEqual(sample.audio.shape[0],batch_size)
49 |
50 | if __name__ == '__main__':
51 | unittest.main()
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_config.py:
--------------------------------------------------------------------------------
1 | from wav2mov.config import config
2 |
3 | for k,v in config.vals.items():
4 | print(f"{k} : {v}")
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import torch
4 | from wav2mov.core.data.datasets import AudioVideoDataset
5 | from wav2mov.utils.audio import StridedAudio
6 | from settings import BASE_DIR
7 | ROOT_DIR = os.path.join(BASE_DIR,'datasets','grid_dataset')
8 |
9 | def test():
10 | strided_audio = StridedAudio(16000//24,1)
11 | dataset = AudioVideoDataset(ROOT_DIR,os.path.join(ROOT_DIR,'filenames.txt'),video_fps=24,audio_sf=16000)
12 | for i,sample in enumerate(dataset):
13 | audio,video = sample
14 | stride = math.floor(16_000/24)
15 | print(f"Stride = {stride}")
16 | print(f"Audio shape,video shape , audio.shape[0]//stride")
17 | print(audio.shape,video.shape,audio.shape[0]//stride)
18 | get_frames = strided_audio.get_frame_wrapper(audio)
19 | for i in range(video.shape[0]):
20 | frame = get_frames(i)
21 | print(frame[0].shape,frame[1])
22 | break
23 | print("="*10)
24 | if i==3:
25 | break
26 | def main():
27 | test()
28 | return
29 | if __name__=='__main__':
30 | main()
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_discriminators.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 |
4 | from wav2mov.models.identity_discriminator import IdentityDiscriminator
5 | from wav2mov.models.sequence_discriminator import SequenceDiscriminator
6 | from wav2mov.models.sync_discriminator import SyncDiscriminator
7 | from wav2mov.models.patch_disc import PatchDiscriminator
8 |
9 | from utils import get_input_of_shape,no_grad
10 | BATCH_SIZE = 1
11 | IMAGE_SIZE = 256
12 | CHANNELS = 3
13 |
14 |
15 |
16 |
17 | class TestDescs(unittest.TestCase):
18 | @no_grad
19 | def test_identity(self):
20 | x = get_input_of_shape((BATCH_SIZE,CHANNELS,IMAGE_SIZE,IMAGE_SIZE))
21 | y = get_input_of_shape((BATCH_SIZE,CHANNELS,IMAGE_SIZE,IMAGE_SIZE))
22 | desc = IdentityDiscriminator()
23 | out = desc(x,y)
24 | print(f"identity descriminator : input :{x.shape} and {y.shape} | output : {out.shape}")
25 | self.assertEqual(out.shape,(BATCH_SIZE,16))
26 |
27 | @no_grad
28 | def test_sequence(self):
29 | image_size = IMAGE_SIZE*IMAGE_SIZE*3
30 | hidden_size = 100
31 |
32 | desc = SequenceDiscriminator(image_size,hidden_size,num_layers=1)
33 | x = get_input_of_shape((BATCH_SIZE,2,image_size))
34 | out = desc(x)
35 | print(f"sequence descriminator : input :{x.shape} | output : {out.shape}")
36 | self.assertEqual(out.shape,(BATCH_SIZE,hidden_size))
37 |
38 | @no_grad
39 | def test_sync(self):
40 | audio = get_input_of_shape((BATCH_SIZE,666))
41 | image = get_input_of_shape((BATCH_SIZE,CHANNELS,IMAGE_SIZE,IMAGE_SIZE))
42 | desc = SyncDiscriminator()
43 | out = desc(audio,image)
44 | print(f"sync descriminator : input :{audio.shape} and {image.shape} | output : {out.shape}")
45 | self.assertEqual(out.shape,(BATCH_SIZE,128))
46 |
47 | @no_grad
48 | def test_patch_disc(self):
49 | frame_image = get_input_of_shape((BATCH_SIZE,1,256,256))
50 | still_image = get_input_of_shape((BATCH_SIZE,1,256,256))
51 | disc = PatchDiscriminator(1,ndf=64)
52 | out = disc(frame_image,still_image)
53 | print(f'patch disc out: {out.shape}')
54 | self.assertEqual(out.shape,(BATCH_SIZE,1,30,30))
55 |
56 | def main():
57 | unittest.main()
58 | return
59 |
60 | if __name__=='__main__':
61 | main()
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_file_logger.py:
--------------------------------------------------------------------------------
1 | from wav2mov.logger import Logger
2 |
3 | def test():
4 | logger = Logger(__name__)
5 | logger.add_console_handler()
6 | logger.add_filehandler('logs/log.log')
7 |
8 | log = []
9 | for i in range(5):
10 | log.append(f'log {i}')
11 |
12 | logger.info('\t'.join(log))
13 |
14 |
15 | if __name__ == '__main__':
16 | test()
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_logger.py:
--------------------------------------------------------------------------------
1 | from wav2mov.logger import Logger
2 |
3 | log_path = r'E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\tests\log.json'
4 | def test_logger():
5 | logger = Logger(__file__)
6 | logger.add_console_handler()
7 | logger.add_filehandler(log_path,in_json=True)
8 | logger.debug('testing json logger %d',extra={'great':1})
9 | logger.debug('testing json logger',extra={'great':1})
10 | logger.debug('testing json logger',extra={'great':1})
11 | logger.debug('testing json logger',extra={'great':1})
12 |
13 | def main():
14 | test_logger()
15 | if __name__=='__main__':
16 | main()
17 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_main_data.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from wav2mov.main.data import get_dataloaders
4 | from wav2mov.config import config
5 | from wav2mov.params import params
6 |
7 | import argparse
8 | # argparser = argparse.ArgumentParser(description='test dataloader')
9 | # argparser.add_argument('--num_videos','-n',type=int,help='number of videos')
10 | # options = argparser.parse_args()
11 | class Options:pass
12 | options = Options()
13 | options.num_videos=36
14 | class TestData(unittest.TestCase):
15 | def test_dataloader(self):
16 | dataloader,mean,std = get_dataloaders(options,config,params,shuffle=True)
17 | sample = next(iter(dataloader.train))
18 | print(f'train dl : {len(dataloader.train)}')
19 | print(f'val dl : {len(dataloader.val)}')
20 | print(f'channel wise mean : {mean}')
21 | print(f'channel wise std : {std}')
22 | print(f'audio shape : {sample.audio.shape }')
23 | print(f'video shape : {sample.video.shape }')
24 |
25 | if __name__ == '__main__':
26 | unittest.main()
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_module_level_logger.py:
--------------------------------------------------------------------------------
1 | from wav2mov.main import callbacks
2 |
3 | def test():
4 | m_logger = callbacks.m_logger
5 | m_logger.debug(f'{m_logger.name} {m_logger.level}')
6 | if __name__ == '__main__':
7 | test()
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_packing.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence,pad_packed_sequence
4 |
5 | from wav2mov.core.data.collates import collate_fn
6 |
7 | from wav2mov.main.data import get_dataloaders
8 | from wav2mov.params import params
9 | from wav2mov.config import config
10 |
11 | import argparse
12 | def parse_args():
13 | arg_parser = argparse.ArgumentParser(description='Testing Variable Sequence Batch Size')
14 | arg_parser.add_argument('--samples','-s',type=int,default=1,help='number of samples whose details will be printed')
15 | arg_parser.add_argument('--batch_size','-b',type=int,default=10,help='batch size for dataloader')
16 | return arg_parser.parse_args()
17 |
18 |
19 |
20 |
21 |
22 | def test_batch_size(options):
23 |
24 | dl,_,_ = get_dataloaders(config,params,get_mean_std=False,collate_fn=collate_fn)
25 | dl = dl.train
26 | for i in range(options.samples):
27 | sample,lens = next(iter(dl))
28 | audios,videos = sample
29 | audio_lens,video_lens = lens
30 | print(audios.shape)
31 | videos = pack_padded_sequence(videos,lengths=video_lens,batch_first=True,enforce_sorted=False)
32 | audios = pack_padded_sequence(audios, lengths = audio_lens,batch_first=True,enforce_sorted = False)
33 | print(videos,audios)
34 | print(f' sample {i+1} '.center(30,'='))
35 | # print('video shape : ',sample.video.shape)
36 | # print('audio shape : ',sample.audio.shape)
37 |
38 | print('video shape : ',videos.data.shape)
39 | print('audio shape : ',audios.data.shape)
40 | print('batch sizes : ',videos.batch_sizes)
41 | # print('lens audio : ',lens.audio)
42 | # print('lens video :',lens.video)
43 | p ,l= pad_packed_sequence(videos,batch_first=True)
44 | print(p.shape,l)
45 | print(''*30)
46 |
47 |
48 | if __name__ == '__main__':
49 | options = parse_args()
50 | BATCH_SIZE = options.batch_size
51 | data = params['data']
52 | params.set('data', {**data, 'batch_size': BATCH_SIZE})
53 | test_batch_size(options)
54 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_settings.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | TEST_DIR = os.path.dirname(os.path.abspath(__file__))
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_stored_video.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms as vtransforms
3 |
4 | import numpy as np
5 | from torchvision.transforms.transforms import CenterCrop
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | parser.add_argument('--gray','-g',choices=['y','n'],default='y',type=str,help='whether the output should be grayscale or color image')
11 | args = parser.parse_args()
12 | print('settings : gray(y/n): ',args.gray)
13 | from wav2mov.utils.plots import show_img
14 |
15 |
16 |
17 | path = r'E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\datasets\grid_dataset\s10_l_bbat9p\video_frames.npy'
18 |
19 |
20 | def test_db():
21 |
22 |
23 | video = np.load(path).astype('float64')
24 | video = video.transpose(0,3,1,2)
25 | video = torch.from_numpy(video)
26 | print(f'video shape{video.shape}')
27 |
28 |
29 |
30 |
31 | channels = video.shape[1]
32 | print(' before '.center(10,'='))
33 | print('mean :',[torch.mean(video[0][:,i,...]) for i in range(channels)])
34 | print('std : ', [torch.std(video[0][:,i,...]) for i in range(channels)])
35 | video = video/255
36 |
37 | channels = 1 if args.gray == 'y' else 3
38 | transforms = vtransforms.Compose(
39 |
40 | [
41 | vtransforms.Grayscale(1),
42 | # vtransforms.CenterCrop(256),
43 | vtransforms.Resize((256, 256)),
44 | vtransforms.Normalize([0.5]*channels, [0.5]*channels)
45 | ])
46 | print(' after '.center(10,'='))
47 | print('mean :',[torch.mean(video[0][:,i,...]) for i in range(channels)])
48 | print('std : ', [torch.std(video[0][:,i,...]) for i in range(channels)])
49 | print('max : ', [torch.max(video[0][:,i,...]) for i in range(channels)])
50 | print('min : ', [torch.min(video[0][:,i,...]) for i in range(channels)])
51 | video = transforms(video)
52 | # show_img(video[0])
53 | show_img(video[0],cmap='gray')
54 |
55 |
56 | def main():
57 | test_db()
58 |
59 | if __name__=='__main__':
60 | main()
61 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_tensorboard_logger.py:
--------------------------------------------------------------------------------
1 | import time
2 | from torch.utils.tensorboard.writer import SummaryWriter
3 | from wav2mov.logger import TensorLogger
4 |
5 |
6 | def test():
7 | logger = TensorLogger('logs')
8 | writer_1 = SummaryWriter('logs/exp1/test1')
9 | writer_2 = SummaryWriter('logs/exp1/test2')
10 | logger.add_writer('test1',writer_1)
11 | logger.add_writer('test2',writer_2)
12 |
13 | for i in range(10):
14 | logger.add_scalar('test1','aytala',i+1,i)
15 | print('adding 1 ',i)
16 | time.sleep(2)
17 | # writer_1.add_scalar('test1',i+2,i)
18 |
19 | for i in range(10):
20 | logger.add_scalar('test1','macha1',i*2,i)
21 | print('adding 2 ',i)
22 | time.sleep(2)
23 |
24 | if __name__ == '__main__':
25 | test()
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | # from wav2mov.models.generator import GeneratorBW
4 | from wav2mov.models.generator_v6 import GeneratorBW,Generator,Encoder,Decoder
5 |
6 | params = {
7 | "img_dim":[256,256],
8 | "retain_dim":True,
9 | "device":"cpu",
10 | "in_channels": 1,
11 | "enc_chs": [64, 128, 256, 512, 1024],
12 | "dec_chs": [1024, 512, 256, 128, 64],
13 | "up_chs": [1026, 512, 256, 128, 64],
14 | }
15 |
16 |
17 |
18 | class TestUnet(unittest.TestCase):
19 | def test_encoder(self):
20 | encoder = Encoder(chs=[params['in_channels']] +params['enc_chs'])
21 | image = torch.randn(1,1,256,256)
22 | out = encoder(image)
23 | print('testing encoder ')
24 | for layer in out:
25 | print(layer.shape)
26 |
27 | def test_decoder(self):
28 | encoder = Encoder(chs=[params['in_channels']] + params['enc_chs'])
29 | image = torch.randn(1,1,256,256)
30 | encoded = encoder(image)[::-1]
31 |
32 | audio_noise_encoded = torch.randn(1,2,8,8)
33 | encoded[0] = torch.cat([encoded[0],audio_noise_encoded],dim=1)
34 | decoder = Decoder(up_chs=params['up_chs'],dec_chs=params['dec_chs'])
35 | image = torch.randn(1,1,256,256)
36 | out = decoder(encoded[0],encoded[1:])
37 | print('test_decoder : ',out.shape)
38 |
39 | def test_gen(self):
40 | gen = Generator(params)
41 | frame_img = torch.randn(1,1,256,256)
42 | audio_noise = torch.randn(1,2,8,8)
43 | out = gen(frame_img,audio_noise)
44 | self.assertEqual(out.shape,(1,1,256,256))
45 |
46 | def test():
47 | with torch.no_grad():
48 | x = torch.randn(1,1,256,256)
49 | a = torch.randn(1,666)
50 | gen = GeneratorBW( {
51 | "device": "cpu",
52 | "in_channels": 1,
53 | "enc_chs": [64, 128, 256, 512, 1024],
54 | "dec_chs": [1024, 512, 256, 128, 64],
55 | "up_chs": [1026, 512, 256, 128, 64],
56 | "img_dim": [256, 256],
57 | "retain_dim": True,
58 | "lr": 1e-4
59 | })
60 | gen.eval()
61 | assert gen(a,x).shape==(1,1,256,256)
62 | print("Test Passed " ,1)
63 |
64 | if __name__ == '__main__':
65 | unittest.main()
66 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import unittest
4 |
5 | from test_settings import TEST_DIR
6 |
7 | from wav2mov.utils.plots import save_gif
8 |
9 |
10 | class TestUtils(unittest.TestCase):
11 | def test_save_gif_1(self):
12 | images = torch.randn(10,3,256,256)
13 | save_gif(os.path.join(TEST_DIR,'logs','test_1.gif'),images)
14 | def test_save_gif_2(self):
15 | images = []
16 | for _ in range(5):
17 | images.append(torch.randn(3,256,256))
18 | save_gif(os.path.join(TEST_DIR, 'logs', 'test_2.gif'), images)
19 |
20 | def main():
21 | unittest.main()
22 |
23 | if __name__=='__main__':
24 | main()
25 |
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_wav2mov.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 |
4 | from wav2mov.core.data.collates import get_batch_collate
5 |
6 | from wav2mov.config import get_config
7 | from wav2mov.params import params
8 |
9 | from wav2mov.models.wav2mov import Wav2Mov
10 | from wav2mov.main.engine import Engine
11 | from wav2mov.main.callbacks import LossMetersCallback,TimeTrackerCallback,ModelCheckpoint
12 | from wav2mov.main.data import get_dataloaders
13 |
14 | from wav2mov.main.options import Options,set_options
15 |
16 | logger = logging.getLogger(__file__)
17 | logger.setLevel(logging.DEBUG)
18 |
19 | # BATCH_SIZE = params['data']['batch_size']
20 | BATCH_SIZE = 1
21 |
22 | NUM_FRAMES = 25
23 |
24 | def get_input():
25 | audio = torch.randn(BATCH_SIZE,666*9,device='cuda')
26 | video = torch.randn(BATCH_SIZE,NUM_FRAMES,1,256,256,device='cuda')
27 | audio_frames = torch.randn(BATCH_SIZE,NUM_FRAMES,666+4*666,device='cuda')
28 | return audio,video,audio_frames
29 |
30 | def test(options,hparams,config,logger):
31 | engine = Engine(options,hparams,config,logger)
32 | model = Wav2Mov(hparams,config,logger)
33 | collate_fn = get_batch_collate(hparams['data'])
34 | dataloaders_ntuple = get_dataloaders(options,config,hparams,
35 | get_mean_std=False,
36 | collate_fn=collate_fn)
37 | callbacks = [LossMetersCallback(options,hparams,logger,
38 | verbose=True),
39 | TimeTrackerCallback(hparams,logger),
40 | ModelCheckpoint(model,hparams,
41 | save_every=5)]
42 |
43 | engine.run(model,dataloaders_ntuple,callbacks)
44 |
45 | if __name__ == '__main__':
46 | options = Options().parse()
47 | set_options(options,params)
48 | config = get_config(options.version)
49 | test(options,params,config,logger)
--------------------------------------------------------------------------------
/wav2mov/tests/old/test_wav2mov_v7.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 |
4 | from wav2mov.config import get_config
5 | from wav2mov.params import params
6 |
7 | from wav2mov.models.wav2mov_v7 import Wav2MovBW
8 |
9 |
10 | logger = logging.getLogger(__file__)
11 | logger.setLevel(logging.DEBUG)
12 | config = get_config('v_test')
13 |
14 | # BATCH_SIZE = params['data']['batch_size']
15 | BATCH_SIZE = 1
16 |
17 | NUM_FRAMES = 25
18 |
19 | def get_input():
20 | audio = torch.randn(BATCH_SIZE,666*9,device='cuda')
21 | video = torch.randn(BATCH_SIZE,NUM_FRAMES,1,256,256,device='cuda')
22 | audio_frames = torch.randn(BATCH_SIZE,NUM_FRAMES,666+4*666,device='cuda')
23 | return audio,video,audio_frames
24 | def test():
25 | model = Wav2MovBW(config,params,logger)
26 | audio,video,audio_frames = get_input()
27 |
28 | model.on_train_start()
29 | for epoch in range(1):
30 | model.on_batch_start()
31 | batch_size,num_frames,channels,height,width = video.shape
32 | audio_frames = audio_frames.reshape(-1,audio_frames.shape[-1])
33 | model.set_input(audio_frames,video.reshape(batch_size*num_frames,channels,height,width))
34 | frame_img = video[:,NUM_FRAMES//2,...]
35 | frame_img = frame_img.repeat((1,NUM_FRAMES,1,1,1))#id and gen requires still image for each audio_frame as condition
36 | frame_img = frame_img.reshape(batch_size*num_frames,channels,height,width)
37 | # print('frame_img ',frame_img.shape)
38 | model.set_condition(frame_img)
39 | model.optimize_parameters()
40 | model.optimize_sequence(video,audio)
41 |
42 |
43 | if __name__ == '__main__':
44 | test()
--------------------------------------------------------------------------------
/wav2mov/tests/old/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | def get_input_of_shape(shape):
3 | return torch.randn(shape)
4 |
5 |
6 | def no_grad(func):
7 | def wrapper(*args, **kwargs):
8 | with torch.no_grad():
9 | out = func(*args, **kwargs)
10 | return out
11 | return wrapper
12 |
--------------------------------------------------------------------------------
/wav2mov/tests/test_audio_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wav2mov.params import params
3 | from wav2mov.config import get_config
4 | from wav2mov.models.generator.audio_encoder import AudioEnocoder
5 |
6 | from wav2mov.logger import get_module_level_logger
7 | logger = get_module_level_logger(__file__)
8 | config = get_config('test_sync')
9 | def test():
10 | model = AudioEnocoder(params['gen'])
11 | audio = torch.randn(2,10,7,13)
12 | out = model(audio)
13 | logger.debug(f'out shape {out.shape}')
14 | assert(out.shape==(2,10,params['gen']['latent_dim_audio']))
15 |
16 | def main():
17 | test()
18 |
19 | if __name__=='__main__':
20 | main()
21 |
--------------------------------------------------------------------------------
/wav2mov/tests/test_audio_frames.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import os
3 | import torch
4 | from matplotlib import pyplot as plt
5 | from scipy.io.wavfile import write
6 |
7 | from wav2mov.config import get_config
8 | from wav2mov.params import params as hparams
9 |
10 | from wav2mov.main.options import Options
11 | from wav2mov.core.data.collates import get_batch_collate
12 | from wav2mov.main.data import get_dataloaders as get_dl
13 |
14 | DIR = os.path.dirname(os.path.abspath(__file__))
15 | def test():
16 | options = Options().parse()
17 | config = get_config('test_audio_frames')
18 | collate_fn = get_batch_collate(hparams['data'])
19 | dls = get_dl(options,config,hparams,collate_fn=collate_fn)
20 | train_dl = dls.train
21 | sample = next(iter(train_dl))
22 | audio,audio_frames,video = sample
23 |
24 | print(f'audio : {audio.shape}')
25 | print(f'audio frames : {audio_frames.shape}')
26 | print(f'video : {video.shape}')
27 | means,stds,maxs,mins = [],[],[],[]
28 | for audio_frame in audio_frames[0][:10]:
29 | maxs.append(torch.max(audio_frame).item())
30 | mins.append(torch.min(audio_frame).item())
31 | means.append(torch.mean(audio_frame).item())
32 | stds.append(torch.std(audio_frame).item())
33 | print('means ',means)
34 | print('stds ',stds)
35 | print('maxs ',maxs)
36 | print('mins ',mins)
37 | plt.imshow(audio_frames[0][30])
38 | plt.show()
39 |
40 | if __name__ == '__main__':
41 | test()
42 |
--------------------------------------------------------------------------------
/wav2mov/tests/test_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import unittest
4 |
5 | from wav2mov.models.generator.id_encoder import IdEncoder
6 | from wav2mov.models.generator.id_decoder import IdDecoder
7 | from wav2mov.models.generator.frame_generator import Generator
8 | from wav2mov.logger import get_module_level_logger
9 | from wav2mov.utils.plots import show_img
10 | logger = get_module_level_logger(__name__)
11 |
12 | class TestGen(unittest.TestCase):
13 | def test_id_encoder(self):
14 | logger.debug(f'id_encoder test')
15 | hparams = {
16 | 'in_channels':1,
17 | 'chs':[64,128,256,512,1024],
18 | 'latent_dim_id':(8,8)
19 | }
20 |
21 | id_encoder = IdEncoder(hparams)
22 | images = torch.randn(1,1,256,256)
23 | encoded,intermediates = id_encoder(images)
24 | req_h,req_w = 128,128
25 | for i,intermediate in enumerate(intermediates):
26 | self.assertEqual((req_h,req_w),intermediate.shape[-2:])
27 | logger.debug(f'{i} : {intermediate.shape}')
28 | req_h,req_w = req_h//2,req_w//2
29 | logger.debug(f'final encoded {encoded.shape}')
30 |
31 | def test_id_decoder(self):
32 | logger.debug(f'id_decoder test')
33 | hparams = {
34 | 'in_channels':1,
35 | 'chs':[64,128,256,512,1024],
36 | 'latent_dim':16,
37 | 'latent_dim_id':(8,8)
38 | }
39 | id_enocder = IdEncoder(hparams)
40 | id_decoder = IdDecoder(hparams)
41 | images = torch.randn(1,1,256,256)
42 | encoded,intermediates= id_enocder(images)
43 | decoded = id_decoder(encoded,intermediates)
44 | self.assertEqual(decoded.shape,(1,1,256,256))
45 |
46 | def test_generator(self):
47 | logger.debug(f'test generator')
48 |
49 | hparams = {
50 | 'in_channels':3,
51 | 'chs':[64,128,256,512,1024],
52 | 'latent_dim':16+256+10,
53 | 'latent_dim_id':(8,8),
54 | 'latent_dim_audio':256,
55 | 'latent_dim_noise':10,
56 | 'device':'cpu',
57 | 'lr':2e-4
58 | }
59 |
60 | gen = Generator(hparams)
61 | ref_frames = torch.randn(1,5,3,256,256)
62 | audio_frames = torch.zeros(1,5,5*666)
63 | out = gen(audio_frames,ref_frames)
64 | self.assertEqual(out.shape,(1,5,3,256,256))
65 |
66 | # show_img(ref_frames[0][0],cmap='gray')
67 | # show_img(out[0][0],cmap='gray')
68 |
69 | test_on_real_img(gen)
70 |
71 | def test_on_real_img(gen):
72 | gen.eval()
73 | ref_frames = torch.from_numpy(np.load(r'E:\Users\VS_Code_Workspace\Python\VirtualEnvironments\wav2mov\wav2mov\datasets\grid_dataset_256_256\s10_l_bbat9p\video_frames.npy'))
74 | show_img(ref_frames[0].permute(2,0,1))
75 | ref_frames = ref_frames[0].permute(2,0,1).unsqueeze(0).unsqueeze(0).float()
76 | # logger.debug(ref_frames.shape)
77 | audio_frames = torch.randn(1,1,5*666)
78 | out = gen(audio_frames,ref_frames)
79 | show_img(out[0][0])
80 |
81 | if __name__ == '__main__':
82 | unittest.main()
--------------------------------------------------------------------------------
/wav2mov/tests/test_main_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wav2mov.core.data.collates import get_batch_collate
3 | from wav2mov.params import params
4 | from wav2mov.config import get_config
5 | from wav2mov.logger import get_module_level_logger
6 | from wav2mov.main.data import get_dataloaders
7 | from wav2mov.main.options import Options
8 |
9 | logger = get_module_level_logger(__name__)
10 | def test(options,hparams,config):
11 | collate_fn = get_batch_collate(hparams['data'])
12 | dataloader_pack = get_dataloaders(options,config,params,collate_fn=collate_fn)
13 | train_dl,test_dl = dataloader_pack
14 | logger.debug(f'train : {len(train_dl)} test : {len(test_dl)}')
15 |
16 | dl_iter = iter(train_dl)
17 | for _ in range(min(len(train_dl),10)):
18 | sample = next(dl_iter)
19 | audio,video = sample.audio,sample.video
20 | logger.debug(f'video {video.shape} : {torch.mean(video,dim=[0,1,3,4])} ,{torch.std(video,dim=[0,1,3,4])}')
21 | logger.debug(f'audio {audio.shape} : {torch.mean(audio)} ,{torch.std(audio)}')
22 |
23 | def main():
24 | options = Options().parse()
25 | config = get_config(options.version)
26 | test(options,params,config)
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/wav2mov/tests/test_seq_disc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wav2mov.params import params
3 | from wav2mov.models import SequenceDiscriminator
4 |
5 | from wav2mov.logger import get_module_level_logger
6 | logger = get_module_level_logger(__file__)
7 |
8 | def test():
9 | model = SequenceDiscriminator(params['disc']['sequence_disc'])
10 | frames = torch.randn(1,10,3,256,256)
11 | out = model(frames)
12 | logger.debug(f'out shape {out.shape}')
13 | assert(out.shape==(1,256))
14 |
15 | def main():
16 | test()
17 |
18 | if __name__=='__main__':
19 | main()
20 |
--------------------------------------------------------------------------------
/wav2mov/tests/test_sync_disc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from wav2mov.params import params
3 | from wav2mov.config import get_config
4 | from wav2mov.models import SyncDiscriminator
5 |
6 | from wav2mov.logger import get_module_level_logger
7 | logger = get_module_level_logger(__file__)
8 | config = get_config('test_sync')
9 | def test():
10 | model = SyncDiscriminator(params['disc']['sync_disc'],config)
11 | audio_frames = torch.randn(2,12,13)
12 | video_frames = torch.randn(2,5,3,256,256)
13 | out = model(audio_frames,video_frames)[0]
14 | logger.debug(f'out shape {out.shape}')
15 | assert(out.shape==(2,1))
16 |
17 | def main():
18 | test()
19 |
20 | if __name__=='__main__':
21 | main()
22 |
--------------------------------------------------------------------------------
/wav2mov/tests/test_video.py:
--------------------------------------------------------------------------------
1 | import dlib
2 | import cv2
3 | import os
4 | from wav2mov.logger import get_module_level_logger
5 | logger = get_module_level_logger(__name__)
6 |
7 |
8 | DIR = r'D:\dataset_lip\GRID\video_6sub_500'
9 | video = os.path.join(DIR,'s4_p_swbxza.mov')
10 |
11 |
12 | face_detector = dlib.get_frontal_face_detector()
13 |
14 | def convert_and_trim_bb(image, rect):
15 | """ from pyimagesearch
16 | https://www.pyimagesearch.com/2021/04/19/face-detection-with-dlib-hog-and-cnn/
17 | """
18 | # extract the starting and ending (x, y)-coordinates of the
19 | # bounding box
20 | startX = rect.left()
21 | startY = rect.top()
22 | endX = rect.right()
23 | endY = rect.bottom()
24 | # ensure the bounding box coordinates fall within the spatial
25 | # dimensions of the image
26 | startX = max(0, startX)
27 | startY = max(0, startY)
28 | endX = min(endX, image.shape[1])
29 | endY = min(endY, image.shape[0])
30 | # compute the width and height of the bounding box
31 | w = endX - startX
32 | h = endY - startY
33 | # return our bounding box coordinates
34 | return (startX, startY, w, h)
35 | def show_img(image_name,image):
36 | cv2.imshow(str(image_name),image)
37 | cv2.waitKey(0)
38 | cv2.destroyAllWindows()
39 |
40 | def get_video_frames(video_path,img_size:tuple):
41 | try:
42 | logger.debug(video_path)
43 | cap = cv2.VideoCapture(str(video_path))
44 | if(not cap.isOpened()):
45 | logger.error("Cannot open video stream or file!")
46 | frames = []
47 | i = 0
48 | while cap.isOpened():
49 | frameId = cap.get(1)
50 | ret, image = cap.read()
51 | if not ret:
52 | break
53 | try:
54 | i+=1
55 | #image[top_row:bottom_row,left_column:right_column]
56 | image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)#other libraries including matplotlib,dlib expects image in RGB
57 | print(image.shape,i)
58 | face = face_detector(image)[0]#get first face object
59 | print(face.top(),face.bottom())
60 | x,y,w,h = convert_and_trim_bb(image,face)
61 | # image = cv2.resize(image[y:y+h,x:x+w],img_size,interpolation=cv2.INTER_CUBIC)
62 |
63 | # show_img(len(frames),image)
64 | except Exception as e:
65 | # print(e)
66 | # logger.error(e)
67 | continue
68 | frames.append(image)
69 | return frames
70 | except Exception as e:
71 | logger.exception(e)
72 |
73 |
74 | if __name__ == '__main__':
75 | get_video_frames(video,(256,256))
--------------------------------------------------------------------------------
/wav2mov/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Package for util files"""
--------------------------------------------------------------------------------
/wav2mov/utils/audio.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | class StridedAudio:
5 | def __init__(self,stride,coarticulation_factor=0):
6 | self.coarticulation_factor = coarticulation_factor
7 | self.stride = stride
8 | if isinstance(self.stride,float):
9 | self.stride = math.floor(self.stride)
10 |
11 |
12 | def get_frame_wrapper(self,audio):
13 | # padding = torch.tensor([])
14 | # audio = torch.from_numpy(audio)
15 |
16 | if self.coarticulation_factor!=0:
17 | padding = torch.cat([torch.tensor([0]*self.stride) for _ in range(self.coarticulation_factor)],dim=0)
18 | audio = torch.cat([padding,audio,padding])
19 |
20 | audio = audio.unsqueeze(0)
21 |
22 | def get_frame(idx):
23 | center_frame_pos = idx*self.stride
24 | # print(center_frame_pos)
25 | center_frame = audio[:,center_frame_pos:center_frame_pos+self.stride]
26 | if self.coarticulation_factor == 0:
27 | return center_frame, center_frame_pos+self.stride
28 | # curr_idx = center_frame_pos
29 | # print(f'audio shape {audio.shape} centerframe {center_frame.shape}')
30 | for i in range(self.coarticulation_factor):
31 | center_frame = torch.cat([audio[:,center_frame_pos-(i-1)*self.stride:center_frame_pos-i*self.stride],center_frame],dim=1)
32 | last_pos = 0
33 | for i in range(self.coarticulation_factor):
34 | center_frame = torch.cat([audio[:,center_frame_pos+(i+1)*self.stride:center_frame_pos+(i+2)*self.stride]])
35 | last_pos = center_frame_pos+(i+2)*self.stride
36 | return center_frame,last_pos
37 | return get_frame
38 |
39 |
40 | class StridedAudioV2:
41 | def __init__(self, stride, coarticulation_factor=0,device='cpu'):
42 | """
43 | 0.15s window ==>16k*0.15 points = 2400
44 | so padding on either side will be 2400//2 = 1200
45 | """
46 | self.coarticulation_factor = coarticulation_factor
47 | self.stride = stride
48 | if isinstance(self.stride, float):
49 | self.stride = math.floor(self.stride)
50 | self.pad_len = self.coarticulation_factor*self.stride
51 | self.device = device
52 |
53 | def get_frame_wrapper(self, audio):
54 | # padding = torch.tensor([])
55 | # audio = torch.from_numpy(audio)
56 |
57 | if self.coarticulation_factor != 0:
58 |
59 | padding = torch.zeros((audio.shape[0],self.pad_len),device=self.device)
60 |
61 | # padding = torch.cat([torch.tensor([0]*self.stride)
62 | # for _ in range(self.coarticulation_factor)], dim=0)
63 |
64 | audio = torch.cat([padding, audio,padding],dim=1)
65 | # print(audio.shape,self.padding.shape)
66 |
67 |
68 | def get_frame_from_idx(idx):
69 | center_idx= (idx) + (self.coarticulation_factor)
70 | start_pos = (center_idx-self.coarticulation_factor)*self.stride
71 | end_pos = (center_idx+self.coarticulation_factor+1)*self.stride
72 | return audio[:,start_pos:end_pos]
73 |
74 | def get_frames_from_range(start_idx,end_idx):
75 | start_pos = start_idx + self.coarticulation_factor
76 | start_idx = (start_pos-self.coarticulation_factor)*self.stride
77 | end_pos = end_idx+self.coarticulation_factor
78 | end_idx = (end_pos+self.coarticulation_factor+1)*self.stride
79 | return audio[:,start_idx:end_idx]
80 |
81 | return get_frame_from_idx,get_frames_from_range
82 |
83 |
84 |
--------------------------------------------------------------------------------
/wav2mov/utils/cnn_shape_calc.py:
--------------------------------------------------------------------------------
1 | def out_d(in_d,k,p,s):
2 | """calculates out shape of 2d cnn
3 |
4 | Args:
5 | in_d (int): height or width
6 | k (int):kernel height or width
7 |
8 | p (int): padding along height or width
9 |
10 | s (int): stride along height or width
11 |
12 | Returns:
13 | int
14 | """
15 | return ((in_d-k+2*p)/s)+1
16 |
17 | def out_d_transposed(in_d,k,p,s):
18 | """ find z and p
19 | insert z number of zeros between each row and column (2*(i-1)x2*(i-1))
20 | pad with p number of zeros
21 | perform standard convolution
22 | """
23 | return ((in_d-1)*s)+k-(2*p)
--------------------------------------------------------------------------------
/wav2mov/utils/files.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | def move_dir(src,dest):
5 | os.makedirs(dest, exist_ok=True)
6 | print('src :',src)
7 | print('dest :',dest)
8 | shutil.move(src, dest)
9 | print(f'[SHUTIL]: Folder moved : src ({src}) dest ({dest})')
10 |
11 | def mov_log_dir(options,config):
12 | src = os.path.dirname(config['log_fullpath'])
13 | if os.path.exists(src):
14 | dest = os.path.join(os.path.join(config['base_dir'],'logs'),
15 | options.version)
16 | # os.makedirs(dest,exist_ok=True)
17 | move_dir(src,dest)
18 |
19 | def mov_out_dir(options,config):
20 | if not options.test in ['y', 'yes']:
21 | return
22 | run = os.path.basename(options.model_path).strip('gen_').split('.')[0]
23 | version_dir = os.path.dirname(os.path.dirname(options.model_path))
24 | print('version dir',version_dir)
25 | version = os.path.basename(version_dir)
26 | if 'v' not in version:
27 | version = options.version
28 | src = os.path.join(config['out_dir'], run)
29 | if os.path.exists(src):
30 | dest = os.path.join(config['out_dir'], version)
31 | # os.makedirs(dest,exist_ok=True)
32 | move_dir(src, dest)
33 |
--------------------------------------------------------------------------------
/wav2mov/utils/misc.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | def log_run_time(logger):
4 | def timeit(func):
5 | def timed(*args, **kwargs):
6 | start_time = time.time()
7 | res = func(*args, **kwargs)
8 | end_time = time.time()
9 | time_taken = end_time-start_time
10 | logger.log(f'[TIME TAKEN] {func.__name__} took {time_taken:0.2f} seconds (or) {time_taken/60:0.2f} minutes',type="DEBUG")
11 | return res
12 | return timed
13 | return timeit
14 |
15 |
16 |
17 | class AverageMeter:
18 | def __init__(self,name,fmt=':0.4f'):
19 |
20 | self.name = name
21 | self.fmt = fmt
22 | self.reset()
23 |
24 | def reset(self):
25 | self.sum = 0
26 | self.count = 0
27 | self.avg = 0
28 |
29 | def _update_average(self):
30 | self.avg = self.sum/self.count
31 |
32 | def update(self,val,n):
33 | self.count +=n
34 | self.sum += val*n
35 | self._update_average()
36 |
37 | def __str__(self):
38 | fmt_str = '{name} : {avg' + self.fmt + '}'
39 | return fmt_str.format(**self.__dict__)
40 |
41 | def add(self,val):
42 | self.sum += val
43 | self._update_average()
44 |
45 | class AverageMetersList:
46 | def __init__(self, names, fmt=':0.4f'):
47 | self.meters = {name:AverageMeter(name,fmt) for name in names}
48 |
49 | def update(self,d:dict):
50 | """update the average meters
51 |
52 | Args:
53 | d (dict): key is the name of the meter and value is a tuple containing value and the multiplier
54 |
55 | """
56 | for name,(value,n) in d.items():
57 | if n==0:
58 | continue
59 | self.meters[name].update(value,n)
60 |
61 |
62 | def reset(self):
63 | for name in self.meters.keys():
64 | self.meters[name].reset()
65 |
66 | def as_list(self):
67 | return self.meters.values()
68 |
69 | def average(self):
70 | return {name:meter.avg for name,meter in self.meters.items()}
71 |
72 | def get(self,name):
73 | if name not in self.meters:
74 | raise KeyError(f'{name} has not average meter initialized')
75 | return self.meters.get(name)
76 |
77 | def __str__(self):
78 | avg = self.average()
79 | return '\t'.join(f'{key}:{val:0.4f}' for key,val in avg.items())
80 |
81 | class ProgressMeter:
82 | def __init__(self,steps,meters,prefix=''):
83 | self.batch_fmt_str = self._get_epoch_fmt_str(steps)
84 | self.meters = meters
85 | self.prefix = prefix
86 |
87 | def get_display_str(self,step):
88 | entries = [self.prefix + self.batch_fmt_str.format(step)]
89 | entries += [str(meter) for meter in self.meters]
90 | return '\t'.join(entries)
91 |
92 | def _get_epoch_fmt_str(self,steps):
93 | num_digits = len(str(steps//1))
94 | fmt = '{:' + str(num_digits) + 'd}'
95 | return '[' + fmt +'/'+ fmt.format(steps) + ']'
96 |
97 |
98 |
99 | def get_duration_in_minutes_seconds(self,duration):
100 | return duration//60,duration
--------------------------------------------------------------------------------
/wav2mov/utils/plots.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.functional import Tensor
4 | from torchvision.io import write_video
5 |
6 | import imageio
7 | import numpy as np
8 | from matplotlib import pyplot as plt
9 |
10 | from moviepy import editor as mpy
11 | from scipy.io.wavfile import write as write_audio
12 |
13 | import warnings
14 | warnings.filterwarnings( "ignore", module = r"matplotlib\..*" )
15 |
16 | from wav2mov.logger import get_module_level_logger
17 | logger = get_module_level_logger(__name__)
18 |
19 | def no_grad_wrapper(fn):
20 | def wrapper(*args,**kwargs):
21 | with torch.no_grad():
22 | return fn(*args,**kwargs)
23 | return wrapper
24 |
25 | @no_grad_wrapper
26 | def show_img(img,cmap='viridis'):
27 | if isinstance(img,np.ndarray):
28 | img_np = img
29 | else:
30 | if len(img.shape)>3:
31 | img = img.squeeze(0)
32 | img_np = img.cpu().numpy()
33 | img_np = np.transpose(img_np, (1, 2, 0))
34 | # print(img_np.shape)
35 | # print(img_np)
36 | if img_np.shape[2]==1:#if single channel
37 | img_np = img_np.squeeze(2)
38 | plt.imshow(img_np,cmap=cmap)
39 | plt.show()
40 | #
41 |
42 | def save_gif(gif_path,images,duration=0.5):
43 | """creates gif
44 |
45 | Args:
46 | gif_path (str): path where gif file should be saved
47 | images (torch.funcitonal.Tensor): tensor of images of shape (N,C,H,W)
48 | """
49 | if isinstance(images,Tensor):
50 | images = images.numpy()
51 |
52 |
53 | images = images.transpose(0,2,3,1).astype('uint8')
54 |
55 | imageio.mimsave(gif_path,images,duration=0.5)
56 |
57 | def save_video(hparams,video_path,audio,video_frames):
58 | """
59 | audio_frames : C,S
60 | video_frames : T,C,H,W
61 | """
62 | if video_frames.shape[1]==1:
63 | video_frames = video_frames.repeat(1,3,1,1)
64 | logger.debug(f'video frames :{video_frames.shape}, audio : {audio.shape}')
65 | video_frames = video_frames.to(torch.uint8)
66 | write_video(filename= video_path,
67 | video_array = video_frames.permute(0,2,3,1),
68 | fps = hparams['video_fps'],
69 | video_codec="h264",
70 | # audio_array= audio,
71 | # audio_fps = hparams['audio_sf'],
72 | # audio_codec = 'mp3'
73 | )
74 | dir_name = os.path.dirname(video_path)
75 | temp_audio_path = os.path.join(dir_name,'temp','temp_audio.wav')
76 | os.makedirs(os.path.dirname(temp_audio_path),exist_ok=True)
77 | write_audio(temp_audio_path,hparams['audio_sf'],audio.cpu().numpy().reshape(-1))
78 |
79 | video_clip = mpy.VideoFileClip(video_path)
80 | audio_clip = mpy.AudioFileClip(temp_audio_path)
81 | video_clip.audio = audio_clip
82 | video_clip.write_videofile(os.path.join(dir_name,'fake_video_with_audio.avi'),fps=hparams['video_fps'],codec='png')
83 |
84 | def save_video_v2(hparams,filepath,audio,video_frames):
85 | def get_video_frames(idx):
86 | idx = int(idx)
87 | # logger.debug(f'{video_frames.shape} ,{video_frames[idx].shape}')
88 | frame = video_frames[idx].permute(1,2,0).squeeze()
89 | return frame.cpu().numpy().astype(np.uint8)
90 |
91 | logger.debug('saving video please wait...')
92 | num_frames = video_frames.shape[0]
93 | video_fps = hparams['data']['video_fps']
94 | audio_sf = hparams['data']['audio_sf']
95 | duration = audio.squeeze().shape[0]/audio_sf
96 | # duration = 10
97 | # duration = math.ceil(num_frames/video_fps)
98 | logger.debug(f'duation {duration} seconds')
99 | dir_name = os.path.dirname(filepath)
100 | temp_audio_path = os.path.join(dir_name,'temp','temp_audio.wav')
101 | os.makedirs(os.path.dirname(temp_audio_path),exist_ok=True)
102 | # print(audio.cpu().numpy().reshape(-1).shape)
103 | write_audio(temp_audio_path,audio_sf,audio.cpu().numpy().reshape(-1))
104 |
105 | video_clip = mpy.VideoClip(make_frame=get_video_frames,duration=duration)
106 | audio_clip = mpy.AudioFileClip(temp_audio_path,fps=audio_sf)
107 | video_clip = video_clip.set_audio(audio_clip)
108 | # print(filepath,video_clip.write_videofile.__doc__)
109 | video_clip.write_videofile( filepath,
110 | fps=video_fps,
111 | codec="png",
112 | bitrate=None,
113 | audio=True,
114 | audio_fps=audio_sf,
115 | preset="medium",
116 | # audio_nbytes=4,
117 | audio_codec=None,
118 | audio_bitrate=None,
119 | # audio_bufsize=2000,
120 | temp_audiofile=None,
121 | # temp_audiofile_path="",
122 | remove_temp=True,
123 | write_logfile=False,
124 | threads=None,
125 | ffmpeg_params=['-s','256x256','-aspect','1:1'],
126 | logger="bar",
127 | # pixel_format='gray
128 | )
129 |
--------------------------------------------------------------------------------