├── .gitignore ├── README.md ├── animator.py ├── app.py ├── data ├── fom │ ├── vox-256.yaml │ └── vox-adv-256.yaml ├── images │ ├── Geralt │ │ ├── 0.jpg │ │ ├── 1.jpg │ │ ├── 2.jpg │ │ └── fx.jpg │ ├── Other │ │ └── marilyn.jpg │ └── Yennefer │ │ ├── 0.jpg │ │ ├── 1.jpg │ │ ├── 2.jpg │ │ └── 3.jpg └── sda │ └── image.bmp ├── deepstory.py ├── generate.py ├── interface ├── 1.png ├── 2.png ├── 3.png ├── 4.png └── 5.png ├── modules ├── dctts │ ├── __init__.py │ ├── audio.py │ ├── hparams.py │ ├── layers.py │ ├── ssrn.py │ └── text2mel.py ├── fom │ ├── __init__.py │ ├── animate.py │ ├── dense_motion.py │ ├── generator.py │ ├── keypoint_detector.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ └── replicate.py │ └── util.py └── sda │ ├── __init__.py │ ├── encoder_audio.py │ ├── encoder_image.py │ ├── img_generator.py │ ├── rnn_audio.py │ ├── sda.py │ └── utils.py ├── requirements.txt ├── result.gif ├── result.mp4 ├── static ├── bootstrap │ ├── css │ │ └── bootstrap.min.css │ └── js │ │ └── bootstrap.min.js ├── css │ └── styles.css └── js │ └── jquery.min.js ├── templates ├── animate.html ├── deepstory.js ├── gen_sentences.html ├── gpt2.html ├── index.html ├── map.html ├── models.html ├── sentences.html ├── status.html └── video.html ├── test.py ├── text.txt ├── util.py └── voice.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | **/*.pth 4 | **/*.tar 5 | **/*.dat 6 | export/ 7 | data/gpt2/ 8 | temp/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deepstory 2 | Deepstory is an artwork that incorporates Natural Language Generation(NLG) w/GPT-2, Text-to-Speech(TTS) w/Deep Convolutional TTS, speech to animation w/Speech driven animation and image animation w/First Order Motion Model into a media application. 3 | 4 | To put it simply, it turns a text/generated text into a video where the character is animated to speak your story using his/her voice. 5 | 6 | You can convert image into a video like this: 7 | 8 | ![result](https://raw.githubusercontent.com/thetobysiu/deepstory/master/result.gif) 9 | 10 | It provides a comfortable web interface and backend written with flask to create your own story. 11 | 12 | It supports transformers model, and pytorch-dctts models 13 | 14 | ## Live Demo 15 | Colab (flask-ngrok): https://colab.research.google.com/drive/1HYCPUmFw5rN8kvZdwzFpfBlaUMWPNHas?usp=sharing 16 | 17 | Video (In case you need instructions): https://blog.thetobysiu.com/video/ 18 | 19 | ## Updates 20 | 21 | 1. Redesign interface, especially the whole GPT2 interface 22 | 2. GPT2 now support text loading from original data, so that it can continue to generate a story based on the book 23 | 3. Figure out the token limits in GPT2 and only infer to the nearest 1024 - predict length tokens 24 | 4. GPT2 support interactive mode that generates several batches of sentences and provides an interface to add those sentence 25 | 5. Sentence speaker mapping system, not replacing all speaker by default anymore 26 | 6. text normalization is now in the synthesizing stage so that punctuations are preserved and can be referenced to have a variable duration in synthesized audio 27 | 7. Audio synthesizing are now all in temp folder, synthesized audios are trimmed so that it's animated video is more accurate(sda mode trained data are short also) 28 | 8. Combined audio now have variable silences according to punctuation 29 | 9. Basically, rewrite the web interface and lots of codes... 30 | 31 | Colab version will be available soon! 32 | 33 | ## Interface 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | ## Folder structure 44 | ``` 45 | Deepstory 46 | ├── animator.py 47 | ├── app.py 48 | ├── data 49 | │   ├── dctts 50 | │   │   ├── Geralt 51 | │   │   │   ├── ssrn.pth 52 | │   │   │   └── t2m.pth 53 | │   │   ├── LJ 54 | │   │   │   ├── ssrn.pth 55 | │   │   │   └── t2m.pth 56 | │   │   └── Yennefer 57 | │   │   ├── ssrn.pth 58 | │   │   └── t2m.pth 59 | │   ├── fom 60 | │   │   ├── vox-256.yaml 61 | │   │   ├── vox-adv-256.yaml 62 | │   │   ├── vox-adv-cpk.pth.tar 63 | │   │   └── vox-cpk.pth.tar 64 | │   ├── gpt2 65 | │   │   ├── Waiting for Godot 66 | │   │   │   ├── config.json 67 | │   │   │   ├── default.txt 68 | │   │   │   ├── merges.txt 69 | │   │   │   ├── pytorch_model.bin 70 | │   │   │   ├── special_tokens_map.json 71 | │   │   │   ├── text.txt 72 | │   │   │   ├── tokenizer_config.json 73 | │   │   │   └── vocab.json 74 | │   │   └── Witcher Books 75 | │   │   ├── config.json 76 | │   │   ├── default.txt 77 | │   │   ├── merges.txt 78 | │   │   ├── pytorch_model.bin 79 | │   │   ├── special_tokens_map.json 80 | │   │   ├── text.txt 81 | │   │   ├── tokenizer_config.json 82 | │   │   └── vocab.json 83 | │   ├── images 84 | │   │   ├── Geralt 85 | │   │   │   ├── 0.jpg 86 | │   │   │   └── fx.jpg 87 | │   │   └── Yennefer 88 | │   │   ├── 0.jpg 89 | │   │   ├── 1.jpg 90 | │   │   ├── 2.jpg 91 | │   │   ├── 3.jpg 92 | │   │   ├── 4.jpg 93 | │   │   └── 5.jpg 94 | │   └── sda 95 | │   ├── grid.dat 96 | │   └── image.bmp 97 | ├── deepstory.py 98 | ├── generate.py 99 | ├── modules 100 | │   ├── dctts 101 | │   │   ├── audio.py 102 | │   │   ├── hparams.py 103 | │   │   ├── __init__.py 104 | │   │   ├── layers.py 105 | │   │   ├── ssrn.py 106 | │   │   └── text2mel.py 107 | │   ├── fom 108 | │   │   ├── animate.py 109 | │   │   ├── dense_motion.py 110 | │   │   ├── generator.py 111 | │   │   ├── __init__.py 112 | │   │   ├── keypoint_detector.py 113 | │   │   ├── sync_batchnorm 114 | │   │   │   ├── batchnorm.py 115 | │   │   │   ├── comm.py 116 | │   │   │   ├── __init__.py 117 | │   │   │   └── replicate.py 118 | │   │   └── util.py 119 | │   └── sda 120 | │   ├── encoder_audio.py 121 | │   ├── encoder_image.py 122 | │   ├── img_generator.py 123 | │   ├── __init__.py 124 | │   ├── rnn_audio.py 125 | │   ├── sda.py 126 | │   └── utils.py 127 | ├── README.md 128 | ├── requirements.txt 129 | ├── static 130 | │   ├── bootstrap 131 | │   │   ├── css 132 | │   │   │   └── bootstrap.min.css 133 | │   │   └── js 134 | │   │   └── bootstrap.min.js 135 | │   ├── css 136 | │   │   └── styles.css 137 | │   └── js 138 | │   └── jquery.min.js 139 | ├── templates 140 | │   ├── animate.html 141 | │   ├── deepstory.js 142 | │   ├── gen_sentences.html 143 | │   ├── gpt2.html 144 | │   ├── index.html 145 | │   ├── map.html 146 | │   ├── models.html 147 | │   ├── sentences.html 148 | │   ├── status.html 149 | │   └── video.html 150 | ├── test.py 151 | ├── text.txt 152 | ├── util.py 153 | └── voice.py 154 | ``` 155 | 156 | ## Complete project download 157 | They are available at the google drive version of this project. All the models(including Geralt, Yennefer) are included. 158 | 159 | You have to download the spacy english model first. 160 | 161 | make sure you have ffmpeg installed in your computer, and ffmpeg-python installed. 162 | 163 | https://drive.google.com/drive/folders/1AxORLF-QFd2wSORzMOKlvCQSFhdZSODJ?usp=sharing 164 | 165 | To simplify things, a google colab version will be released soon... 166 | 167 | ## Requirements 168 | It is required to have an nvidia GPU with at least 4GB of VRAM to run this project 169 | 170 | ## Credits 171 | https://github.com/tugstugi/pytorch-dc-tts 172 | 173 | https://github.com/DinoMan/speech-driven-animation 174 | 175 | https://github.com/AliaksandrSiarohin/first-order-model 176 | 177 | https://github.com/huggingface/transformers 178 | 179 | ## Notes 180 | The whole project uses PyTorch, while tensorflow is listed in requirements.txt, it was used for transformers to convert a model trained from gpt-2-simple to a Pytorch model. 181 | 182 | Only the files inside modules folder are slightly modified from the original. The remaining files are all written by me, except some parts that are referenced. 183 | 184 | ## Bugs 185 | There's still some memory issues if you synthesize sentences within a session over and over, but it takes at least 10 times to cause memory overflow. 186 | 187 | ## Training models 188 | There are other repos of tools that I created to preprocess the files. They can be found in my profile. -------------------------------------------------------------------------------- /animator.py: -------------------------------------------------------------------------------- 1 | # SIU KING WAI SM4701 Deepstory 2 | # mostly referenced from demo.py of first order model github repo, optimized loading in gpu vram 3 | import imageio 4 | import yaml 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | from modules.fom import OcclusionAwareGenerator, KPDetector, DataParallelWithCallback, normalize_kp 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | class ImageAnimator: 14 | def __init__(self): 15 | self.config_path = 'data/fom/vox-256.yaml' 16 | self.checkpoint_path = 'data/fom/vox-cpk.pth.tar' 17 | self.generator = None 18 | self.kp_detector = None 19 | 20 | def __enter__(self): 21 | self.load() 22 | return self 23 | 24 | def __exit__(self, exc_type, exc_val, exc_tb): 25 | self.close() 26 | 27 | def load(self): 28 | with open(self.config_path) as f: 29 | config = yaml.load(f) 30 | 31 | self.generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], 32 | **config['model_params']['common_params']).to(device) 33 | 34 | self.kp_detector = KPDetector(**config['model_params']['kp_detector_params'], 35 | **config['model_params']['common_params']).to(device) 36 | 37 | checkpoint = torch.load(self.checkpoint_path) 38 | 39 | self.generator.load_state_dict(checkpoint['generator']) 40 | self.kp_detector.load_state_dict(checkpoint['kp_detector']) 41 | 42 | del checkpoint 43 | 44 | self.generator = DataParallelWithCallback(self.generator) 45 | self.kp_detector = DataParallelWithCallback(self.kp_detector) 46 | 47 | self.generator.eval() 48 | self.kp_detector.eval() 49 | 50 | def close(self): 51 | del self.generator 52 | del self.kp_detector 53 | torch.cuda.empty_cache() 54 | 55 | def animate_image(self, source_image, driving_video, output_path, relative=True, adapt_movement_scale=True): 56 | with torch.no_grad(): 57 | predictions = [] 58 | # ==================================================================================== 59 | # adapted from original to optimize memory load in gpu instead of cpu 60 | source_image = imageio.imread(source_image) 61 | # normalize color to float 0-1 62 | source = torch.from_numpy(source_image[np.newaxis].astype(np.float32)).to('cuda') / 255 63 | del source_image 64 | source = source.permute(0, 3, 1, 2) 65 | # resize 66 | source = F.interpolate(source, size=(256, 256), mode='area') 67 | 68 | # modified to fit speech driven animation 69 | driving = torch.from_numpy(driving_video).to('cuda') / 255 70 | del driving_video 71 | driving = F.interpolate(driving, scale_factor=2, mode='bilinear', align_corners=False) 72 | # pad the left and right side of the scaled 128x96->256x192 to fit 256x256 73 | driving = F.pad(input=driving, pad=(32, 32, 0, 0, 0, 0, 0, 0), mode='constant', value=0) 74 | driving = driving.permute(1, 0, 2, 3).unsqueeze(0) 75 | # ==================================================================================== 76 | kp_source = self.kp_detector(source) 77 | kp_driving_initial = self.kp_detector(driving[:, :, 0]) 78 | 79 | for frame_idx in range(driving.shape[2]): 80 | driving_frame = driving[:, :, frame_idx] 81 | kp_driving = self.kp_detector(driving_frame) 82 | kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, 83 | kp_driving_initial=kp_driving_initial, use_relative_movement=relative, 84 | use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) 85 | out = self.generator(source, kp_source=kp_source, kp_driving=kp_norm) 86 | out['prediction'] *= 255 87 | out['prediction'] = out['prediction'].byte() 88 | # predictions.append(out['prediction'][0].cpu().numpy()) 89 | predictions.append(out['prediction'].permute(0, 2, 3, 1)[0].cpu().numpy()) 90 | imageio.mimsave(output_path, predictions, fps=25) 91 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # SIU KING WAI SM4701 Deepstory 2 | from flask import Flask, render_template, request, make_response, send_from_directory 3 | from deepstory import Deepstory 4 | app = Flask(__name__) 5 | deepstory = Deepstory() 6 | print('Deepstory instance created') 7 | 8 | 9 | def send_message(message, status=200): 10 | response = make_response(message, status) 11 | response.mimetype = "text/plain" 12 | return response 13 | 14 | 15 | @app.route('/') 16 | def index(): 17 | return render_template('index.html', model_list=deepstory.model_list, gpt2_list=deepstory.gpt2_list, 18 | image_dict=deepstory.image_dict, speaker_map=deepstory.speaker_map_dict) 19 | 20 | 21 | @app.route('/map') 22 | def map_page(): 23 | return render_template('map.html', model_list=deepstory.model_list, speaker_map=deepstory.speaker_map_dict) 24 | 25 | 26 | @app.route('/deepstory.js') 27 | def deepstoryjs(): 28 | response = make_response(render_template('deepstory.js')) 29 | response.mimetype = "text/javascript" 30 | return response 31 | 32 | 33 | @app.route('/status') 34 | def status(): 35 | return render_template('status.html', synthsized=deepstory.is_synthesized, 36 | combined=deepstory.is_processed, base=deepstory.is_base, animated=deepstory.is_animated) 37 | 38 | 39 | @app.route('/gpt2') 40 | def gpt2(): 41 | return render_template('gpt2.html', gpt2=deepstory.current_gpt2, generated_text=deepstory.generated_text) 42 | 43 | 44 | @app.route('/gen_sents') 45 | def gen_sents(): 46 | return render_template('gen_sentences.html', sentences=deepstory.generated_sentences) 47 | 48 | 49 | @app.route('/sentences') 50 | def sentences(): 51 | return render_template('sentences.html', 52 | sentences=deepstory.sentence_dicts, model_list=deepstory.model_list, 53 | speaker_map=deepstory.speaker_map_dict) 54 | 55 | 56 | @app.route('/load_text') 57 | def load_text(): 58 | model = request.args.get('model') 59 | lines_no = int(request.args.get('lines')) 60 | try: 61 | deepstory.load_text(model, lines_no) 62 | return send_message(f'{lines_no} in {model} loaded.') 63 | except FileNotFoundError: 64 | return send_message(f'Text file not found.', 403) 65 | 66 | 67 | @app.route('/load_gpt2', methods=['GET']) 68 | def load_gpt2(): 69 | model = request.args.get('model') 70 | if deepstory.gpt2: 71 | if deepstory.current_gpt2 == model: 72 | return send_message(f'{model} is already loaded.', 403) 73 | deepstory.load_gpt2(model) 74 | return send_message(f'{model} loaded.') 75 | 76 | 77 | @app.route('/generate_text', methods=['POST']) 78 | def generate_text(): 79 | if deepstory.gpt2: 80 | text = request.form.get('text') 81 | predict_length = int(request.form.get('predict_length')) 82 | top_p = float(request.form.get('top_p')) 83 | top_k = int(request.form.get('top_k')) 84 | temperature = float(request.form.get('temperature')) 85 | do_sample = bool(request.form.get('do_sample')) 86 | deepstory.generate_text_gpt2(text, predict_length, top_p, top_k, temperature, do_sample) 87 | return send_message(f'Generated.') 88 | else: 89 | return send_message(f'Please load a GPT2 model first.', 403) 90 | 91 | 92 | @app.route('/generate_sents', methods=['POST']) 93 | def generate_sents(): 94 | if deepstory.gpt2: 95 | text = request.form.get('text') 96 | predict_length = int(request.form.get('predict_length')) 97 | top_p = float(request.form.get('top_p')) 98 | top_k = int(request.form.get('top_k')) 99 | temperature = float(request.form.get('temperature')) 100 | do_sample = bool(request.form.get('do_sample')) 101 | batches = int(request.form.get('batches')) 102 | max_sentences = int(request.form.get('max_sentences')) 103 | deepstory.generate_sents_gpt2( 104 | text, predict_length, top_p, top_k, temperature, do_sample, batches, max_sentences) 105 | return send_message(f'Generated.') 106 | else: 107 | return send_message(f'Please load a GPT2 model first.', 403) 108 | 109 | 110 | @app.route('/add_sent', methods=['GET']) 111 | def add_sent(): 112 | sent_id = int(request.args.get('id')) 113 | deepstory.add_sent(sent_id) 114 | return send_message(f'Sentence {sent_id} added.') 115 | 116 | 117 | @app.route('/load_sentence', methods=['POST']) 118 | def load_sentence(): 119 | text = request.form.get('text') 120 | speaker = request.form.get('speaker') 121 | if text: 122 | is_comma = bool(request.form.get('isComma')) 123 | is_chopped = bool(request.form.get('isChopped')) 124 | is_speaker = bool(request.form.get('isSpeaker')) 125 | force = bool(request.form.get('force')) 126 | n = int(request.form.get('n')) 127 | deepstory.parse_text(text, 128 | n_gram=n, 129 | default_speaker=speaker, 130 | separate_comma=is_comma, 131 | separate_sentence=is_chopped, 132 | parse_speaker=is_speaker, 133 | force_parse=force) 134 | return send_message(f'Sentences loaded.') 135 | else: 136 | return send_message('Please enter text.', 403) 137 | 138 | 139 | @app.route('/animate', methods=['GET', 'POST']) 140 | def animate(): 141 | if request.method == 'POST': 142 | deepstory.animate_image(request.form) 143 | return send_message(f'Images animated.') 144 | elif request.method == 'GET': 145 | return render_template('animate.html', 146 | image_dict=deepstory.image_dict, 147 | loaded_speakers=deepstory.get_base_speakers()) 148 | 149 | 150 | @app.route('/modify', methods=['POST']) 151 | def modify(): 152 | deepstory.modify_speaker(request.json) 153 | return send_message(f'Speaker updated.') 154 | 155 | 156 | @app.route('/clear') 157 | def clear(): 158 | deepstory.clear_cache() 159 | return send_message(f'Cache cleared.') 160 | 161 | 162 | @app.route('/update_map', methods=['POST']) 163 | def update_map(): 164 | deepstory.update_mapping(request.form) 165 | return send_message(f'Mapping updated.') 166 | 167 | 168 | @app.route('/synthesize') 169 | def synthesize(): 170 | if deepstory.sentence_dicts: 171 | try: 172 | deepstory.synthesize_wavs() 173 | return send_message(f'Sentences synthesized.') 174 | except FileNotFoundError: 175 | return send_message("One of the model doesn't exist, please modify to something else.", 403) 176 | else: 177 | return send_message('Please enter text.', 403) 178 | 179 | 180 | @app.route('/combine') 181 | def combine(): 182 | try: 183 | deepstory.process_wavs() 184 | return send_message(f'Clip created.') 185 | except (KeyError, ValueError): 186 | return send_message('No audio is synthesized to be combined', 403) 187 | # except: 188 | # return send_message('Unknown Error.', 403) 189 | 190 | 191 | @app.route('/create_base') 192 | def create_base(): 193 | if deepstory.is_processed: 194 | deepstory.wav_to_vid() 195 | return send_message(f'Base video created.') 196 | else: 197 | return send_message('No audio is synthesized to be processed', 403) 198 | 199 | 200 | @app.route("/wav/") 201 | def stream(sentence_id): 202 | response = make_response(deepstory.stream(sentence_id), 200) 203 | response.mimetype = "audio/x-wav" 204 | return response 205 | 206 | 207 | @app.route('/image/') 208 | def image_viewer(filename): 209 | return send_from_directory(f'data/images/', filename) 210 | 211 | 212 | @app.route('/video') 213 | def video(): 214 | return render_template('video.html', animated=deepstory.is_animated) 215 | 216 | 217 | @app.route('/get_video') 218 | def video_viewer(): 219 | return send_from_directory(f'export', 'animated.mp4') 220 | 221 | 222 | if __name__ == '__main__': 223 | app.run() 224 | -------------------------------------------------------------------------------- /data/fom/vox-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/vox-png 3 | frame_shape: [256, 256, 3] 4 | id_sampling: True 5 | pairs_list: data/vox256.csv 6 | augmentation_params: 7 | flip_param: 8 | horizontal_flip: True 9 | time_flip: True 10 | jitter_param: 11 | brightness: 0.1 12 | contrast: 0.1 13 | saturation: 0.1 14 | hue: 0.1 15 | 16 | 17 | model_params: 18 | common_params: 19 | num_kp: 10 20 | num_channels: 3 21 | estimate_jacobian: True 22 | kp_detector_params: 23 | temperature: 0.1 24 | block_expansion: 32 25 | max_features: 1024 26 | scale_factor: 0.25 27 | num_blocks: 5 28 | generator_params: 29 | block_expansion: 64 30 | max_features: 512 31 | num_down_blocks: 2 32 | num_bottleneck_blocks: 6 33 | estimate_occlusion_map: True 34 | dense_motion_params: 35 | block_expansion: 64 36 | max_features: 1024 37 | num_blocks: 5 38 | scale_factor: 0.25 39 | discriminator_params: 40 | scales: [1] 41 | block_expansion: 32 42 | max_features: 512 43 | num_blocks: 4 44 | sn: True 45 | 46 | train_params: 47 | num_epochs: 100 48 | num_repeats: 75 49 | epoch_milestones: [60, 90] 50 | lr_generator: 2.0e-4 51 | lr_discriminator: 2.0e-4 52 | lr_kp_detector: 2.0e-4 53 | batch_size: 40 54 | scales: [1, 0.5, 0.25, 0.125] 55 | checkpoint_freq: 50 56 | transform_params: 57 | sigma_affine: 0.05 58 | sigma_tps: 0.005 59 | points_tps: 5 60 | loss_weights: 61 | generator_gan: 0 62 | discriminator_gan: 1 63 | feature_matching: [10, 10, 10, 10] 64 | perceptual: [10, 10, 10, 10, 10] 65 | equivariance_value: 10 66 | equivariance_jacobian: 10 67 | 68 | reconstruction_params: 69 | num_videos: 1000 70 | format: '.mp4' 71 | 72 | animate_params: 73 | num_pairs: 50 74 | format: '.mp4' 75 | normalization_params: 76 | adapt_movement_scale: False 77 | use_relative_movement: True 78 | use_relative_jacobian: True 79 | 80 | visualizer_params: 81 | kp_size: 5 82 | draw_border: True 83 | colormap: 'gist_rainbow' 84 | -------------------------------------------------------------------------------- /data/fom/vox-adv-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: data/vox-png 3 | frame_shape: [256, 256, 3] 4 | id_sampling: True 5 | pairs_list: data/vox256.csv 6 | augmentation_params: 7 | flip_param: 8 | horizontal_flip: True 9 | time_flip: True 10 | jitter_param: 11 | brightness: 0.1 12 | contrast: 0.1 13 | saturation: 0.1 14 | hue: 0.1 15 | 16 | 17 | model_params: 18 | common_params: 19 | num_kp: 10 20 | num_channels: 3 21 | estimate_jacobian: True 22 | kp_detector_params: 23 | temperature: 0.1 24 | block_expansion: 32 25 | max_features: 1024 26 | scale_factor: 0.25 27 | num_blocks: 5 28 | generator_params: 29 | block_expansion: 64 30 | max_features: 512 31 | num_down_blocks: 2 32 | num_bottleneck_blocks: 6 33 | estimate_occlusion_map: True 34 | dense_motion_params: 35 | block_expansion: 64 36 | max_features: 1024 37 | num_blocks: 5 38 | scale_factor: 0.25 39 | discriminator_params: 40 | scales: [1] 41 | block_expansion: 32 42 | max_features: 512 43 | num_blocks: 4 44 | use_kp: True 45 | 46 | 47 | train_params: 48 | num_epochs: 150 49 | num_repeats: 75 50 | epoch_milestones: [] 51 | lr_generator: 2.0e-4 52 | lr_discriminator: 2.0e-4 53 | lr_kp_detector: 2.0e-4 54 | batch_size: 36 55 | scales: [1, 0.5, 0.25, 0.125] 56 | checkpoint_freq: 50 57 | transform_params: 58 | sigma_affine: 0.05 59 | sigma_tps: 0.005 60 | points_tps: 5 61 | loss_weights: 62 | generator_gan: 1 63 | discriminator_gan: 1 64 | feature_matching: [10, 10, 10, 10] 65 | perceptual: [10, 10, 10, 10, 10] 66 | equivariance_value: 10 67 | equivariance_jacobian: 10 68 | 69 | reconstruction_params: 70 | num_videos: 1000 71 | format: '.mp4' 72 | 73 | animate_params: 74 | num_pairs: 50 75 | format: '.mp4' 76 | normalization_params: 77 | adapt_movement_scale: False 78 | use_relative_movement: True 79 | use_relative_jacobian: True 80 | 81 | visualizer_params: 82 | kp_size: 5 83 | draw_border: True 84 | colormap: 'gist_rainbow' 85 | -------------------------------------------------------------------------------- /data/images/Geralt/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Geralt/0.jpg -------------------------------------------------------------------------------- /data/images/Geralt/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Geralt/1.jpg -------------------------------------------------------------------------------- /data/images/Geralt/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Geralt/2.jpg -------------------------------------------------------------------------------- /data/images/Geralt/fx.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Geralt/fx.jpg -------------------------------------------------------------------------------- /data/images/Other/marilyn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Other/marilyn.jpg -------------------------------------------------------------------------------- /data/images/Yennefer/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Yennefer/0.jpg -------------------------------------------------------------------------------- /data/images/Yennefer/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Yennefer/1.jpg -------------------------------------------------------------------------------- /data/images/Yennefer/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Yennefer/2.jpg -------------------------------------------------------------------------------- /data/images/Yennefer/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/images/Yennefer/3.jpg -------------------------------------------------------------------------------- /data/sda/image.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/data/sda/image.bmp -------------------------------------------------------------------------------- /deepstory.py: -------------------------------------------------------------------------------- 1 | # SIU KING WAI SM4701 Deepstory 2 | import re 3 | import os 4 | import numpy as np 5 | import scipy 6 | import modules.sda as sda 7 | import glob 8 | import torch 9 | import ffmpeg 10 | import random 11 | 12 | from io import BytesIO 13 | from util import separate, fix_text, trim_text, split_audio_to_list, get_duration 14 | from voice import Voice 15 | from generate import Generator 16 | from animator import ImageAnimator 17 | from modules.dctts import hp 18 | 19 | 20 | class Deepstory: 21 | def __init__(self): 22 | # remove previously created video 23 | # self.clear_cache() 24 | # self.text = 'Geralt|I hate portals. A round of Gwent maybe?' 25 | self.generated_text = '' 26 | self.generated_sentences = [] 27 | self.speaker_dict = {} 28 | self.speaker_map_dict = {} 29 | self.image_dict = { 30 | os.path.basename(os.path.dirname(path)): sorted( 31 | [os.path.basename(file) for file in glob.glob(f'{path}/*.*')]) 32 | for path in glob.glob('data/images/*/') 33 | } 34 | self.sentence_dicts = [] 35 | self.gpt2 = False 36 | self.gpt2_list = [os.path.split(os.path.split(path)[0])[-1] for path in glob.glob('data/gpt2/*/')] 37 | self.speaker_list = [] 38 | self.model_list = [os.path.split(os.path.split(path)[0])[-1] for path in glob.glob('data/dctts/*/')] 39 | 40 | def load_gpt2(self, model_name): 41 | if self.gpt2: 42 | del self.gpt2 43 | torch.cuda.empty_cache() 44 | self.gpt2 = Generator(model_name) 45 | self.generated_text = self.gpt2.default 46 | self.generated_sentences = [] 47 | 48 | def load_text(self, model_name, lines_no): 49 | with open(f'data/gpt2/{model_name}/text.txt', 'r') as f: 50 | lines = f.readlines() 51 | start_index = random.randint(0, len(lines) - 1 - lines_no) 52 | text = ''.join(lines[start_index:start_index+lines_no]) 53 | if text[-1] == '\n': 54 | text = text[:-1] 55 | self.generated_text = text 56 | 57 | @property 58 | def current_gpt2(self): 59 | return self.gpt2.model_name if self.gpt2 else False 60 | 61 | def generate_text_gpt2(self, text, predict_length, top_p, top_k, temperature, do_sample): 62 | self.generated_sentences = [] 63 | script = self.current_gpt2 == 'Waiting for Godot' 64 | result = trim_text(self.gpt2.generate(text, predict_length, top_p, top_k, temperature, do_sample)[0], 65 | script=script) 66 | self.generated_text = text + result 67 | 68 | def generate_sents_gpt2(self, text, predict_length, top_p, top_k, temperature, do_sample, batches, max_sentences): 69 | self.generated_text = text 70 | script = self.current_gpt2 == 'Waiting for Godot' 71 | sents = self.gpt2.generate(text, predict_length, top_p, top_k, temperature, do_sample, num=batches) 72 | self.generated_sentences = [trim_text(sent, max_sentences, script=script) for sent in sents] 73 | 74 | def add_sent(self, sent_id): 75 | self.generated_text += self.generated_sentences[sent_id] 76 | self.generated_sentences = [] 77 | 78 | def parse_text(self, text, default_speaker, force_parse=False, separate_comma=False, 79 | n_gram=2, separate_sentence=False, parse_speaker=True, normalize=True): 80 | """ 81 | Parse the input text into suitable data structure 82 | :param force_parse: forced to replace all speaker that are not in model list as the default speaker 83 | :param n_gram: concat sentences of this max length in a line 84 | :param text: source 85 | :param default_speaker: the default speaker if no speaker in specified 86 | :param separate_comma: split by comma 87 | :param separate_sentence: split sentence if multiple clauses exist 88 | :param parse_speaker: bool for turn on/off parse speaker 89 | :param normalize: to convert common punctuation besides comma to comma 90 | """ 91 | 92 | lines = re.split(r'\r\n|\n\r|\r|\n', text) 93 | 94 | line_speaker_dict = {} 95 | self.speaker_list = [] 96 | self.speaker_map_dict = {} 97 | if parse_speaker: 98 | # re.match(r'^.*(?=:)', text) 99 | for i, line in enumerate(lines): 100 | if re.search(r':|\|', line): 101 | # ?: non capture group of : and | 102 | speaker, line = re.split(r'\s*(?::|\|)\s*', line, 1) 103 | # add entry only if the voice model exist in the folder, 104 | # the unrecognized one will need to mapped so as to be able to be synthesized 105 | if force_parse: 106 | if speaker in self.model_list: 107 | line_speaker_dict[i] = speaker 108 | else: 109 | if speaker not in self.speaker_list: 110 | self.speaker_list.append(speaker) 111 | line_speaker_dict[i] = speaker 112 | lines[i] = line 113 | 114 | for i, speaker in enumerate(self.speaker_list): 115 | if speaker not in self.model_list: 116 | self.speaker_map_dict[speaker] = self.model_list[i % len(self.model_list)] 117 | 118 | # separate by spacy sentencizer 119 | lines = [separate(fix_text(line), n_gram, comma=separate_comma) for line in lines] 120 | 121 | self.sentence_dicts = [] 122 | for i, line in enumerate(lines): 123 | for j, sent in enumerate(line): 124 | if self.sentence_dicts: 125 | # might be buggy, forgot why I wrote this at all 126 | while sent[0].is_punct and not any(sent[0].text == punct for punct in ['“', '‘']): 127 | self.sentence_dicts[-1]['punct'] = self.sentence_dicts[-1]['punct'] + sent.text[0] 128 | sent = sent[1:] 129 | continue 130 | 131 | sentence_dict = { 132 | 'text': sent.text, 133 | 'begin': True if j == 0 else False, 134 | 'punct': '', 135 | 'speaker': line_speaker_dict.get(i, default_speaker) 136 | } 137 | 138 | while not sentence_dict['text'][-1].isalpha(): 139 | sentence_dict['punct'] = sentence_dict['punct'] + sentence_dict['text'][-1] 140 | sentence_dict['text'] = sentence_dict['text'][:-1] 141 | # Reverse the punctuation order since I add it based on the last item 142 | sentence_dict['punct'] = sentence_dict['punct'][::-1] 143 | sentence_dict['text'] = sentence_dict['text'] + sentence_dict['punct'] 144 | self.sentence_dicts.append(sentence_dict) 145 | 146 | self.update_speaker_dict() 147 | 148 | def update_speaker_dict(self): 149 | self.speaker_dict = {} 150 | for i, sentence_dict in enumerate(self.sentence_dicts): 151 | if sentence_dict['speaker'] not in self.speaker_dict: 152 | self.speaker_dict[sentence_dict['speaker']] = [] 153 | self.speaker_dict[sentence_dict['speaker']].append(i) 154 | 155 | def update_mapping(self, map_dict): 156 | for speaker, mapped in map_dict.items(): 157 | self.speaker_map_dict[speaker] = mapped 158 | 159 | def modify_speaker(self, speaker_list): 160 | for i, speaker in enumerate(speaker_list): 161 | self.sentence_dicts[i]['speaker'] = speaker 162 | self.update_speaker_dict() 163 | 164 | def synthesize_wavs(self): 165 | # clear model from vram to prevent out of memory error 166 | if self.current_gpt2: 167 | del self.gpt2 168 | self.gpt2 = None 169 | torch.cuda.empty_cache() 170 | 171 | speaker_dict_mapped = {} 172 | for speaker, sentence_ids in self.speaker_dict.items(): 173 | mapped_speaker = self.speaker_map_dict.get(speaker, speaker) 174 | if mapped_speaker in speaker_dict_mapped: 175 | speaker_dict_mapped[mapped_speaker].extend(sentence_ids) 176 | else: 177 | speaker_dict_mapped[mapped_speaker] = sentence_ids 178 | 179 | for speaker, sentence_ids in speaker_dict_mapped.items(): 180 | with Voice(speaker) as voice: 181 | for i in sentence_ids: 182 | self.sentence_dicts[i]['wav'] = voice.synthesize(self.sentence_dicts[i]['text']) 183 | 184 | @property 185 | def is_synthesized(self): 186 | return 'wav' in self.sentence_dicts[0] if self.sentence_dicts else False 187 | 188 | def process_wavs(self): 189 | """ 190 | Prepare wavs 191 | 1. Add silence at beginning if the sentence is the beginning of a line 192 | 2. Add silence at the end based on punctuation in the sentence 193 | 3. Finely split the audio again so that sentence with comma can be chopped (increasing sda model performance) 194 | 4. Create a cache of the whole combined audio clips 195 | """ 196 | # can be adjusted as you like, in seconds. 197 | punctuation_dict = { 198 | '.': 0.3, 199 | ',': 0.15, 200 | '!': 0.3, 201 | '?': 0.4, 202 | '"': 0.1, 203 | '…': 0.6, 204 | ':': 0.15, 205 | ';': 0.2, 206 | '’': 0.05, 207 | '‘': 0.05, 208 | '”': 0.05, 209 | '“': 0.05 210 | } 211 | 212 | if not os.path.isdir(f'export'): 213 | os.mkdir(f'export') 214 | 215 | if not os.path.isdir(f'temp'): 216 | os.mkdir(f'temp') 217 | 218 | if os.path.isdir(f'temp/audio'): 219 | for path in glob.glob('temp/audio/*'): 220 | os.remove(path) 221 | else: 222 | os.mkdir(f'temp/audio') 223 | 224 | wavs_dicts = [] 225 | for sentence_dict in self.sentence_dicts: 226 | # Add silence between lines 227 | pad_begin = get_duration(0.5) if sentence_dict['begin'] else 0 228 | pad_end = get_duration(sum(float(punctuation_dict.get(punct, 0)) for punct in sentence_dict['punct'])) 229 | wav = np.pad(sentence_dict['wav'], (pad_begin, pad_end), 'constant') 230 | split_list = split_audio_to_list(wav) 231 | for i, wav_slice in enumerate(split_list): 232 | wav_part = wav[slice(*wav_slice)] 233 | # add some more silence so that the video generated would not be that awkward 234 | if i == 0: 235 | wav_part = np.pad(wav_part, (0, get_duration(0.1)), 'constant') 236 | elif i == len(split_list) - 1: 237 | wav_part = np.pad(wav_part, (get_duration(0.1), 0), 'constant') 238 | else: 239 | # append silence at the beginning of slice 240 | wav_part = np.pad(wav_part, (get_duration(0.1), get_duration(0.1)), 'constant') 241 | wavs_dicts.append({ 242 | 'speaker': sentence_dict['speaker'], 243 | 'wav': wav_part 244 | }) 245 | 246 | for i, wav_dict in enumerate(wavs_dicts): 247 | scipy.io.wavfile.write(f'temp/audio/{i:03d}|{wav_dict["speaker"]}.wav', hp.sr, wav_dict['wav']) 248 | 249 | scipy.io.wavfile.write('export/combined.wav', hp.sr, 250 | np.concatenate([wavs_dict['wav'] for wavs_dict in wavs_dicts], axis=None)) 251 | 252 | @property 253 | def is_processed(self): 254 | try: 255 | return bool(os.listdir('temp/audio')) 256 | except FileNotFoundError: 257 | return False 258 | 259 | def stream(self, sentence_id=0): 260 | with BytesIO() as f: 261 | scipy.io.wavfile.write(f, hp.sr, self.sentence_dicts[sentence_id]['wav']) 262 | return f.getvalue() 263 | 264 | @staticmethod 265 | def wav_to_vid(): 266 | """Create Base video for First Order Motion Model""" 267 | if os.path.isdir(f'temp/base'): 268 | for path in glob.glob('temp/base/*'): 269 | os.remove(path) 270 | else: 271 | os.mkdir(f'temp/base') 272 | 273 | va = sda.VideoAnimator(gpu=0) # Instantiate the animator 274 | for audio_path in sorted( 275 | glob.glob('temp/audio/*.wav'), 276 | key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("|")[0]) 277 | ): 278 | np.save( 279 | f'temp/base/{os.path.splitext(os.path.basename(audio_path))[0]}.npy', 280 | va('data/sda/image.bmp', audio_path) 281 | ) 282 | del va 283 | torch.cuda.empty_cache() 284 | 285 | @property 286 | def is_base(self): 287 | try: 288 | return bool(os.listdir('temp/base')) 289 | except FileNotFoundError: 290 | return False 291 | 292 | @staticmethod 293 | def get_base_speakers(): 294 | return set( 295 | os.path.splitext(os.path.basename(base_path))[0].split("|")[1] 296 | for base_path in glob.glob('temp/base/*.npy') 297 | ) 298 | 299 | @staticmethod 300 | def animate_image(image_dict): 301 | if os.path.isdir(f'temp/animated'): 302 | for path in glob.glob('temp/animated/*'): 303 | os.remove(path) 304 | else: 305 | os.mkdir(f'temp/animated') 306 | 307 | with ImageAnimator() as animator: 308 | for i, base_path in enumerate(sorted( 309 | glob.glob('temp/base/*.npy'), 310 | key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("|")[0]))): 311 | speaker = os.path.splitext(os.path.basename(base_path))[0].split("|")[1] 312 | animator.animate_image( 313 | f'data/images/{image_dict[speaker]}', 314 | np.load(base_path), 315 | f'temp/animated/{i}.mp4' 316 | ) 317 | 318 | audio = ffmpeg.input('export/combined.wav').audio 319 | videos = [ffmpeg.input(clip).video for clip in sorted( 320 | glob.glob('temp/animated/*'), key=lambda x: int(os.path.basename(x)[:-4]))] 321 | ffmpeg.concat(*videos).output('export/combined.mp4', loglevel="panic").overwrite_output().run() 322 | video = ffmpeg.input('export/combined.mp4').video 323 | ffmpeg.output( 324 | video, audio, 'export/animated.mp4', loglevel="panic", vcodec="copy", ar=hp.sr, **{'b:a': '128k'} 325 | ).overwrite_output().run() 326 | 327 | @property 328 | def is_animated(self): 329 | return os.path.isfile('export/animated.mp4') 330 | 331 | def clear_cache(self): 332 | # remove previously created video 333 | if self.is_animated: 334 | for path in glob.glob('temp/animated/*'): 335 | os.remove(path) 336 | os.remove('export/animated.mp4') 337 | os.remove('export/combined.mp4') 338 | if self.is_processed: 339 | for path in glob.glob('temp/audio/*'): 340 | os.remove(path) 341 | os.remove('export/combined.wav') 342 | if self.is_base: 343 | for path in glob.glob('temp/base/*'): 344 | os.remove(path) 345 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # SIU KING WAI SM4701 Deepstory 2 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 3 | 4 | 5 | class Generator: 6 | def __init__(self, model_name): 7 | self.model_name = model_name 8 | self.tokenizer = GPT2Tokenizer.from_pretrained(f'data/gpt2/{model_name}') 9 | self.model = GPT2LMHeadModel.from_pretrained(f'data/gpt2/{model_name}').to('cuda') 10 | with open(f'data/gpt2/{model_name}/default.txt', 'r') as f: 11 | text = f.read() 12 | if text[-1] == '\n': 13 | text = text[:-1] 14 | self.default = text 15 | 16 | def generate(self, text, predict_length, top_p, top_k, temperature, do_sample, num=1): 17 | 18 | if text: 19 | # encode input context to gpt2 tokens 20 | input_ids = self.tokenizer.encode(text, return_tensors='pt').to('cuda') 21 | # gpt2 model can only infer to maximum of 1024 tokens 22 | if len(input_ids[0]) + predict_length > 1024: 23 | # take the nearest (1024 - predict_length) tokens from the end while reserving space to generate. 24 | input_ids = input_ids[0][-(1024 - predict_length):].unsqueeze(0) 25 | input_length = len(input_ids[0]) 26 | else: 27 | input_ids = None 28 | input_length = 0 29 | outputs = self.model.generate(input_ids=input_ids, 30 | max_length=input_length + predict_length, 31 | top_p=top_p, 32 | top_k=top_k, 33 | temperature=temperature, 34 | do_sample=do_sample, 35 | num_return_sequences=num) 36 | return [ 37 | self.tokenizer.decode( 38 | outputs[i][input_length:], clean_up_tokenization_spaces=True, skip_special_tokens=True) 39 | for i in range(num) 40 | ] 41 | -------------------------------------------------------------------------------- /interface/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/interface/1.png -------------------------------------------------------------------------------- /interface/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/interface/2.png -------------------------------------------------------------------------------- /interface/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/interface/3.png -------------------------------------------------------------------------------- /interface/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/interface/4.png -------------------------------------------------------------------------------- /interface/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/interface/5.png -------------------------------------------------------------------------------- /modules/dctts/__init__.py: -------------------------------------------------------------------------------- 1 | from .text2mel import Text2Mel 2 | from .ssrn import SSRN 3 | from .hparams import HParams as hp 4 | from .audio import spectrogram2wav 5 | -------------------------------------------------------------------------------- /modules/dctts/audio.py: -------------------------------------------------------------------------------- 1 | """These methods are copied from https://github.com/Kyubyong/dc_tts/""" 2 | 3 | import copy 4 | import librosa 5 | import scipy.io.wavfile 6 | import numpy as np 7 | 8 | from scipy import signal 9 | from .hparams import HParams as hp 10 | 11 | 12 | def spectrogram2wav(mag): 13 | '''# Generate wave file from linear magnitude spectrogram 14 | Args: 15 | mag: A numpy array of (T, 1+n_fft//2) 16 | Returns: 17 | wav: A 1-D numpy array. 18 | ''' 19 | # transpose 20 | mag = mag.T 21 | 22 | # de-noramlize 23 | mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db 24 | 25 | # to amplitude 26 | mag = np.power(10.0, mag * 0.05) 27 | 28 | # wav reconstruction 29 | wav = griffin_lim(mag ** hp.power) 30 | 31 | # de-preemphasis 32 | wav = signal.lfilter([1], [1, -hp.preemphasis], wav) 33 | 34 | # trim 35 | wav, _ = librosa.effects.trim(wav) 36 | 37 | # output as PCM 16 bit 38 | wav *= 32767 39 | return wav.astype(np.int16) 40 | 41 | 42 | def griffin_lim(spectrogram): 43 | '''Applies Griffin-Lim's raw.''' 44 | X_best = copy.deepcopy(spectrogram) 45 | for i in range(hp.n_iter): 46 | X_t = invert_spectrogram(X_best) 47 | est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length) 48 | phase = est / np.maximum(1e-8, np.abs(est)) 49 | X_best = spectrogram * phase 50 | X_t = invert_spectrogram(X_best) 51 | y = np.real(X_t) 52 | 53 | return y 54 | 55 | 56 | def invert_spectrogram(spectrogram): 57 | '''Applies inverse fft. 58 | Args: 59 | spectrogram: [1+n_fft//2, t] 60 | ''' 61 | return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann") 62 | 63 | 64 | def save_to_wav(mag, filename): 65 | """Generate and save an audio file from the given linear spectrogram using Griffin-Lim.""" 66 | wav = spectrogram2wav(mag) 67 | scipy.io.wavfile.write(filename, hp.sr, wav) 68 | -------------------------------------------------------------------------------- /modules/dctts/hparams.py: -------------------------------------------------------------------------------- 1 | """Hyper parameters.""" 2 | __author__ = 'Erdene-Ochir Tuguldur' 3 | 4 | 5 | class HParams: 6 | """Hyper parameters""" 7 | 8 | disable_progress_bar = False # set True if you don't want the progress bar in the console 9 | 10 | logdir = "logdir" # log dir where the checkpoints and tensorboard files are saved 11 | max_load_memory = 4000000000 # h5 file size larger than this will not be load into memory 12 | vocab = "PE abcdefghijklmnopqrstuvwxyz'.,!?" # P: Padding, E: EOS. 13 | char2idx = {char: idx for idx, char in enumerate(vocab)} 14 | idx2char = {idx: char for idx, char in enumerate(vocab)} 15 | 16 | # audio.py options, these values are from https://github.com/Kyubyong/dc_tts/blob/master/hyperparams.py 17 | reduction_rate = 4 # melspectrogram reduction rate, don't change because SSRN is using this rate 18 | n_fft = 2048 # fft points (samples) 19 | n_mels = 80 # Number of Mel banks to generate 20 | power = 1.2 # Exponent for amplifying the predicted magnitude 21 | n_iter = 50 # Number of inversion iterations 22 | preemphasis = .97 23 | max_db = 140 24 | ref_db = 20 25 | sr = 22050 # Sampling rate 26 | frame_shift = 0.0125 # seconds 27 | frame_length = 0.05 # seconds 28 | hop_length = int(sr * frame_shift) # samples. =276. 29 | win_length = int(sr * frame_length) # samples. =1102. 30 | max_N = 259 # Maximum number of characters. 31 | max_T = 326 # Maximum number of mel frames. 32 | 33 | e = 128 # embedding dimension 34 | d = 512 # Text2Mel hidden unit dimension 35 | c = 512+128 # SSRN hidden unit dimension 36 | 37 | dropout_rate = 0.05 # dropout 38 | 39 | # Text2Mel network options 40 | text2mel_lr = 0.005 # learning rate 41 | text2mel_batch_size = 32 42 | text2mel_max_iteration = 300000 # max train step 43 | text2mel_weight_init = 'none' # 'kaiming', 'xavier' or 'none' 44 | text2mel_normalization = 'layer' # 'layer', 'weight' or 'none' 45 | text2mel_basic_block = 'gated_conv' # 'highway', 'gated_conv' or 'residual' 46 | 47 | # SSRN network options 48 | ssrn_lr = 0.0005 # learning rate 49 | ssrn_batch_size = 32 50 | ssrn_max_iteration = 300000 # max train step 51 | ssrn_weight_init = 'kaiming' # 'kaiming', 'xavier' or 'none' 52 | ssrn_normalization = 'weight' # 'layer', 'weight' or 'none' 53 | ssrn_basic_block = 'residual' # 'highway', 'gated_conv' or 'residual' 54 | -------------------------------------------------------------------------------- /modules/dctts/layers.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Erdene-Ochir Tuguldur' 2 | __all__ = ['E', 'D', 'C', 'HighwayBlock', 'GatedConvBlock', 'ResidualBlock'] 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .hparams import HParams as hp 8 | 9 | 10 | class LayerNorm(nn.LayerNorm): 11 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 12 | """Layer Norm.""" 13 | super(LayerNorm, self).__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) 14 | 15 | def forward(self, x): 16 | x = x.permute(0, 2, 1) # PyTorch LayerNorm seems to be expect (B, T, C) 17 | y = super(LayerNorm, self).forward(x) 18 | y = y.permute(0, 2, 1) # reverse 19 | return y 20 | 21 | 22 | class D(nn.Module): 23 | def __init__(self, in_channels, out_channels, kernel_size, dilation, weight_init='none', normalization='weight', nonlinearity='linear'): 24 | """1D Deconvolution.""" 25 | super(D, self).__init__() 26 | self.deconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, 27 | stride=2, # paper: stride of deconvolution is always 2 28 | dilation=dilation) 29 | 30 | if normalization == 'weight': 31 | self.deconv = nn.utils.weight_norm(self.deconv) 32 | elif normalization == 'layer': 33 | self.layer_norm = LayerNorm(out_channels) 34 | 35 | self.nonlinearity = nonlinearity 36 | if weight_init == 'kaiming': 37 | nn.init.kaiming_normal_(self.deconv.weight, mode='fan_out', nonlinearity=nonlinearity) 38 | elif weight_init == 'xavier': 39 | nn.init.xavier_uniform_(self.deconv.weight, nn.init.calculate_gain(nonlinearity)) 40 | 41 | def forward(self, x, output_size=None): 42 | y = self.deconv(x, output_size=output_size) 43 | if hasattr(self, 'layer_norm'): 44 | y = self.layer_norm(y) 45 | y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True) 46 | if self.nonlinearity == 'relu': 47 | y = F.relu(y, inplace=True) 48 | return y 49 | 50 | 51 | class C(nn.Module): 52 | def __init__(self, in_channels, out_channels, kernel_size, dilation, causal=False, weight_init='none', normalization='weight', nonlinearity='linear'): 53 | """1D convolution. 54 | The argument 'causal' indicates whether the causal convolution should be used or not. 55 | """ 56 | super(C, self).__init__() 57 | self.causal = causal 58 | if causal: 59 | self.padding = (kernel_size - 1) * dilation 60 | else: 61 | self.padding = (kernel_size - 1) * dilation // 2 62 | 63 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 64 | stride=1, # paper: 'The stride of convolution is always 1.' 65 | padding=self.padding, dilation=dilation) 66 | 67 | if normalization == 'weight': 68 | self.conv = nn.utils.weight_norm(self.conv) 69 | elif normalization == 'layer': 70 | self.layer_norm = LayerNorm(out_channels) 71 | 72 | self.nonlinearity = nonlinearity 73 | if weight_init == 'kaiming': 74 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity=nonlinearity) 75 | elif weight_init == 'xavier': 76 | nn.init.xavier_uniform_(self.conv.weight, nn.init.calculate_gain(nonlinearity)) 77 | 78 | def forward(self, x): 79 | y = self.conv(x) 80 | padding = self.padding 81 | if self.causal and padding > 0: 82 | y = y[:, :, :-padding] 83 | 84 | if hasattr(self, 'layer_norm'): 85 | y = self.layer_norm(y) 86 | y = F.dropout(y, p=hp.dropout_rate, training=self.training, inplace=True) 87 | if self.nonlinearity == 'relu': 88 | y = F.relu(y, inplace=True) 89 | return y 90 | 91 | 92 | class E(nn.Module): 93 | def __init__(self, num_embeddings, embedding_dim): 94 | super(E, self).__init__() 95 | self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0) 96 | 97 | def forward(self, x): 98 | return self.embedding(x) 99 | 100 | 101 | class HighwayBlock(nn.Module): 102 | def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'): 103 | """Highway Network like layer: https://arxiv.org/abs/1505.00387 104 | The input and output shapes remain same. 105 | Args: 106 | d: input channel 107 | k: kernel size 108 | delta: dilation 109 | causal: causal convolution or not 110 | """ 111 | super(HighwayBlock, self).__init__() 112 | self.d = d 113 | self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, weight_init=weight_init, normalization=normalization) 114 | 115 | def forward(self, x): 116 | L = self.C(x) 117 | H1 = L[:, :self.d, :] 118 | H2 = L[:, self.d:, :] 119 | sigH1 = F.sigmoid(H1) 120 | return sigH1 * H2 + (1 - sigH1) * x 121 | 122 | 123 | class GatedConvBlock(nn.Module): 124 | def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight'): 125 | """Gated convolutional layer: https://arxiv.org/abs/1612.08083 126 | The input and output shapes remain same. 127 | Args: 128 | d: input channel 129 | k: kernel size 130 | delta: dilation 131 | causal: causal convolution or not 132 | """ 133 | super(GatedConvBlock, self).__init__() 134 | self.C = C(in_channels=d, out_channels=2 * d, kernel_size=k, dilation=delta, causal=causal, 135 | weight_init=weight_init, normalization=normalization) 136 | self.glu = nn.GLU(dim=1) 137 | 138 | def forward(self, x): 139 | L = self.C(x) 140 | return self.glu(L) + x 141 | 142 | 143 | class ResidualBlock(nn.Module): 144 | def __init__(self, d, k, delta, causal=False, weight_init='none', normalization='weight', 145 | widening_factor=2): 146 | """Residual block: https://arxiv.org/abs/1512.03385 147 | The input and output shapes remain same. 148 | Args: 149 | d: input channel 150 | k: kernel size 151 | delta: dilation 152 | causal: causal convolution or not 153 | """ 154 | super(ResidualBlock, self).__init__() 155 | self.C1 = C(in_channels=d, out_channels=widening_factor * d, kernel_size=k, dilation=delta, causal=causal, 156 | weight_init=weight_init, normalization=normalization, nonlinearity='relu') 157 | self.C2 = C(in_channels=widening_factor * d, out_channels=d, kernel_size=k, dilation=delta, causal=causal, 158 | weight_init=weight_init, normalization=normalization, nonlinearity='relu') 159 | 160 | def forward(self, x): 161 | return self.C2(self.C1(x)) + x 162 | -------------------------------------------------------------------------------- /modules/dctts/ssrn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara 3 | Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention 4 | https://arxiv.org/abs/1710.08969 5 | 6 | SSRN Network. 7 | """ 8 | __author__ = 'Erdene-Ochir Tuguldur' 9 | __all__ = ['SSRN'] 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .hparams import HParams as hp 15 | from .layers import D, C, HighwayBlock, GatedConvBlock, ResidualBlock 16 | 17 | 18 | def Conv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): 19 | return C(in_channels, out_channels, kernel_size, dilation, causal=False, 20 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity) 21 | 22 | 23 | def DeConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): 24 | return D(in_channels, out_channels, kernel_size, dilation, 25 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, nonlinearity=nonlinearity) 26 | 27 | 28 | def BasicBlock(d, k, delta): 29 | if hp.ssrn_basic_block == 'gated_conv': 30 | return GatedConvBlock(d, k, delta, causal=False, 31 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization) 32 | elif hp.ssrn_basic_block == 'highway': 33 | return HighwayBlock(d, k, delta, causal=False, 34 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization) 35 | else: 36 | return ResidualBlock(d, k, delta, causal=False, 37 | weight_init=hp.ssrn_weight_init, normalization=hp.ssrn_normalization, 38 | widening_factor=1) 39 | 40 | 41 | class SSRN(nn.Module): 42 | def __init__(self, c=hp.c, f=hp.n_mels, f_prime=(1 + hp.n_fft // 2)): 43 | """Spectrogram super-resolution network. 44 | Args: 45 | c: SSRN dim 46 | f: Number of mel bins 47 | f_prime: full spectrogram dim 48 | Input: 49 | Y: (B, f, T) predicted melspectrograms 50 | Outputs: 51 | Z_logit: logit of Z 52 | Z: (B, f_prime, 4*T) full spectrograms 53 | """ 54 | super(SSRN, self).__init__() 55 | self.layers = nn.Sequential( 56 | Conv(f, c, 1, 1), 57 | 58 | BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), 59 | 60 | DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), 61 | DeConv(c, c, 2, 1), BasicBlock(c, 3, 1), BasicBlock(c, 3, 3), 62 | 63 | Conv(c, 2 * c, 1, 1), 64 | 65 | BasicBlock(2 * c, 3, 1), BasicBlock(2 * c, 3, 1), 66 | 67 | Conv(2 * c, f_prime, 1, 1), 68 | 69 | # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'), 70 | # Conv(f_prime, f_prime, 1, 1, nonlinearity='relu'), 71 | BasicBlock(f_prime, 1, 1), 72 | 73 | Conv(f_prime, f_prime, 1, 1) 74 | ) 75 | 76 | def forward(self, x): 77 | Z_logit = self.layers(x) 78 | Z = F.sigmoid(Z_logit) 79 | return Z_logit, Z -------------------------------------------------------------------------------- /modules/dctts/text2mel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hideyuki Tachibana, Katsuya Uenoyama, Shunsuke Aihara 3 | Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention 4 | https://arxiv.org/abs/1710.08969 5 | 6 | Text2Mel Network. 7 | """ 8 | __author__ = 'Erdene-Ochir Tuguldur' 9 | __all__ = ['Text2Mel'] 10 | 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from .hparams import HParams as hp 18 | from .layers import E, C, HighwayBlock, GatedConvBlock, ResidualBlock 19 | 20 | 21 | def Conv(in_channels, out_channels, kernel_size, dilation, causal=False, nonlinearity='linear'): 22 | return C(in_channels, out_channels, kernel_size, dilation, causal=causal, 23 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, nonlinearity=nonlinearity) 24 | 25 | 26 | def BasicBlock(d, k, delta, causal=False): 27 | if hp.text2mel_basic_block == 'gated_conv': 28 | return GatedConvBlock(d, k, delta, causal=causal, 29 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization) 30 | elif hp.text2mel_basic_block == 'highway': 31 | return HighwayBlock(d, k, delta, causal=causal, 32 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization) 33 | else: 34 | return ResidualBlock(d, k, delta, causal=causal, 35 | weight_init=hp.text2mel_weight_init, normalization=hp.text2mel_normalization, 36 | widening_factor=2) 37 | 38 | 39 | def CausalConv(in_channels, out_channels, kernel_size, dilation, nonlinearity='linear'): 40 | return Conv(in_channels, out_channels, kernel_size, dilation, causal=True, nonlinearity=nonlinearity) 41 | 42 | 43 | def CausalBasicBlock(d, k, delta): 44 | return BasicBlock(d, k, delta, causal=True) 45 | 46 | 47 | class TextEnc(nn.Module): 48 | 49 | def __init__(self, vocab, e=hp.e, d=hp.d): 50 | """Text encoder network. 51 | Args: 52 | vocab: vocabulary 53 | e: embedding dim 54 | d: Text2Mel dim 55 | Input: 56 | L: (B, N) text inputs 57 | Outputs: 58 | K: (B, d, N) keys 59 | V: (N, d, N) values 60 | """ 61 | super(TextEnc, self).__init__() 62 | self.d = d 63 | self.embedding = E(len(vocab), e) 64 | 65 | self.layers = nn.Sequential( 66 | Conv(e, 2 * d, 1, 1, nonlinearity='relu'), 67 | Conv(2 * d, 2 * d, 1, 1), 68 | 69 | BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27), 70 | BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 3), BasicBlock(2 * d, 3, 9), BasicBlock(2 * d, 3, 27), 71 | 72 | BasicBlock(2 * d, 3, 1), BasicBlock(2 * d, 3, 1), 73 | 74 | BasicBlock(2 * d, 1, 1), BasicBlock(2 * d, 1, 1) 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.embedding(x) 79 | out = out.permute(0, 2, 1) # change to (B, e, N) 80 | out = self.layers(out) # (B, 2*d, N) 81 | K = out[:, :self.d, :] # (B, d, N) 82 | V = out[:, self.d:, :] # (B, d, N) 83 | return K, V 84 | 85 | 86 | class AudioEnc(nn.Module): 87 | def __init__(self, d=hp.d, f=hp.n_mels): 88 | """Audio encoder network. 89 | Args: 90 | d: Text2Mel dim 91 | f: Number of mel bins 92 | Input: 93 | S: (B, f, T) melspectrograms 94 | Output: 95 | Q: (B, d, T) queries 96 | """ 97 | super(AudioEnc, self).__init__() 98 | self.layers = nn.Sequential( 99 | CausalConv(f, d, 1, 1, nonlinearity='relu'), 100 | CausalConv(d, d, 1, 1, nonlinearity='relu'), 101 | CausalConv(d, d, 1, 1), 102 | 103 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), 104 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), 105 | 106 | CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 3), 107 | ) 108 | 109 | def forward(self, x): 110 | return self.layers(x) 111 | 112 | 113 | class AudioDec(nn.Module): 114 | def __init__(self, d=hp.d, f=hp.n_mels): 115 | """Audio decoder network. 116 | Args: 117 | d: Text2Mel dim 118 | f: Number of mel bins 119 | Input: 120 | R_prime: (B, 2d, T) [V*Attention, Q] paper says: "we found it beneficial in our pilot study." 121 | Output: 122 | Y: (B, f, T) 123 | """ 124 | super(AudioDec, self).__init__() 125 | self.layers = nn.Sequential( 126 | CausalConv(2 * d, d, 1, 1), 127 | 128 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 3), CausalBasicBlock(d, 3, 9), CausalBasicBlock(d, 3, 27), 129 | 130 | CausalBasicBlock(d, 3, 1), CausalBasicBlock(d, 3, 1), 131 | 132 | # CausalConv(d, d, 1, 1, nonlinearity='relu'), 133 | # CausalConv(d, d, 1, 1, nonlinearity='relu'), 134 | CausalBasicBlock(d, 1, 1), 135 | CausalConv(d, d, 1, 1, nonlinearity='relu'), 136 | 137 | CausalConv(d, f, 1, 1) 138 | ) 139 | 140 | def forward(self, x): 141 | return self.layers(x) 142 | 143 | 144 | class Text2Mel(nn.Module): 145 | def __init__(self, vocab, d=hp.d): 146 | """Text to melspectrogram network. 147 | Args: 148 | vocab: vocabulary 149 | d: Text2Mel dim 150 | Input: 151 | L: (B, N) text inputs 152 | S: (B, f, T) melspectrograms 153 | Outputs: 154 | Y_logit: logit of Y 155 | Y: predicted melspectrograms 156 | A: (B, N, T) attention matrix 157 | """ 158 | super(Text2Mel, self).__init__() 159 | self.d = d 160 | self.text_enc = TextEnc(vocab) 161 | self.audio_enc = AudioEnc() 162 | self.audio_dec = AudioDec() 163 | 164 | def forward(self, L, S, monotonic_attention=False): 165 | K, V = self.text_enc(L) 166 | Q = self.audio_enc(S) 167 | A = torch.bmm(K.permute(0, 2, 1), Q) / np.sqrt(self.d) 168 | 169 | if monotonic_attention: 170 | # TODO: vectorize instead of loops 171 | B, N, T = A.size() 172 | for i in range(B): 173 | prva = -1 # previous attention 174 | for t in range(T): 175 | _, n = torch.max(A[i, :, t], 0) 176 | if not (-1 <= n - prva <= 3): 177 | A[i, :, t] = -2 ** 20 # some small numbers 178 | A[i, min(N - 1, prva + 1), t] = 1 179 | _, prva = torch.max(A[i, :, t], 0) 180 | 181 | A = F.softmax(A, dim=1) 182 | R = torch.bmm(V, A) 183 | R_prime = torch.cat((R, Q), 1) 184 | Y_logit = self.audio_dec(R_prime) 185 | Y = F.sigmoid(Y_logit) 186 | return Y_logit, Y, A 187 | -------------------------------------------------------------------------------- /modules/fom/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import OcclusionAwareGenerator 2 | from .keypoint_detector import KPDetector 3 | from .sync_batchnorm import DataParallelWithCallback 4 | from .animate import normalize_kp 5 | -------------------------------------------------------------------------------- /modules/fom/animate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.spatial import ConvexHull 3 | import numpy as np 4 | 5 | 6 | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, 7 | use_relative_movement=False, use_relative_jacobian=False): 8 | if adapt_movement_scale: 9 | source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume 10 | driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume 11 | adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) 12 | else: 13 | adapt_movement_scale = 1 14 | 15 | kp_new = {k: v for k, v in kp_driving.items()} 16 | 17 | if use_relative_movement: 18 | kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) 19 | kp_value_diff *= adapt_movement_scale 20 | kp_new['value'] = kp_value_diff + kp_source['value'] 21 | 22 | if use_relative_jacobian: 23 | jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) 24 | kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) 25 | 26 | return kp_new 27 | -------------------------------------------------------------------------------- /modules/fom/dense_motion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | from .util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian 5 | 6 | 7 | class DenseMotionNetwork(nn.Module): 8 | """ 9 | Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving 10 | """ 11 | 12 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False, 13 | scale_factor=1, kp_variance=0.01): 14 | super(DenseMotionNetwork, self).__init__() 15 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1), 16 | max_features=max_features, num_blocks=num_blocks) 17 | 18 | self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3)) 19 | 20 | if estimate_occlusion_map: 21 | self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) 22 | else: 23 | self.occlusion = None 24 | 25 | self.num_kp = num_kp 26 | self.scale_factor = scale_factor 27 | self.kp_variance = kp_variance 28 | 29 | if self.scale_factor != 1: 30 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 31 | 32 | def create_heatmap_representations(self, source_image, kp_driving, kp_source): 33 | """ 34 | Eq 6. in the paper H_k(z) 35 | """ 36 | spatial_size = source_image.shape[2:] 37 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance) 38 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance) 39 | heatmap = gaussian_driving - gaussian_source 40 | 41 | #adding background feature 42 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()) 43 | heatmap = torch.cat([zeros, heatmap], dim=1) 44 | heatmap = heatmap.unsqueeze(2) 45 | return heatmap 46 | 47 | def create_sparse_motions(self, source_image, kp_driving, kp_source): 48 | """ 49 | Eq 4. in the paper T_{s<-d}(z) 50 | """ 51 | bs, _, h, w = source_image.shape 52 | identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) 53 | identity_grid = identity_grid.view(1, 1, h, w, 2) 54 | coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) 55 | if 'jacobian' in kp_driving: 56 | jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) 57 | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) 58 | jacobian = jacobian.repeat(1, 1, h, w, 1, 1) 59 | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) 60 | coordinate_grid = coordinate_grid.squeeze(-1) 61 | 62 | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2) 63 | 64 | #adding background feature 65 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) 66 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) 67 | return sparse_motions 68 | 69 | def create_deformed_source_image(self, source_image, sparse_motions): 70 | """ 71 | Eq 7. in the paper \hat{T}_{s<-d}(z) 72 | """ 73 | bs, _, h, w = source_image.shape 74 | source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1) 75 | source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) 76 | sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) 77 | sparse_deformed = F.grid_sample(source_repeat, sparse_motions) 78 | sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) 79 | return sparse_deformed 80 | 81 | def forward(self, source_image, kp_driving, kp_source): 82 | if self.scale_factor != 1: 83 | source_image = self.down(source_image) 84 | 85 | bs, _, h, w = source_image.shape 86 | 87 | out_dict = dict() 88 | heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) 89 | sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) 90 | deformed_source = self.create_deformed_source_image(source_image, sparse_motion) 91 | out_dict['sparse_deformed'] = deformed_source 92 | 93 | input = torch.cat([heatmap_representation, deformed_source], dim=2) 94 | input = input.view(bs, -1, h, w) 95 | 96 | prediction = self.hourglass(input) 97 | 98 | mask = self.mask(prediction) 99 | mask = F.softmax(mask, dim=1) 100 | out_dict['mask'] = mask 101 | mask = mask.unsqueeze(2) 102 | sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) 103 | deformation = (sparse_motion * mask).sum(dim=1) 104 | deformation = deformation.permute(0, 2, 3, 1) 105 | 106 | out_dict['deformation'] = deformation 107 | 108 | # Sec. 3.2 in the paper 109 | if self.occlusion: 110 | occlusion_map = torch.sigmoid(self.occlusion(prediction)) 111 | out_dict['occlusion_map'] = occlusion_map 112 | 113 | return out_dict 114 | -------------------------------------------------------------------------------- /modules/fom/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d 5 | from .dense_motion import DenseMotionNetwork 6 | 7 | 8 | class OcclusionAwareGenerator(nn.Module): 9 | """ 10 | Generator that given source image and and keypoints try to transform image according to movement trajectories 11 | induced by keypoints. Generator follows Johnson architecture. 12 | """ 13 | 14 | def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks, 15 | num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): 16 | super(OcclusionAwareGenerator, self).__init__() 17 | 18 | if dense_motion_params is not None: 19 | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels, 20 | estimate_occlusion_map=estimate_occlusion_map, 21 | **dense_motion_params) 22 | else: 23 | self.dense_motion_network = None 24 | 25 | self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) 26 | 27 | down_blocks = [] 28 | for i in range(num_down_blocks): 29 | in_features = min(max_features, block_expansion * (2 ** i)) 30 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 31 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 32 | self.down_blocks = nn.ModuleList(down_blocks) 33 | 34 | up_blocks = [] 35 | for i in range(num_down_blocks): 36 | in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) 37 | out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) 38 | up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 39 | self.up_blocks = nn.ModuleList(up_blocks) 40 | 41 | self.bottleneck = torch.nn.Sequential() 42 | in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) 43 | for i in range(num_bottleneck_blocks): 44 | self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) 45 | 46 | self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) 47 | self.estimate_occlusion_map = estimate_occlusion_map 48 | self.num_channels = num_channels 49 | 50 | def deform_input(self, inp, deformation): 51 | _, h_old, w_old, _ = deformation.shape 52 | _, _, h, w = inp.shape 53 | if h_old != h or w_old != w: 54 | deformation = deformation.permute(0, 3, 1, 2) 55 | deformation = F.interpolate(deformation, size=(h, w), mode='bilinear') 56 | deformation = deformation.permute(0, 2, 3, 1) 57 | return F.grid_sample(inp, deformation) 58 | 59 | def forward(self, source_image, kp_driving, kp_source): 60 | # Encoding (downsampling) part 61 | out = self.first(source_image) 62 | for i in range(len(self.down_blocks)): 63 | out = self.down_blocks[i](out) 64 | 65 | # Transforming feature representation according to deformation and occlusion 66 | output_dict = {} 67 | if self.dense_motion_network is not None: 68 | dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving, 69 | kp_source=kp_source) 70 | output_dict['mask'] = dense_motion['mask'] 71 | output_dict['sparse_deformed'] = dense_motion['sparse_deformed'] 72 | 73 | if 'occlusion_map' in dense_motion: 74 | occlusion_map = dense_motion['occlusion_map'] 75 | output_dict['occlusion_map'] = occlusion_map 76 | else: 77 | occlusion_map = None 78 | deformation = dense_motion['deformation'] 79 | out = self.deform_input(out, deformation) 80 | 81 | if occlusion_map is not None: 82 | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: 83 | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') 84 | out = out * occlusion_map 85 | 86 | output_dict["deformed"] = self.deform_input(source_image, deformation) 87 | 88 | # Decoding part 89 | out = self.bottleneck(out) 90 | for i in range(len(self.up_blocks)): 91 | out = self.up_blocks[i](out) 92 | out = self.final(out) 93 | out = F.sigmoid(out) 94 | 95 | output_dict["prediction"] = out 96 | 97 | return output_dict 98 | -------------------------------------------------------------------------------- /modules/fom/keypoint_detector.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d 5 | 6 | 7 | class KPDetector(nn.Module): 8 | """ 9 | Detecting a keypoints. Return keypoint position and jacobian near each keypoint. 10 | """ 11 | 12 | def __init__(self, block_expansion, num_kp, num_channels, max_features, 13 | num_blocks, temperature, estimate_jacobian=False, scale_factor=1, 14 | single_jacobian_map=False, pad=0): 15 | super(KPDetector, self).__init__() 16 | 17 | self.predictor = Hourglass(block_expansion, in_features=num_channels, 18 | max_features=max_features, num_blocks=num_blocks) 19 | 20 | self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), 21 | padding=pad) 22 | 23 | if estimate_jacobian: 24 | self.num_jacobian_maps = 1 if single_jacobian_map else num_kp 25 | self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, 26 | out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) 27 | self.jacobian.weight.data.zero_() 28 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) 29 | else: 30 | self.jacobian = None 31 | 32 | self.temperature = temperature 33 | self.scale_factor = scale_factor 34 | if self.scale_factor != 1: 35 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 36 | 37 | def gaussian2kp(self, heatmap): 38 | """ 39 | Extract the mean and from a heatmap 40 | """ 41 | shape = heatmap.shape 42 | heatmap = heatmap.unsqueeze(-1) 43 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) 44 | value = (heatmap * grid).sum(dim=(2, 3)) 45 | kp = {'value': value} 46 | 47 | return kp 48 | 49 | def forward(self, x): 50 | if self.scale_factor != 1: 51 | x = self.down(x) 52 | 53 | feature_map = self.predictor(x) 54 | prediction = self.kp(feature_map) 55 | 56 | final_shape = prediction.shape 57 | heatmap = prediction.view(final_shape[0], final_shape[1], -1) 58 | heatmap = F.softmax(heatmap / self.temperature, dim=2) 59 | heatmap = heatmap.view(*final_shape) 60 | 61 | out = self.gaussian2kp(heatmap) 62 | 63 | if self.jacobian is not None: 64 | jacobian_map = self.jacobian(feature_map) 65 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], 66 | final_shape[3]) 67 | heatmap = heatmap.unsqueeze(2) 68 | 69 | jacobian = heatmap * jacobian_map 70 | jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) 71 | jacobian = jacobian.sum(dim=-1) 72 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) 73 | out['jacobian'] = jacobian 74 | 75 | return out 76 | -------------------------------------------------------------------------------- /modules/fom/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /modules/fom/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | 132 | .. math:: 133 | 134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 135 | 136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 137 | standard-deviation are reduced across all devices during training. 138 | 139 | For example, when one uses `nn.DataParallel` to wrap the network during 140 | training, PyTorch's implementation normalize the tensor on each device using 141 | the statistics only on that device, which accelerated the computation and 142 | is also easy to implement, but the statistics might be inaccurate. 143 | Instead, in this synchronized version, the statistics will be computed 144 | over all training samples distributed on multiple devices. 145 | 146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 147 | as the built-in PyTorch implementation. 148 | 149 | The mean and standard-deviation are calculated per-dimension over 150 | the mini-batches and gamma and beta are learnable parameter vectors 151 | of size C (where C is the input size). 152 | 153 | During training, this layer keeps a running estimate of its computed mean 154 | and variance. The running sum is kept with a default momentum of 0.1. 155 | 156 | During evaluation, this running mean/variance is used for normalization. 157 | 158 | Because the BatchNorm is done over the `C` dimension, computing statistics 159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 160 | 161 | Args: 162 | num_features: num_features from an expected input of size 163 | `batch_size x num_features [x width]` 164 | eps: a value added to the denominator for numerical stability. 165 | Default: 1e-5 166 | momentum: the value used for the running_mean and running_var 167 | computation. Default: 0.1 168 | affine: a boolean value that when set to ``True``, gives the layer learnable 169 | affine parameters. Default: ``True`` 170 | 171 | Shape: 172 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 174 | 175 | Examples: 176 | >>> # With Learnable Parameters 177 | >>> m = SynchronizedBatchNorm1d(100) 178 | >>> # Without Learnable Parameters 179 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 180 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 181 | >>> output = m(input) 182 | """ 183 | 184 | def _check_input_dim(self, input): 185 | if input.dim() != 2 and input.dim() != 3: 186 | raise ValueError('expected 2D or 3D input (got {}D input)' 187 | .format(input.dim())) 188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 189 | 190 | 191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 193 | of 3d inputs 194 | 195 | .. math:: 196 | 197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 198 | 199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 200 | standard-deviation are reduced across all devices during training. 201 | 202 | For example, when one uses `nn.DataParallel` to wrap the network during 203 | training, PyTorch's implementation normalize the tensor on each device using 204 | the statistics only on that device, which accelerated the computation and 205 | is also easy to implement, but the statistics might be inaccurate. 206 | Instead, in this synchronized version, the statistics will be computed 207 | over all training samples distributed on multiple devices. 208 | 209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 210 | as the built-in PyTorch implementation. 211 | 212 | The mean and standard-deviation are calculated per-dimension over 213 | the mini-batches and gamma and beta are learnable parameter vectors 214 | of size C (where C is the input size). 215 | 216 | During training, this layer keeps a running estimate of its computed mean 217 | and variance. The running sum is kept with a default momentum of 0.1. 218 | 219 | During evaluation, this running mean/variance is used for normalization. 220 | 221 | Because the BatchNorm is done over the `C` dimension, computing statistics 222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 223 | 224 | Args: 225 | num_features: num_features from an expected input of 226 | size batch_size x num_features x height x width 227 | eps: a value added to the denominator for numerical stability. 228 | Default: 1e-5 229 | momentum: the value used for the running_mean and running_var 230 | computation. Default: 0.1 231 | affine: a boolean value that when set to ``True``, gives the layer learnable 232 | affine parameters. Default: ``True`` 233 | 234 | Shape: 235 | - Input: :math:`(N, C, H, W)` 236 | - Output: :math:`(N, C, H, W)` (same shape as input) 237 | 238 | Examples: 239 | >>> # With Learnable Parameters 240 | >>> m = SynchronizedBatchNorm2d(100) 241 | >>> # Without Learnable Parameters 242 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 244 | >>> output = m(input) 245 | """ 246 | 247 | def _check_input_dim(self, input): 248 | if input.dim() != 4: 249 | raise ValueError('expected 4D input (got {}D input)' 250 | .format(input.dim())) 251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 252 | 253 | 254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 256 | of 4d inputs 257 | 258 | .. math:: 259 | 260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 261 | 262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 263 | standard-deviation are reduced across all devices during training. 264 | 265 | For example, when one uses `nn.DataParallel` to wrap the network during 266 | training, PyTorch's implementation normalize the tensor on each device using 267 | the statistics only on that device, which accelerated the computation and 268 | is also easy to implement, but the statistics might be inaccurate. 269 | Instead, in this synchronized version, the statistics will be computed 270 | over all training samples distributed on multiple devices. 271 | 272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 273 | as the built-in PyTorch implementation. 274 | 275 | The mean and standard-deviation are calculated per-dimension over 276 | the mini-batches and gamma and beta are learnable parameter vectors 277 | of size C (where C is the input size). 278 | 279 | During training, this layer keeps a running estimate of its computed mean 280 | and variance. The running sum is kept with a default momentum of 0.1. 281 | 282 | During evaluation, this running mean/variance is used for normalization. 283 | 284 | Because the BatchNorm is done over the `C` dimension, computing statistics 285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 286 | or Spatio-temporal BatchNorm 287 | 288 | Args: 289 | num_features: num_features from an expected input of 290 | size batch_size x num_features x depth x height x width 291 | eps: a value added to the denominator for numerical stability. 292 | Default: 1e-5 293 | momentum: the value used for the running_mean and running_var 294 | computation. Default: 0.1 295 | affine: a boolean value that when set to ``True``, gives the layer learnable 296 | affine parameters. Default: ``True`` 297 | 298 | Shape: 299 | - Input: :math:`(N, C, D, H, W)` 300 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 301 | 302 | Examples: 303 | >>> # With Learnable Parameters 304 | >>> m = SynchronizedBatchNorm3d(100) 305 | >>> # Without Learnable Parameters 306 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 308 | >>> output = m(input) 309 | """ 310 | 311 | def _check_input_dim(self, input): 312 | if input.dim() != 5: 313 | raise ValueError('expected 5D input (got {}D input)' 314 | .format(input.dim())) 315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 316 | -------------------------------------------------------------------------------- /modules/fom/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /modules/fom/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /modules/fom/util.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | from .sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d 7 | 8 | 9 | def kp2gaussian(kp, spatial_size, kp_variance): 10 | """ 11 | Transform a keypoint into gaussian like representation 12 | """ 13 | mean = kp['value'] 14 | 15 | coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) 16 | number_of_leading_dimensions = len(mean.shape) - 1 17 | shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape 18 | coordinate_grid = coordinate_grid.view(*shape) 19 | repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) 20 | coordinate_grid = coordinate_grid.repeat(*repeats) 21 | 22 | # Preprocess kp shape 23 | shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) 24 | mean = mean.view(*shape) 25 | 26 | mean_sub = (coordinate_grid - mean) 27 | 28 | out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) 29 | 30 | return out 31 | 32 | 33 | def make_coordinate_grid(spatial_size, type): 34 | """ 35 | Create a meshgrid [-1,1] x [-1,1] of given spatial_size. 36 | """ 37 | h, w = spatial_size 38 | x = torch.arange(w).type(type) 39 | y = torch.arange(h).type(type) 40 | 41 | x = (2 * (x / (w - 1)) - 1) 42 | y = (2 * (y / (h - 1)) - 1) 43 | 44 | yy = y.view(-1, 1).repeat(1, w) 45 | xx = x.view(1, -1).repeat(h, 1) 46 | 47 | meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) 48 | 49 | return meshed 50 | 51 | 52 | class ResBlock2d(nn.Module): 53 | """ 54 | Res block, preserve spatial resolution. 55 | """ 56 | 57 | def __init__(self, in_features, kernel_size, padding): 58 | super(ResBlock2d, self).__init__() 59 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 60 | padding=padding) 61 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 62 | padding=padding) 63 | self.norm1 = BatchNorm2d(in_features, affine=True) 64 | self.norm2 = BatchNorm2d(in_features, affine=True) 65 | 66 | def forward(self, x): 67 | out = self.norm1(x) 68 | out = F.relu(out) 69 | out = self.conv1(out) 70 | out = self.norm2(out) 71 | out = F.relu(out) 72 | out = self.conv2(out) 73 | out += x 74 | return out 75 | 76 | 77 | class UpBlock2d(nn.Module): 78 | """ 79 | Upsampling block for use in decoder. 80 | """ 81 | 82 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 83 | super(UpBlock2d, self).__init__() 84 | 85 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 86 | padding=padding, groups=groups) 87 | self.norm = BatchNorm2d(out_features, affine=True) 88 | 89 | def forward(self, x): 90 | out = F.interpolate(x, scale_factor=2) 91 | out = self.conv(out) 92 | out = self.norm(out) 93 | out = F.relu(out) 94 | return out 95 | 96 | 97 | class DownBlock2d(nn.Module): 98 | """ 99 | Downsampling block for use in encoder. 100 | """ 101 | 102 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 103 | super(DownBlock2d, self).__init__() 104 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 105 | padding=padding, groups=groups) 106 | self.norm = BatchNorm2d(out_features, affine=True) 107 | self.pool = nn.AvgPool2d(kernel_size=(2, 2)) 108 | 109 | def forward(self, x): 110 | out = self.conv(x) 111 | out = self.norm(out) 112 | out = F.relu(out) 113 | out = self.pool(out) 114 | return out 115 | 116 | 117 | class SameBlock2d(nn.Module): 118 | """ 119 | Simple block, preserve spatial resolution. 120 | """ 121 | 122 | def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): 123 | super(SameBlock2d, self).__init__() 124 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 125 | kernel_size=kernel_size, padding=padding, groups=groups) 126 | self.norm = BatchNorm2d(out_features, affine=True) 127 | 128 | def forward(self, x): 129 | out = self.conv(x) 130 | out = self.norm(out) 131 | out = F.relu(out) 132 | return out 133 | 134 | 135 | class Encoder(nn.Module): 136 | """ 137 | Hourglass Encoder 138 | """ 139 | 140 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 141 | super(Encoder, self).__init__() 142 | 143 | down_blocks = [] 144 | for i in range(num_blocks): 145 | down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), 146 | min(max_features, block_expansion * (2 ** (i + 1))), 147 | kernel_size=3, padding=1)) 148 | self.down_blocks = nn.ModuleList(down_blocks) 149 | 150 | def forward(self, x): 151 | outs = [x] 152 | for down_block in self.down_blocks: 153 | outs.append(down_block(outs[-1])) 154 | return outs 155 | 156 | 157 | class Decoder(nn.Module): 158 | """ 159 | Hourglass Decoder 160 | """ 161 | 162 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 163 | super(Decoder, self).__init__() 164 | 165 | up_blocks = [] 166 | 167 | for i in range(num_blocks)[::-1]: 168 | in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) 169 | out_filters = min(max_features, block_expansion * (2 ** i)) 170 | up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) 171 | 172 | self.up_blocks = nn.ModuleList(up_blocks) 173 | self.out_filters = block_expansion + in_features 174 | 175 | def forward(self, x): 176 | out = x.pop() 177 | for up_block in self.up_blocks: 178 | out = up_block(out) 179 | skip = x.pop() 180 | out = torch.cat([out, skip], dim=1) 181 | return out 182 | 183 | 184 | class Hourglass(nn.Module): 185 | """ 186 | Hourglass architecture. 187 | """ 188 | 189 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 190 | super(Hourglass, self).__init__() 191 | self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) 192 | self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) 193 | self.out_filters = self.decoder.out_filters 194 | 195 | def forward(self, x): 196 | return self.decoder(self.encoder(x)) 197 | 198 | 199 | class AntiAliasInterpolation2d(nn.Module): 200 | """ 201 | Band-limited downsampling, for better preservation of the input signal. 202 | """ 203 | def __init__(self, channels, scale): 204 | super(AntiAliasInterpolation2d, self).__init__() 205 | sigma = (1 / scale - 1) / 2 206 | kernel_size = 2 * round(sigma * 4) + 1 207 | self.ka = kernel_size // 2 208 | self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka 209 | 210 | kernel_size = [kernel_size, kernel_size] 211 | sigma = [sigma, sigma] 212 | # The gaussian kernel is the product of the 213 | # gaussian function of each dimension. 214 | kernel = 1 215 | meshgrids = torch.meshgrid( 216 | [ 217 | torch.arange(size, dtype=torch.float32) 218 | for size in kernel_size 219 | ] 220 | ) 221 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 222 | mean = (size - 1) / 2 223 | kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) 224 | 225 | # Make sure sum of values in gaussian kernel equals 1. 226 | kernel = kernel / torch.sum(kernel) 227 | # Reshape to depthwise convolutional weight 228 | kernel = kernel.view(1, 1, *kernel.size()) 229 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 230 | 231 | self.register_buffer('weight', kernel) 232 | self.groups = channels 233 | self.scale = scale 234 | 235 | def forward(self, input): 236 | if self.scale == 1.0: 237 | return input 238 | 239 | out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) 240 | out = F.conv2d(out, weight=self.weight, groups=self.groups) 241 | out = F.interpolate(out, scale_factor=(self.scale, self.scale)) 242 | 243 | return out 244 | -------------------------------------------------------------------------------- /modules/sda/__init__.py: -------------------------------------------------------------------------------- 1 | from .sda import VideoAnimator, get_audio_feature_extractor, cut_audio_sequence, tempdir 2 | -------------------------------------------------------------------------------- /modules/sda/encoder_audio.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from .utils import calculate_padding, prime_factors, calculate_output_size 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, code_size, rate, feat_length, init_kernel=None, init_stride=None, num_feature_maps=16, 7 | increasing_stride=True): 8 | super(Encoder, self).__init__() 9 | 10 | self.code_size = code_size 11 | self.cl = nn.ModuleList() 12 | self.activations = nn.ModuleList() 13 | self.strides = [] 14 | self.kernels = [] 15 | 16 | features = feat_length * rate 17 | strides = prime_factors(features) 18 | kernels = [2 * s for s in strides] 19 | 20 | if init_kernel is not None and init_stride is not None: 21 | self.strides.append(int(init_stride * rate)) 22 | self.kernels.append(int(init_kernel * rate)) 23 | padding = calculate_padding(init_kernel * rate, stride=init_stride * rate, in_size=features) 24 | init_features = calculate_output_size(features, init_kernel * rate, stride=init_stride * rate, 25 | padding=padding) 26 | strides = prime_factors(init_features) 27 | kernels = [2 * s for s in strides] 28 | 29 | if not increasing_stride: 30 | strides.reverse() 31 | kernels.reverse() 32 | 33 | self.strides.extend(strides) 34 | self.kernels.extend(kernels) 35 | 36 | for i in range(len(self.strides) - 1): 37 | padding = calculate_padding(self.kernels[i], stride=self.strides[i], in_size=features) 38 | features = calculate_output_size(features, self.kernels[i], stride=self.strides[i], padding=padding) 39 | pad = int(math.ceil(padding / 2.0)) 40 | 41 | if i == 0: 42 | self.cl.append( 43 | nn.Conv1d(1, num_feature_maps, self.kernels[i], stride=self.strides[i], padding=pad)) 44 | self.activations.append(nn.Sequential(nn.BatchNorm1d(num_feature_maps), nn.ReLU(True))) 45 | else: 46 | self.cl.append(nn.Conv1d(num_feature_maps, 2 * num_feature_maps, self.kernels[i], 47 | stride=self.strides[i], padding=pad)) 48 | self.activations.append(nn.Sequential(nn.BatchNorm1d(2 * num_feature_maps), nn.ReLU(True))) 49 | 50 | num_feature_maps *= 2 51 | 52 | self.cl.append(nn.Conv1d(num_feature_maps, self.code_size, features)) 53 | self.activations.append(nn.Tanh()) 54 | 55 | def forward(self, x): 56 | for i in range(len(self.strides)): 57 | x = self.cl[i](x) 58 | x = self.activations[i](x) 59 | 60 | return x.squeeze() 61 | -------------------------------------------------------------------------------- /modules/sda/encoder_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from .utils import calculate_padding, is_power2 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, code_size, img_size, kernel_size=4, num_input_channels=3, num_feature_maps=64, batch_norm=True): 7 | super(Encoder, self).__init__() 8 | 9 | # Get the dimension which is a power of 2 10 | if is_power2(max(img_size)): 11 | stable_dim = max(img_size) 12 | else: 13 | stable_dim = min(img_size) 14 | 15 | if isinstance(img_size, tuple): 16 | self.img_size = img_size 17 | self.final_size = tuple(int(4 * x // stable_dim) for x in self.img_size) 18 | else: 19 | self.img_size = (img_size, img_size) 20 | self.final_size = (4, 4) 21 | 22 | self.code_size = code_size 23 | self.num_feature_maps = num_feature_maps 24 | self.cl = nn.ModuleList() 25 | self.num_layers = int(np.log2(max(self.img_size))) - 2 26 | 27 | stride = 2 28 | # This ensures that we have same padding no matter if we have even or odd kernels 29 | padding = calculate_padding(kernel_size, stride) 30 | 31 | if batch_norm: 32 | self.cl.append(nn.Sequential( 33 | nn.Conv2d(num_input_channels, self.num_feature_maps, kernel_size, stride=stride, padding=padding // 2, 34 | bias=False), 35 | nn.BatchNorm2d(self.num_feature_maps), 36 | nn.ReLU(True))) 37 | else: 38 | self.cl.append(nn.Sequential( 39 | nn.Conv2d(num_input_channels, self.num_feature_maps, kernel_size, stride=stride, padding=padding // 2, 40 | bias=False), 41 | nn.ReLU(True))) 42 | 43 | self.channels = [self.num_feature_maps] 44 | for i in range(self.num_layers - 1): 45 | 46 | if batch_norm: 47 | self.cl.append(nn.Sequential( 48 | nn.Conv2d(self.channels[-1], self.channels[-1] * 2, kernel_size, stride=stride, 49 | padding=padding // 2, 50 | bias=False), 51 | nn.BatchNorm2d(self.channels[-1] * 2), 52 | nn.ReLU(True))) 53 | else: 54 | self.cl.append(nn.Sequential( 55 | nn.Conv2d(self.channels[-1], self.channels[-1] * 2, kernel_size, stride=stride, 56 | padding=padding // 2, bias=False), 57 | nn.ReLU(True))) 58 | 59 | self.channels.append(2 * self.channels[-1]) 60 | 61 | self.cl.append(nn.Sequential( 62 | nn.Conv2d(self.channels[-1], code_size, self.final_size, stride=1, padding=0, bias=False), 63 | nn.Tanh())) 64 | 65 | def forward(self, x, retain_intermediate=False): 66 | if retain_intermediate: 67 | h = [x] 68 | for conv_layer in self.cl: 69 | h.append(conv_layer(h[-1])) 70 | return h[-1].view(-1, self.code_size), h[1:-1] 71 | else: 72 | for conv_layer in self.cl: 73 | x = conv_layer(x) 74 | 75 | return x.view(-1, self.code_size) -------------------------------------------------------------------------------- /modules/sda/img_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from .utils import calculate_padding 5 | 6 | 7 | class Deconv(nn.Module): 8 | def __init__(self, in_channels, out_channels, in_size, kernel_size, stride=1, batch_norm=True): 9 | super(Deconv, self).__init__() 10 | # This ensures that we have same padding no matter if we have even or odd kernels 11 | padding = calculate_padding(kernel_size, stride) 12 | self.dcl = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding // 2, 13 | bias=False) 14 | 15 | if batch_norm: 16 | self.activation = nn.Sequential(nn.BatchNorm2d(out_channels), nn.ReLU(True)) 17 | else: 18 | self.activation = nn.ReLU(True) 19 | 20 | self.required_channels = out_channels 21 | self.out_size_required = tuple(x * stride for x in in_size) 22 | 23 | def forward(self, x): 24 | x = self.dcl(x, 25 | output_size=[-1, self.required_channels, self.out_size_required[0], self.out_size_required[1]]) 26 | 27 | return self.activation(x) 28 | 29 | 30 | class UnetBlock(nn.Module): 31 | def __init__(self, in_channels, out_channels, skip_channels, in_size, kernel_size, stride=1, batch_norm=True): 32 | super(UnetBlock, self).__init__() 33 | # This ensures that we have same padding no matter if we have even or odd kernels 34 | padding = calculate_padding(kernel_size, stride) 35 | self.dcl1 = nn.ConvTranspose2d(in_channels + skip_channels, in_channels, 3, padding=1, bias=False) 36 | self.dcl2 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 37 | padding=padding // 2, bias=False) 38 | if batch_norm: 39 | self.activation1 = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(True)) 40 | self.activation2 = nn.Sequential(nn.BatchNorm2d(out_channels), nn.ReLU(True)) 41 | else: 42 | self.activation1 = nn.ReLU(True) 43 | self.activation2 = nn.ReLU(True) 44 | 45 | self.required_channels = out_channels 46 | self.out_size_required = tuple(x * stride for x in in_size) 47 | 48 | def forward(self, x, s): 49 | s = s.view(x.size()) 50 | 51 | x = torch.cat([x, s], 1) 52 | 53 | x = self.dcl1(x) 54 | x = self.activation1(x) 55 | 56 | x = self.dcl2(x, output_size=[-1, self.required_channels, self.out_size_required[0], self.out_size_required[1]]) 57 | x = self.activation2(x) 58 | return x 59 | 60 | 61 | class Generator(nn.Module): 62 | def __init__(self, img_size, latent_size, condition_size=0, aux_size=0, kernel_size=4, num_channels=3, 63 | num_gen_channels=1024, skip_channels=[], batch_norm=True, sequential_noise=False, 64 | aux_only_on_top=False): 65 | super(Generator, self).__init__() 66 | # If we have a tuple make sure we maintain the aspect ratio 67 | if isinstance(img_size, tuple): 68 | self.img_size = img_size 69 | self.init_size = tuple(int(4 * x / max(img_size)) for x in self.img_size) 70 | else: 71 | self.img_size = (img_size, img_size) 72 | self.init_size = (4, 4) 73 | 74 | self.latent_size = latent_size 75 | self.condition_size = condition_size 76 | self.aux_size = aux_size 77 | 78 | self.rnn_noise = None 79 | if self.aux_size > 0 and sequential_noise: 80 | self.rnn_noise = nn.GRU(self.aux_size, self.aux_size, batch_first=True) 81 | self.rnn_noise_squashing = nn.Tanh() 82 | 83 | self.num_layers = int(np.log2(max(self.img_size))) - 1 84 | self.num_channels = num_channels 85 | self.num_gen_channels = num_gen_channels 86 | 87 | self.dcl = nn.ModuleList() 88 | 89 | self.aux_only_on_top = aux_only_on_top 90 | self.total_latent_size = self.latent_size + self.condition_size 91 | 92 | if self.aux_size > 0 and self.aux_only_on_top: 93 | self.aux_dcl = nn.Sequential( 94 | nn.ConvTranspose2d(self.aux_size, num_gen_channels, (self.init_size[0] // 2, self.init_size[1]), 95 | bias=False), 96 | nn.BatchNorm2d(num_gen_channels), 97 | nn.ReLU(True), 98 | nn.ConstantPad2d((0, 0, 0, self.init_size[0] // 2), 0)) 99 | else: 100 | self.total_latent_size += self.aux_size 101 | 102 | stride = 2 103 | if batch_norm: 104 | self.dcl.append( 105 | nn.Sequential( 106 | nn.ConvTranspose2d(self.total_latent_size, num_gen_channels, self.init_size, bias=False), 107 | nn.BatchNorm2d(num_gen_channels), 108 | nn.ReLU(True))) 109 | else: 110 | self.dcl.append( 111 | nn.Sequential( 112 | nn.ConvTranspose2d(self.total_latent_size, num_gen_channels, self.init_size, bias=False), 113 | nn.ReLU(True))) 114 | 115 | num_input_channels = self.num_gen_channels 116 | in_size = self.init_size 117 | for i in range(self.num_layers - 2): 118 | if not skip_channels: 119 | self.dcl.append(Deconv(num_input_channels, num_input_channels // 2, in_size, kernel_size, stride=stride, 120 | batch_norm=batch_norm)) 121 | else: 122 | self.dcl.append( 123 | UnetBlock(num_input_channels, num_input_channels // 2, skip_channels[i], in_size, 124 | kernel_size, stride=stride, batch_norm=batch_norm)) 125 | 126 | num_input_channels //= 2 127 | in_size = tuple(2 * x for x in in_size) 128 | 129 | padding = calculate_padding(kernel_size, stride) 130 | self.dcl.append(nn.ConvTranspose2d(num_input_channels, self.num_channels, kernel_size, 131 | stride=stride, padding=padding // 2, bias=False)) 132 | self.final_activation = nn.Tanh() 133 | 134 | def forward(self, x, c=None, aux=None, skip=[]): 135 | if aux is not None: 136 | if self.rnn_noise is not None: 137 | aux, h = self.rnn_noise(aux) 138 | aux = self.rnn_noise_squashing(aux) 139 | 140 | if self.aux_only_on_top: 141 | aux = self.aux_dcl(aux.view(-1, self.aux_size, 1, 1)) 142 | else: 143 | x = torch.cat([x, aux], 2) 144 | 145 | if c is not None: 146 | x = torch.cat([x, c], 2) 147 | 148 | x = x.view(-1, self.total_latent_size, 1, 1) 149 | x = self.dcl[0](x) 150 | 151 | if self.aux_only_on_top: 152 | x = x + aux 153 | 154 | if not skip: 155 | for i in range(1, self.num_layers - 1): 156 | x = self.dcl[i](x) 157 | else: 158 | for i in range(1, self.num_layers - 1): 159 | x = self.dcl[i](x, skip[i - 1]) 160 | 161 | x = self.dcl[-1](x, output_size=[-1, 3, self.img_size[0], self.img_size[1]]) 162 | return self.final_activation(x) 163 | -------------------------------------------------------------------------------- /modules/sda/rnn_audio.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .encoder_audio import Encoder 3 | 4 | 5 | class RNN(nn.Module): 6 | def __init__(self, feat_length, enc_code_size, rnn_code_size, rate, n_layers=2, init_kernel=None, 7 | init_stride=None): 8 | super(RNN, self).__init__() 9 | self.audio_feat_samples = int(rate * feat_length) 10 | self.enc_code_size = enc_code_size 11 | self.rnn_code_size = rnn_code_size 12 | self.encoder = Encoder(self.enc_code_size, rate, feat_length, init_kernel=init_kernel, 13 | init_stride=init_stride) 14 | self.rnn = nn.GRU(self.enc_code_size, self.rnn_code_size, n_layers, batch_first=True) 15 | 16 | def forward(self, x, lengths): 17 | seq_length = x.size()[1] 18 | x = x.view(-1, 1, self.audio_feat_samples) 19 | x = self.encoder(x) 20 | x = x.view(-1, seq_length, self.enc_code_size) 21 | x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) 22 | x, h = self.rnn(x) 23 | x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) 24 | return x.contiguous() 25 | -------------------------------------------------------------------------------- /modules/sda/sda.py: -------------------------------------------------------------------------------- 1 | # modified VideoAnimator return 2 | from torchvision import transforms 3 | import torch 4 | from .encoder_image import Encoder 5 | from .img_generator import Generator 6 | from .rnn_audio import RNN 7 | 8 | from scipy import signal 9 | from skimage import transform as tf 10 | import numpy as np 11 | from PIL import Image 12 | import contextlib 13 | import os 14 | import shutil 15 | import tempfile 16 | import skvideo.io as sio 17 | import scipy.io.wavfile as wav 18 | import ffmpeg 19 | import face_alignment 20 | from pydub import AudioSegment 21 | from pydub.utils import mediainfo 22 | 23 | 24 | @contextlib.contextmanager 25 | def cd(newdir, cleanup=lambda: True): 26 | prevdir = os.getcwd() 27 | os.chdir(os.path.expanduser(newdir)) 28 | try: 29 | yield 30 | finally: 31 | os.chdir(prevdir) 32 | cleanup() 33 | 34 | 35 | @contextlib.contextmanager 36 | def tempdir(): 37 | dirpath = tempfile.mkdtemp() 38 | 39 | def cleanup(): 40 | shutil.rmtree(dirpath) 41 | 42 | with cd(dirpath, cleanup): 43 | yield dirpath 44 | 45 | 46 | def get_audio_feature_extractor(model_path="grid", gpu=-1): 47 | if model_path == "grid": 48 | model_path = os.path.split(__file__)[0] + "/data/grid.dat" 49 | elif model_path == "timit": 50 | model_path = os.path.split(__file__)[0] + "/data/timit.dat" 51 | elif model_path == "crema": 52 | model_path = os.path.split(__file__)[0] + "/data/crema.dat" 53 | 54 | if gpu < 0: 55 | device = torch.device("cpu") 56 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 57 | else: 58 | device = torch.device("cuda:" + str(gpu)) 59 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(gpu)) 60 | 61 | audio_rate = model_dict["audio_rate"] 62 | audio_feat_len = model_dict['audio_feat_len'] 63 | rnn_gen_dim = model_dict['rnn_gen_dim'] 64 | aud_enc_dim = model_dict['aud_enc_dim'] 65 | video_rate = model_dict["video_rate"] 66 | 67 | encoder = RNN(audio_feat_len, aud_enc_dim, rnn_gen_dim, audio_rate, init_kernel=0.005, init_stride=0.001) 68 | encoder.to(device) 69 | encoder.load_state_dict(model_dict['encoder']) 70 | 71 | overlap = audio_feat_len - 1.0 / video_rate 72 | return encoder, {"rate": audio_rate, "feature length": audio_feat_len, "overlap": overlap} 73 | 74 | 75 | def cut_audio_sequence(seq, feature_length, overlap, rate): 76 | seq = seq.view(-1, 1) 77 | snip_length = int(feature_length * rate) 78 | cutting_stride = int((feature_length - overlap) * rate) 79 | pad_samples = snip_length - cutting_stride 80 | 81 | pad_left = torch.zeros(pad_samples // 2, 1, device=seq.device) 82 | pad_right = torch.zeros(pad_samples - pad_samples // 2, 1, device=seq.device) 83 | 84 | seq = torch.cat((pad_left, seq), 0) 85 | seq = torch.cat((seq, pad_right), 0) 86 | 87 | stacked = seq.narrow(0, 0, snip_length).unsqueeze(0) 88 | iterations = (seq.size()[0] - snip_length) // cutting_stride + 1 89 | for i in range(1, iterations): 90 | stacked = torch.cat((stacked, seq.narrow(0, i * cutting_stride, snip_length).unsqueeze(0))) 91 | return stacked 92 | 93 | 94 | class VideoAnimator: 95 | def __init__(self, model_path='data/sda/grid.dat', gpu=-1): 96 | if gpu < 0: 97 | self.device = torch.device("cpu") 98 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 99 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device="cpu", flip_input=False) 100 | else: 101 | self.device = torch.device("cuda:" + str(gpu)) 102 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(gpu)) 103 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device="cuda:" + str(gpu), 104 | flip_input=False) 105 | 106 | self.stablePntsIDs = [33, 36, 39, 42, 45] 107 | self.mean_face = model_dict["mean_face"] 108 | self.img_size = model_dict["img_size"] 109 | self.audio_rate = model_dict["audio_rate"] 110 | self.video_rate = model_dict["video_rate"] 111 | self.audio_feat_len = model_dict['audio_feat_len'] 112 | self.audio_feat_samples = model_dict['audio_feat_samples'] 113 | self.id_enc_dim = model_dict['id_enc_dim'] 114 | self.rnn_gen_dim = model_dict['rnn_gen_dim'] 115 | self.aud_enc_dim = model_dict['aud_enc_dim'] 116 | self.aux_latent = model_dict['aux_latent'] 117 | self.sequential_noise = model_dict['sequential_noise'] 118 | self.conversion_dict = {'s16': np.int16, 's32': np.int32} 119 | 120 | self.img_transform = transforms.Compose([ 121 | transforms.ToPILImage(), 122 | transforms.Resize((self.img_size[0], self.img_size[1])), 123 | transforms.ToTensor(), 124 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 125 | 126 | self.encoder = RNN(self.audio_feat_len, self.aud_enc_dim, self.rnn_gen_dim, 127 | self.audio_rate, init_kernel=0.005, init_stride=0.001) 128 | self.encoder.to(self.device) 129 | self.encoder.load_state_dict(model_dict['encoder']) 130 | 131 | self.encoder_id = Encoder(self.id_enc_dim, self.img_size) 132 | self.encoder_id.to(self.device) 133 | self.encoder_id.load_state_dict(model_dict['encoder_id']) 134 | 135 | skip_channels = list(self.encoder_id.channels) 136 | skip_channels.reverse() 137 | 138 | self.generator = Generator(self.img_size, self.rnn_gen_dim, condition_size=self.id_enc_dim, 139 | num_gen_channels=self.encoder_id.channels[-1], 140 | skip_channels=skip_channels, aux_size=self.aux_latent, 141 | sequential_noise=self.sequential_noise) 142 | 143 | self.generator.to(self.device) 144 | self.generator.load_state_dict(model_dict['generator']) 145 | 146 | self.encoder.eval() 147 | self.encoder_id.eval() 148 | self.generator.eval() 149 | 150 | def save_video(self, video, audio, path, fs, overwrite=True, experimental_ffmpeg=False, scale=None): 151 | if not os.path.isabs(path): 152 | path = os.getcwd() + "/" + path 153 | 154 | with tempdir() as dirpath: 155 | # Save the video file 156 | writer = sio.FFmpegWriter(dirpath + "/tmp.avi", 157 | inputdict={'-r': str(self.video_rate) + "/1", }, 158 | outputdict={'-r': str(self.video_rate) + "/1", } 159 | ) 160 | for i in range(video.shape[0]): 161 | frame = np.rollaxis(video[i, :, :, :], 0, 3) 162 | 163 | if scale is not None: 164 | frame = tf.rescale(frame, scale, anti_aliasing=True, multichannel=True, mode='reflect') 165 | 166 | writer.writeFrame(frame) 167 | writer.close() 168 | 169 | # Save the audio file 170 | wav.write(dirpath + "/tmp.wav", fs, audio) 171 | 172 | in1 = ffmpeg.input(dirpath + "/tmp.avi") 173 | in2 = ffmpeg.input(dirpath + "/tmp.wav") 174 | if experimental_ffmpeg: 175 | out = ffmpeg.output(in1['v'], in2['a'], path, strict='-2', loglevel="panic") 176 | else: 177 | out = ffmpeg.output(in1['v'], in2['a'], path, loglevel="panic") 178 | 179 | if overwrite: 180 | out = out.overwrite_output() 181 | out.run() 182 | 183 | def preprocess_img(self, img): 184 | src = self.fa.get_landmarks(img)[0][self.stablePntsIDs, :] 185 | dst = self.mean_face[self.stablePntsIDs, :] 186 | tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix 187 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=self.img_size) # wrap the frame image 188 | warped = warped * 255 # note output from wrap is double image (value range [0,1]) 189 | warped = warped.astype('uint8') 190 | 191 | return warped 192 | 193 | def _cut_sequence_(self, seq, cutting_stride, pad_samples): 194 | pad_left = torch.zeros(pad_samples // 2, 1) 195 | pad_right = torch.zeros(pad_samples - pad_samples // 2, 1) 196 | 197 | seq = torch.cat((pad_left, seq), 0) 198 | seq = torch.cat((seq, pad_right), 0) 199 | 200 | stacked = seq.narrow(0, 0, self.audio_feat_samples).unsqueeze(0) 201 | iterations = (seq.size()[0] - self.audio_feat_samples) // cutting_stride + 1 202 | for i in range(1, iterations): 203 | stacked = torch.cat((stacked, seq.narrow(0, i * cutting_stride, self.audio_feat_samples).unsqueeze(0))) 204 | return stacked.to(self.device) 205 | 206 | def _broadcast_elements_(self, batch, repeat_no): 207 | total_tensors = [] 208 | for i in range(0, batch.size()[0]): 209 | total_tensors += [torch.stack(repeat_no * [batch[i]])] 210 | 211 | return torch.stack(total_tensors) 212 | 213 | def __call__(self, img, audio, fs=None, aligned=False): 214 | if isinstance(img, str): # if we have a path then grab the image 215 | frm = Image.open(img) 216 | frm.thumbnail((400, 400)) 217 | frame = np.array(frm) 218 | else: 219 | frame = img 220 | 221 | if not aligned: 222 | frame = self.preprocess_img(frame) 223 | 224 | if isinstance(audio, str): # if we have a path then grab the audio clip 225 | info = mediainfo(audio) 226 | fs = int(info['sample_rate']) 227 | audio = np.array(AudioSegment.from_file(audio, info['format_name']).set_channels(1).get_array_of_samples()) 228 | 229 | if info['sample_fmt'] in self.conversion_dict: 230 | audio = audio.astype(self.conversion_dict[info['sample_fmt']]) 231 | else: 232 | if max(audio) > np.iinfo(np.int16).max: 233 | audio = audio.astype(np.int32) 234 | else: 235 | audio = audio.astype(np.int16) 236 | 237 | if fs is None: 238 | raise AttributeError("Audio provided without specifying the rate. Specify rate or use audio file!") 239 | 240 | if audio.ndim > 1 and audio.shape[1] > 1: 241 | audio = audio[:, 0] 242 | 243 | max_value = np.iinfo(audio.dtype).max 244 | if fs != self.audio_rate: 245 | seq_length = audio.shape[0] 246 | speech = torch.from_numpy( 247 | signal.resample(audio, int(seq_length * self.audio_rate / float(fs))) / float(max_value)).float() 248 | speech = speech.view(-1, 1) 249 | else: 250 | audio = torch.from_numpy(audio / float(max_value)).float() 251 | speech = audio.view(-1, 1) 252 | 253 | frame = self.img_transform(frame).to(self.device) 254 | 255 | cutting_stride = int(self.audio_rate / float(self.video_rate)) 256 | audio_seq_padding = self.audio_feat_samples - cutting_stride 257 | 258 | # Create new sequences of the audio windows 259 | audio_feat_seq = self._cut_sequence_(speech, cutting_stride, audio_seq_padding) 260 | frame = frame.unsqueeze(0) 261 | audio_feat_seq = audio_feat_seq.unsqueeze(0) 262 | audio_feat_seq_length = audio_feat_seq.size()[1] 263 | 264 | z = self.encoder(audio_feat_seq, [audio_feat_seq_length]) # Encoding for the motion 265 | noise = torch.FloatTensor(1, audio_feat_seq_length, self.aux_latent).normal_(0, 0.33).to(self.device) 266 | z_id, skips = self.encoder_id(frame, retain_intermediate=True) 267 | skip_connections = [] 268 | for skip_variable in skips: 269 | skip_connections.append(self._broadcast_elements_(skip_variable, z.size()[1])) 270 | skip_connections.reverse() 271 | 272 | z_id = self._broadcast_elements_(z_id, z.size()[1]) 273 | gen_video = self.generator(z, c=z_id, aux=noise, skip=skip_connections) 274 | 275 | # returned_audio = ((2 ** 15) * speech.detach().cpu().numpy()).astype(np.int16) 276 | gen_video = 125 * gen_video.squeeze().detach().cpu().numpy() + 125 277 | # return gen_video, returned_audio 278 | return gen_video 279 | -------------------------------------------------------------------------------- /modules/sda/utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | 4 | def prime_factors(number): 5 | factor = 2 6 | factors = [] 7 | while factor * factor <= number: 8 | if number % factor: 9 | factor += 1 10 | else: 11 | number //= factor 12 | factors.append(int(factor)) 13 | if number > 1: 14 | factors.append(int(number)) 15 | return factors 16 | 17 | 18 | def calculate_padding(kernel_size, stride=1, in_size=0): 19 | out_size = ceil(float(in_size) / float(stride)) 20 | return int((out_size - 1) * stride + kernel_size - in_size) 21 | 22 | 23 | def calculate_output_size(in_size, kernel_size, stride, padding): 24 | return int((in_size + padding - kernel_size) / stride) + 1 25 | 26 | 27 | def is_power2(num): 28 | return num != 0 and ((num & (num - 1)) == 0) 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astunparse==1.6.3 3 | audioread==2.1.8 4 | blis==0.4.1 5 | cachetools==4.1.0 6 | catalogue==1.0.0 7 | certifi==2020.4.5.1 8 | cffi==1.14.0 9 | chardet==3.0.4 10 | click==7.1.1 11 | cycler==0.10.0 12 | cymem==2.0.3 13 | decorator==4.4.2 14 | face-alignment==1.0.0 15 | ffmpeg-python==0.2.0 16 | filelock==3.0.12 17 | Flask==1.1.2 18 | future==0.18.2 19 | gast==0.3.3 20 | google-auth==1.14.2 21 | google-auth-oauthlib==0.4.1 22 | google-pasta==0.2.0 23 | grpcio==1.28.1 24 | h5py==2.10.0 25 | idna==2.9 26 | imageio==2.8.0 27 | importlib-metadata==1.6.0 28 | itsdangerous==1.1.0 29 | Jinja2==2.11.1 30 | joblib==0.14.1 31 | Keras-Preprocessing==1.1.0 32 | kiwisolver==1.2.0 33 | librosa==0.7.2 34 | Markdown==3.2.2 35 | MarkupSafe==1.1.1 36 | matplotlib==3.2.1 37 | more-itertools==8.2.0 38 | murmurhash==1.0.2 39 | networkx==2.4 40 | numba==0.48.0 41 | numpy==1.18.3 42 | oauthlib==3.1.0 43 | opencv-python==4.2.0.34 44 | opt-einsum==3.2.1 45 | Pillow==7.1.2 46 | plac==1.1.3 47 | preshed==3.0.2 48 | protobuf==3.11.3 49 | pyasn1==0.4.8 50 | pyasn1-modules==0.2.8 51 | pycparser==2.20 52 | pydub==0.23.1 53 | pyparsing==2.4.7 54 | python-dateutil==2.8.1 55 | PyWavelets==1.1.1 56 | PyYAML==5.3.1 57 | regex==2020.5.7 58 | requests==2.23.0 59 | requests-oauthlib==1.3.0 60 | resampy==0.2.2 61 | rsa==4.0 62 | sacremoses==0.0.43 63 | scikit-image==0.16.2 64 | scikit-learn==0.22.2.post1 65 | scikit-video==1.1.11 66 | scipy==1.4.1 67 | sentencepiece==0.1.86 68 | six==1.14.0 69 | SoundFile==0.10.3.post1 70 | spacy==2.2.4 71 | srsly==1.0.2 72 | tensorboard==2.2.1 73 | tensorboard-plugin-wit==1.6.0.post3 74 | tensorflow==2.2.0 75 | tensorflow-estimator==2.2.0 76 | termcolor==1.1.0 77 | thinc==7.4.0 78 | tokenizers==0.7.0 79 | torch==1.5.0 80 | torchvision==0.6.0 81 | tqdm==4.45.0 82 | transformers==2.9.0 83 | Unidecode==1.1.1 84 | urllib3==1.25.9 85 | wasabi==0.6.0 86 | Werkzeug==1.0.1 87 | wrapt==1.12.1 88 | zipp==3.1.0 89 | -------------------------------------------------------------------------------- /result.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/result.gif -------------------------------------------------------------------------------- /result.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thetobysiu/deepstory/34eb7b1771479b996f361c291dc36b88ca25bd17/result.mp4 -------------------------------------------------------------------------------- /static/css/styles.css: -------------------------------------------------------------------------------- 1 | .container { 2 | max-width: 900px; 3 | } 4 | 5 | .tab-content { 6 | margin: 10px; 7 | } 8 | 9 | .col, .col-1, .col-10, .col-11, .col-12, .col-2, .col-3, .col-4, .col-5, .col-6, .col-7, .col-8, .col-9, .col-auto, .col-lg, .col-lg-1, .col-lg-10, .col-lg-11, .col-lg-12, .col-lg-2, .col-lg-3, .col-lg-4, .col-lg-5, .col-lg-6, .col-lg-7, .col-lg-8, .col-lg-9, .col-lg-auto, .col-md, .col-md-1, .col-md-10, .col-md-11, .col-md-12, .col-md-2, .col-md-3, .col-md-4, .col-md-5, .col-md-6, .col-md-7, .col-md-8, .col-md-9, .col-md-auto, .col-sm, .col-sm-1, .col-sm-10, .col-sm-11, .col-sm-12, .col-sm-2, .col-sm-3, .col-sm-4, .col-sm-5, .col-sm-6, .col-sm-7, .col-sm-8, .col-sm-9, .col-sm-auto, .col-xl, .col-xl-1, .col-xl-10, .col-xl-11, .col-xl-12, .col-xl-2, .col-xl-3, .col-xl-4, .col-xl-5, .col-xl-6, .col-xl-7, .col-xl-8, .col-xl-9, .col-xl-auto { 10 | position: relative; 11 | width: 100%; 12 | padding-right: 15px; 13 | padding-left: 15px; 14 | margin: 10px 0px; 15 | } 16 | 17 | .table td, .table th { 18 | padding: .75rem; 19 | vertical-align: top; 20 | border-top: 0px; 21 | } 22 | 23 | textarea.form-control { 24 | overflow-y:hidden; 25 | } 26 | 27 | img { 28 | width: 256px; 29 | height: auto; 30 | } 31 | -------------------------------------------------------------------------------- /templates/animate.html: -------------------------------------------------------------------------------- 1 | {% if loaded_speakers %} 2 |
3 | {% for speaker in loaded_speakers %} 4 |

{{ speaker }}: 5 | 18 |

19 | {% endfor %} 20 | 21 |
22 | {% else %} 23 |

No speakers loaded.

24 | {% endif %} -------------------------------------------------------------------------------- /templates/deepstory.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function () { 2 | function refresh_status() { 3 | $("#status").load("{{ url_for('status') }}", function () { 4 | $("#clearCache").click(function () { 5 | $.ajax({ 6 | url: "{{ url_for('clear') }}", 7 | success: function (message) { 8 | alert(message); 9 | refresh_status(); 10 | refresh_animate(); 11 | refresh_video(); 12 | }, 13 | error: function (response) { 14 | alert(response.responseText); 15 | } 16 | }); 17 | }); 18 | }); 19 | } 20 | function refresh_sent() { 21 | $("#sentences").load("{{ url_for('sentences') }}"); 22 | } 23 | function refresh_animate() { 24 | $("#tab-3").load("{{ url_for('animate') }}", function () { 25 | $("#animate").find("select").each(function () { 26 | let img = $('', {src: "image/" + $(this).val()}); 27 | img.insertAfter($(this)); 28 | }).on('change', function () { 29 | $(this).parent().find("img").attr('src', "image/" + $(this).val()); 30 | }); 31 | $("#animate").submit(function (e) { 32 | e.preventDefault(); 33 | let button = $(this).find('button') 34 | let tempText = button.text(); 35 | button.text("Animating..."); 36 | button.prop('disabled', true); 37 | $.ajax({ 38 | type: "POST", 39 | url: this.action, 40 | data: $(this).serialize(), 41 | success: function (message) { 42 | alert(message); 43 | button.text(tempText); 44 | button.prop('disabled', false); 45 | refresh_status(); 46 | refresh_video(); 47 | }, 48 | error: function (response) { 49 | alert(response.responseText); 50 | button.text(tempText); 51 | button.prop('disabled', false); 52 | } 53 | }); 54 | }); 55 | }); 56 | } 57 | function refresh_video() { 58 | $("#tab-4").load("{{ url_for('video') }}", function () { 59 | $("#view").click(function () { 60 | if (!$(this).parent().has('video').length) { 61 | let video = $('