├── .gitignore ├── LICENSE.md ├── README.md ├── app.py ├── blender └── animation_uist2021.blend ├── requirements.txt ├── sg_core ├── config │ ├── multimodal_context_toolkit.yml │ └── parse_args.py ├── output │ ├── h36m_gesture_autoencoder │ │ └── gesture_autoencoder_checkpoint_best.bin │ └── sgtoolkit │ │ └── multimodal_context_checkpoint_best.bin └── scripts │ ├── __init__.py │ ├── data_loader │ ├── data_preprocessor.py │ ├── lmdb_data_loader.py │ └── motion_preprocessor.py │ ├── gesture_generator.py │ ├── model │ ├── embedding_net.py │ ├── embedding_space_evaluator.py │ ├── multimodal_context_net.py │ ├── tcn.py │ └── vocab.py │ ├── train.py │ ├── train_eval │ ├── diff_augment.py │ └── train_gan.py │ └── utils │ ├── average_meter.py │ ├── data_utils.py │ ├── gui_utils.py │ ├── train_utils.py │ ├── tts_helper.py │ └── vocab_utils.py ├── sg_core_api.py ├── static ├── css │ └── index.css ├── favicon.ico ├── js │ ├── avatar.js │ ├── avatarInterface.js │ ├── cell.js │ ├── cellTrack.js │ ├── history.js │ ├── index.js │ ├── motionLibrary.js │ ├── ruleManager.js │ ├── sortedlist.js │ ├── stylePannel.js │ ├── timeline.js │ ├── track.js │ └── util.js ├── mesh │ └── mannequin │ │ ├── Ch36_1001_Diffuse.png │ │ ├── Ch36_1001_Glossiness.png │ │ ├── Ch36_1001_Normal.png │ │ ├── Ch36_1001_Specular.png │ │ └── mannequin.babylon └── screenshot.jpg ├── templates ├── index.html ├── modal │ ├── cell_control_data_modification_invalid_warning.html │ ├── control_data_delete_modal.html │ ├── generate_modal.html │ ├── help_dialog_modal.html │ ├── import_json.html │ ├── loading_progress_modal.html │ ├── motion_library_delete_modal.html │ └── view_rule_modal.html └── sample_text.html └── waitress_server.py /.gitignore: -------------------------------------------------------------------------------- 1 | gentle/ 2 | .idea/ 3 | cached_wav/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## AIR License and Service Agreement 2 | 3 | ### Preamble 4 | This License and Service Agreement (LSA) applies to all works and their derivative works based on [source form version] and [object form version] of AIR. Currently, the LSA for AIR has two policies, 'Open Source License' and 'Commercial License'. Therefore, all works including the source code and executable code of AIR and derivative works based thereon are subject to either 'Open Source License' or 'Commercial License' depending on the user's needs and purpose. Details related to the selection of the applicable license are specified in this LSA. 5 | If you use any or all of AIR in any form, you are deemed to have consented to this LSA. If you breach any of the terms and conditions set forth in this LSA, you are solely responsible for any losses or damages incurred by “Electronics and Telecommunications Research Institute” (ETRI), and ETRI assume no responsibility for you or any third party. 6 | 7 | ### Commercial License 8 | If you use the [source form version] or [object form version] of AIR in whole or in part to develop a code or a derivative work, and you want to commercialize the result in some form, you will be covered under a commercial license. 9 | And if you are subject to a commercial license, the contract for the use of AIR is subject to “TECHNOLOGY LICENSE AGREEMENT” of ETRI. 10 | You acknowledge that ETRI has all legal rights, title and interest, including intellectual property rights in the AIR (regardless of whether such intellectual property rights are registered or where such rights exist) and agree with no objection thereto. 11 | Except as provided in a subsidiary agreement, nothing in this LSA grants you the right to use AIR or the name, service mark, logo, domain name and other unique identification marks of ETRI. 12 | 13 | ### Open Source License 14 | If you use the [source form version] or [object form version] of AIR in whole or in part to develop a code or a derivative work, and you do not commercialize the result in any form, you will be covered under an open source license. 15 | AIR is in accordance with Free Software Foundation (FSF)'s open source policy, and is allowed to use it in the appropriate scope and manner, and you must comply with the applicable open source license policy applied to AIR. 16 | AIR is, in principle, subject to GNU General Public License version 3.0(GPLv3). If you have acquired all or a part of the AIR in any way and it is subject to a license other than the open source license described above, please contact the following address for the technical support and other inquiries before use, and check the usage information. 17 | 18 | ### Technical support and other inquiries 19 | If you have any questions about licensing and sales of AIR, and other technical support services, please contact the following: 20 | * Name: Minsu Jang 21 | * Phone: +82-42-860-1250 22 | * E-mail: minsu@etri.re.kr 23 | 24 | ### Credit 25 | 26 | #### [Pytorch Tutorial](https://github.com/spro/practical-pytorch) 27 | The MIT License (MIT) 28 | 29 | Copyright (c) 2017 Sean Robertson 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in 39 | all copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 47 | THE SOFTWARE. 48 | 49 | #### [PyTorch-Batch-Attention-Seq2seq](https://github.com/AuCson/PyTorch-Batch-Attention-Seq2seq) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SGToolkit 2 | 3 | This is the code for SGToolkit: An Interactive Gesture Authoring Toolkit for Embodied Conversational Agents (UIST'21). 4 | We introduce a new gesture generation toolkit, named SGToolkit, which gives a higher quality output than automatic methods and is efficient than manual authoring. 5 | For the toolkit, we propose a neural generative model that synthesizes gestures from speech and accommodates fine-level pose controls and coarse-level style controls from users. 6 | 7 | ![SCREENSHOT](static/screenshot.jpg) 8 | 9 | ### [ACM DL](https://doi.org/10.1145/3472749.3474789) | [arXiv](https://arxiv.org/pdf/2108.04636.pdf) | [Presentation video](https://youtu.be/qClSOtLiVlc) 10 | (please visit the ACM DL page for the supplementary video) 11 | 12 | ## Install 13 | 14 | (This code is tested on Ubuntu 18.04 and Python 3.6) 15 | 16 | * Install gentle and put the path into PYTHONPATH 17 | ```bash 18 | sudo apt install gfortran 19 | git clone https://github.com/lowerquality/gentle.git 20 | cd gentle 21 | ./install.sh # try 'sudo ./install.sh' if you encounter permission errors 22 | ``` 23 | 24 | * Setup Google Cloud TTS. Please follow [the manual](https://cloud.google.com/docs/authentication/getting-started) and put your key file (`google-key.json`) to `sg_core` folder. 25 | 26 | * Install Python packages 27 | ```bash 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | * Download the model file ([dropbox](https://www.dropbox.com/s/b5rwtn60j4tf2vr/multimodal_context_checkpoint_best.bin?dl=0)) and put it into `sg_core/output/sgtoolkit` folder 32 | 33 | 34 | ## Quickstart 35 | 36 | You can run the SGToolkit in your PC with the pretrained model. Run the Flask server `python waitress_server.py` and connect to `localhost:8080` in a web browser supporting HTML5 such as Chrome and Edge. 37 | 38 | Input speech text in the edit box or select example speech text, 39 | and then click the `generate` button to synthesize initial gestures and click the `play` button to review the synthesized gestures. 40 | You now can add pose and style controls. Select a desired frame in the pose or style tracks and add pose controls by editing mannequin or style controls by adjusting style values. 41 | Press `apply controls` to get the updated results. 42 | 43 | Note that the motion library and rule functions are not available. If you want to use them, please setup MongoDB and put the db address at `app.py` line 18. 44 | 45 | 46 | ## Training 47 | 48 | 1. Download the preprocessed TED dataset ([OneDrive link](https://kaistackr-my.sharepoint.com/:u:/g/personal/zeroyy_kaist_ac_kr/EWwpDefvifdCvVKkExlv12QBoRdjiyqy9BXnLGMzFD-HeQ?e=WPUtgo)) and extract to `sg_core/data/ted_dataset_2020.07` 49 | 50 | 2. Run the train script 51 | ```bash 52 | cd sg_core/scripts 53 | python train.py --config=../config/multimodal_context_toolkit.yml 54 | ``` 55 | 56 | 57 | ## Animation rendering 58 | 59 | 1. Export current animation to json and audio files by clicking the export button in the upper right corner of the SGToolkit, and put the exported files into a temporary folder. 60 | 2. Open `blender/animation_uist2021.blend` file with Blender 2.8+. 61 | 3. Modify `data_path` at line 35 of `render` script to be the temporary path containing exported files, and run the script `render` 62 | 63 | 64 | ## Related repositories 65 | 66 | * TED DB: https://github.com/youngwoo-yoon/youtube-gesture-dataset 67 | * Base model: https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context 68 | * HEMVIP, the web-based video evaluation tool: https://github.com/jonepatr/genea_webmushra 69 | 70 | 71 | ## Citation 72 | If this work is helpful in your research, please cite: 73 | ```text 74 | @inproceedings{yoon2021sgtoolkit, 75 | author = {Yoon, Youngwoo and Park, Keunwoo and Jang, Minsu and Kim, Jaehong and Lee, Geehyuk}, 76 | title = {SGToolkit: An Interactive Gesture Authoring Toolkit for Embodied Conversational Agents}, 77 | year = {2021}, 78 | publisher = {Association for Computing Machinery}, 79 | url = {https://doi.org/10.1145/3472749.3474789}, 80 | booktitle = {The 34th Annual ACM Symposium on User Interface Software and Technology}, 81 | series = {UIST '21} 82 | } 83 | ``` 84 | 85 | ## Acknowledgement 86 | 87 | * Character asset: [Mixamo](https://www.mixamo.com/) 88 | * This work was supported by the ICT R&D program of MSIP/IITP. [2017-0-00162, Development of Human-care Robot Technology for Aging Society] 89 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template, request, send_file 2 | from flask_pymongo import PyMongo 3 | import json 4 | import sg_core_api as sgapi 5 | import os 6 | import pathlib 7 | import numpy as np 8 | from bson.json_util import dumps 9 | from bson.objectid import ObjectId 10 | from datetime import datetime 11 | from scipy.interpolate import CubicSpline 12 | 13 | app = Flask(__name__) 14 | 15 | gesture_generator = sgapi.get_gesture_generator() 16 | root_path = pathlib.Path(__file__).parent 17 | 18 | app.config["MONGO_URI"] = "mongodb://localhost" # setup your own db to enable motion library and rule functions 19 | mongo = PyMongo(app) 20 | 21 | 22 | @app.route('/') 23 | def index(): 24 | return render_template('index.html') 25 | 26 | 27 | @app.route('/api/motion', methods=['GET', 'POST']) 28 | def motion_library(): 29 | if request.method == 'POST': 30 | json = request.get_json() 31 | json["motion"] = sgapi.convert_pose_coordinate_for_ui(np.array(json["motion"])).tolist() 32 | result = {} 33 | try: 34 | mongo.db.motion.insert_one(json) 35 | result['msg'] = "success" 36 | except Exception as e: 37 | result['msg'] = "fail" 38 | return result 39 | elif request.method == 'GET': 40 | try: 41 | cursor = mongo.db.motion.find().sort("name", 1) 42 | except AttributeError as e: 43 | return {} # empty library 44 | 45 | motions = sgapi.convert_pose_coordinate_for_ui_for_motion_library(list(cursor)) 46 | return dumps(motions) 47 | else: 48 | assert False 49 | 50 | 51 | @app.route('/api/delete_motion/', methods=['GET']) 52 | def delete_motion_library(id): 53 | result = mongo.db.motion.delete_one({'_id': ObjectId(id)}) 54 | msg = {} 55 | if result.deleted_count > 0: 56 | msg['msg'] = "success" 57 | else: 58 | msg['msg'] = "fail" 59 | return msg 60 | 61 | 62 | @app.route('/api/rule', methods=['GET', 'POST']) 63 | def rule(): 64 | if request.method == 'POST': 65 | json = request.get_json() 66 | result = {} 67 | try: 68 | json['motion'] = ObjectId(json['motion']) 69 | mongo.db.rule.insert_one(json) 70 | result['msg'] = "success" 71 | except Exception as e: 72 | print(json) 73 | print(e) 74 | result['msg'] = "fail" 75 | return result 76 | elif request.method == 'GET': 77 | pipeline = [{'$lookup': 78 | {'from': 'motion', 79 | 'localField': 'motion', 80 | 'foreignField': '_id', 81 | 'as': 'motion_info'}}, 82 | ] 83 | 84 | try: 85 | cursor = mongo.db.rule.aggregate(pipeline) 86 | except AttributeError as e: 87 | return {} # empty rules 88 | 89 | rules = sgapi.convert_pose_coordinate_for_ui_for_rule_library(cursor) 90 | rules = dumps(rules) 91 | return rules 92 | else: 93 | assert False 94 | 95 | 96 | @app.route('/api/delete_rule/', methods=['GET']) 97 | def delete_rule(id): 98 | result = mongo.db.rule.delete_one({'_id': ObjectId(id)}) 99 | msg = {} 100 | if result.deleted_count > 0: 101 | msg['msg'] = "success" 102 | else: 103 | msg['msg'] = "fail" 104 | return msg 105 | 106 | 107 | @app.route('/api/input', methods=['POST']) 108 | def input_text_post(): 109 | content = request.get_json() 110 | input_text = content.get('text-input') 111 | if input_text is None or len(input_text) == 0: 112 | return {'msg': 'empty'} 113 | 114 | print('--------------------------------------------') 115 | print('request time:', datetime.now()) 116 | print('request IP:', request.remote_addr) 117 | print(input_text) 118 | 119 | kp_constraints = content.get('keypoint-constraints') 120 | if kp_constraints: 121 | pose_constraints_input = np.array(kp_constraints) 122 | pose_constraints = sgapi.convert_pose_coordinate_for_model(np.copy(pose_constraints_input)) 123 | else: 124 | pose_constraints = None 125 | pose_constraints_input = None 126 | 127 | style_constraints = content.get('style-constraints') 128 | if style_constraints: 129 | style_constraints = np.array(style_constraints) 130 | else: 131 | style_constraints = None 132 | 133 | result = {} 134 | result['msg'] = "success" 135 | result['input-pose-constraints'] = pose_constraints_input.tolist() if pose_constraints_input is not None else None 136 | result['input-style-constraints'] = style_constraints.tolist() if style_constraints is not None else None 137 | result['input-voice'] = content.get('voice') 138 | result['is-manual-scenario'] = content.get('is-manual-scenario') 139 | 140 | if content.get('is-manual-scenario'): 141 | # interpolate key poses 142 | n_frames = pose_constraints_input.shape[0] 143 | n_joints = int((pose_constraints_input.shape[1] - 1) / 3) 144 | key_idxs = [i for i, e in enumerate(pose_constraints_input) if e[-1] == 1] 145 | 146 | if len(key_idxs) >= 2: 147 | out_gesture = np.zeros((n_frames, n_joints * 3)) 148 | xs = np.arange(0, n_frames, 1) 149 | 150 | for i in range(n_joints): 151 | pts = pose_constraints_input[key_idxs, i * 3:(i + 1) * 3] 152 | cs = CubicSpline(key_idxs, pts, bc_type='clamped') 153 | out_gesture[:, i * 3:(i + 1) * 3] = cs(xs) 154 | 155 | result['output-data'] = out_gesture.tolist() 156 | result['audio-filename'] = os.path.split(result['input-voice'])[ 157 | 1] # WARNING: assumed manual mode uses external audio file 158 | else: 159 | result['msg'] = "fail" 160 | else: 161 | # run gesture generation model 162 | output = gesture_generator.generate(input_text, pose_constraints=pose_constraints, 163 | style_values=style_constraints, voice=content.get('voice')) 164 | 165 | if output is None: 166 | # something wrong 167 | result['msg'] = "fail" 168 | else: 169 | gesture, audio, tts_filename, words_with_timestamps = output 170 | gesture = sgapi.convert_pose_coordinate_for_ui(gesture) 171 | 172 | result['audio-filename'] = os.path.split(tts_filename)[1] # filename without path 173 | result['words-with-timestamps'] = words_with_timestamps 174 | result['output-data'] = gesture.tolist() 175 | 176 | return result 177 | 178 | 179 | @app.route('/media//') 180 | def download_audio_file(filename, new_filename): 181 | return send_file(os.path.join('./cached_wav', filename), as_attachment=True, attachment_filename=new_filename, 182 | cache_timeout=0) 183 | 184 | 185 | @app.route('/mesh/') 186 | def download_mesh_file(filename): 187 | mesh_path = root_path.joinpath("static", "mesh", filename) 188 | return send_file(str(mesh_path), as_attachment=True, cache_timeout=0) 189 | 190 | 191 | @app.route('/upload_audio', methods=['POST']) 192 | def upload(): 193 | upload_dir = './cached_wav' 194 | file_names = [] 195 | 196 | for key in request.files: 197 | file = request.files[key] 198 | _, ext = os.path.splitext(file.filename) 199 | print('uploaded: ', file.filename) 200 | try: 201 | upload_path = os.path.join(upload_dir, "uploaded_audio" + ext) 202 | file.save(upload_path) 203 | file_names.append(upload_path) 204 | except: 205 | print('save fail: ' + os.path.join(upload_dir, file.filename)) 206 | 207 | return json.dumps({'filename': [f for f in file_names]}) 208 | 209 | 210 | if __name__ == '__main__': 211 | app.run() 212 | -------------------------------------------------------------------------------- /blender/animation_uist2021.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/blender/animation_uist2021.blend -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse==1.2.3 2 | dataclasses==0.8 3 | fasttext==0.9.2 4 | Flask==1.1.2 5 | Flask-PyMongo==2.3.0 6 | google-cloud-texttospeech==1.0.1 7 | librosa==0.8.0 8 | lmdb==1.0.0 9 | matplotlib==3.3.3 10 | numpy==1.19.4 11 | pandas==1.1.4 12 | pyarrow==0.15.0 13 | pymongo==3.11.0 14 | SoundFile==0.10.3.post1 15 | scikit-learn==0.23.2 16 | scipy==1.5.4 17 | tabulate==0.8.7 18 | torch==1.7.1 19 | torchvision==0.8.2 20 | tqdm==4.51.0 21 | umap==0.1.1 22 | waitress==1.4.4 23 | wheel==0.35.1 -------------------------------------------------------------------------------- /sg_core/config/multimodal_context_toolkit.yml: -------------------------------------------------------------------------------- 1 | name: multimodal_context 2 | 3 | train_data_path: ../data/ted_dataset_2020.07/lmdb_train 4 | val_data_path: ../data/ted_dataset_2020.07/lmdb_val 5 | test_data_path: ../data/ted_dataset_2020.07/lmdb_test 6 | 7 | wordembed_dim: 300 8 | wordembed_path: ../data/fasttext/crawl-300d-2M-subword.bin 9 | 10 | model_save_path: ../output/sgtoolkit 11 | random_seed: 0 12 | save_result_video: True 13 | 14 | # model params 15 | model: multimodal_context 16 | pose_representation: 3d_vec 17 | mean_dir_vec: [-0.00225, -0.98496, 0.16212, 0.01831, -0.79641, 0.52568, 0.02496, -0.65216, -0.67807, -0.87815, 0.40211, -0.06526, -0.38831, 0.85245, 0.13283, 0.35888, -0.16606, 0.70720, 0.87728, 0.41491, -0.00166, 0.38441, 0.85739, 0.14593, -0.39277, -0.17973, 0.69081] 18 | mean_pose: [-0.00000, -0.00002, 0.00004, -0.00055, -0.24976, 0.03882, 0.00152, -0.32251, 0.10291, 0.00430, -0.43652, 0.02527, -0.12537, -0.19055, 0.03108, -0.23547, 0.04413, 0.06726, -0.14551, 0.00403, 0.23596, 0.12585, -0.18445, 0.04031, 0.23547, 0.04749, 0.08014, 0.13293, 0.00299, 0.24744] 19 | normalize_motion_data: True 20 | 21 | n_layers: 4 22 | hidden_size: 300 23 | z_type: style_vector 24 | style_val_mean: [0.00241791, 0.48645255, 0] 25 | style_val_std: [0.00120855, 0.17992376, 1] 26 | style_val_max: [0.01574225, 1.5461352 , 1] 27 | input_context: both 28 | use_pose_control: true 29 | use_style_control: true 30 | 31 | # train params 32 | epochs: 80 33 | batch_size: 128 34 | learning_rate: 0.0005 35 | 36 | loss_l1_weight: 500 37 | loss_gan_weight: 5.0 38 | loss_reg_weight: 0.05 39 | loss_warmup: 10 40 | 41 | # eval params 42 | eval_net_path: ../output/h36m_gesture_autoencoder/gesture_autoencoder_checkpoint_best.bin 43 | 44 | # dataset params 45 | motion_resampling_framerate: 15 46 | n_poses: 60 47 | n_pre_poses: 30 48 | subdivision_stride: 20 49 | loader_workers: 4 50 | -------------------------------------------------------------------------------- /sg_core/config/parse_args.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | 4 | def str2bool(v): 5 | """ from https://stackoverflow.com/a/43357954/1361529 """ 6 | if isinstance(v, bool): 7 | return v 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise configargparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def parse_args(): 17 | parser = configargparse.ArgParser() 18 | parser.add('-c', '--config', required=True, is_config_file=True, help='Config file path') 19 | parser.add("--name", type=str, default="main") 20 | parser.add("--train_data_path", type=str, required=True) 21 | parser.add("--val_data_path", type=str, required=True) 22 | parser.add("--test_data_path", type=str, required=False) 23 | parser.add("--model_save_path", required=True) 24 | parser.add("--pose_representation", type=str, default='pca') 25 | parser.add("--pose_norm_stats_path", type=str, default=None) 26 | parser.add("--pose_representation_path", type=str, default=None) 27 | parser.add("--mean_dir_vec", action="append", type=float, nargs='*') 28 | parser.add("--mean_pose", action="append", type=float, nargs='*') 29 | parser.add("--style_val_max", action="append", type=float, nargs='*') 30 | parser.add("--style_val_mean", action="append", type=float, nargs='*') 31 | parser.add("--style_val_std", action="append", type=float, nargs='*') 32 | parser.add("--random_seed", type=int, default=-1) 33 | parser.add("--save_result_video", type=str2bool, default=True) 34 | 35 | # word embedding 36 | parser.add("--wordembed_path", type=str, default=None) 37 | parser.add("--wordembed_dim", type=int, default=200) 38 | parser.add("--freeze_wordembed", type=str2bool, default=False) 39 | 40 | # model 41 | parser.add("--model", type=str, required=True) 42 | parser.add("--epochs", type=int, default=10) 43 | parser.add("--batch_size", type=int, default=50) 44 | parser.add("--dropout_prob", type=float, default=0.3) 45 | parser.add("--n_layers", type=int, default=2) 46 | parser.add("--hidden_size", type=int, default=200) 47 | parser.add("--residual_output", type=str2bool, default=False) 48 | parser.add("--z_type", type=str, default='none') 49 | parser.add("--input_context", type=str, default='both') # text, audio, both 50 | parser.add("--use_pose_control", type=str2bool, default=True) 51 | parser.add("--use_style_control", type=str2bool, default=True) 52 | 53 | # dataset 54 | parser.add("--motion_resampling_framerate", type=int, default=24) 55 | parser.add("--n_poses", type=int, default=50) 56 | parser.add("--n_pre_poses", type=int, default=5) 57 | parser.add("--subdivision_stride", type=int, default=5) 58 | parser.add("--normalize_motion_data", type=str2bool, default=False) 59 | parser.add("--augment_data", type=str2bool, default=False) 60 | parser.add("--loader_workers", type=int, default=0) 61 | 62 | # GAN parameter 63 | parser.add("--GAN_noise_size", type=int, default=0) 64 | 65 | # training 66 | parser.add("--diff_augment", type=str2bool, default=True) 67 | parser.add("--learning_rate", type=float, default=0.001) 68 | parser.add("--discriminator_lr_weight", type=float, default=0.2) 69 | parser.add("--loss_l1_weight", type=float, default=50) 70 | parser.add("--loss_gan_weight", type=float, default=1.0) 71 | parser.add("--loss_reg_weight", type=float, default=0.01) 72 | parser.add("--loss_warmup", type=int, default=-1) 73 | 74 | # eval 75 | parser.add("--eval_net_path", type=str, default='') 76 | 77 | args = parser.parse_args() 78 | return args 79 | -------------------------------------------------------------------------------- /sg_core/output/h36m_gesture_autoencoder/gesture_autoencoder_checkpoint_best.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/sg_core/output/h36m_gesture_autoencoder/gesture_autoencoder_checkpoint_best.bin -------------------------------------------------------------------------------- /sg_core/output/sgtoolkit/multimodal_context_checkpoint_best.bin: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eaef7f7179b1019fc7757c5eb76f7ff460e296399198af0bf8b64379f644e486 3 | size 122872161 4 | -------------------------------------------------------------------------------- /sg_core/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/sg_core/scripts/__init__.py -------------------------------------------------------------------------------- /sg_core/scripts/data_loader/data_preprocessor.py: -------------------------------------------------------------------------------- 1 | """ create data samples """ 2 | import logging 3 | from collections import defaultdict 4 | 5 | import lmdb 6 | import math 7 | import numpy as np 8 | import pyarrow 9 | import tqdm 10 | from scipy.signal import savgol_filter 11 | from sklearn.preprocessing import normalize 12 | 13 | import utils.data_utils 14 | from data_loader.motion_preprocessor import MotionPreprocessor 15 | 16 | 17 | class DataPreprocessor: 18 | def __init__(self, clip_lmdb_dir, out_lmdb_dir, n_poses, subdivision_stride, 19 | pose_resampling_fps, mean_pose, mean_dir_vec, disable_filtering=False): 20 | self.n_poses = n_poses 21 | self.subdivision_stride = subdivision_stride 22 | self.skeleton_resampling_fps = pose_resampling_fps 23 | self.mean_pose = mean_pose 24 | self.disable_filtering = disable_filtering 25 | 26 | self.src_lmdb_env = lmdb.open(clip_lmdb_dir, readonly=True, lock=False) 27 | with self.src_lmdb_env.begin() as txn: 28 | self.n_videos = txn.stat()['entries'] 29 | 30 | self.audio_sample_length = int(self.n_poses / self.skeleton_resampling_fps * 16000) 31 | 32 | # create db for samples 33 | map_size = 1024 * 50 # in MB 34 | map_size <<= 20 # in B 35 | self.dst_lmdb_env = lmdb.open(out_lmdb_dir, map_size=map_size) 36 | self.n_out_samples = 0 37 | 38 | def run(self): 39 | n_filtered_out = defaultdict(int) 40 | src_txn = self.src_lmdb_env.begin(write=False) 41 | 42 | # sampling and normalization 43 | cursor = src_txn.cursor() 44 | for key, value in cursor: 45 | video = pyarrow.deserialize(value) 46 | vid = video['vid'] 47 | clips = video['clips'] 48 | print(f'processing data... {vid}') 49 | for clip_idx, clip in enumerate(clips): 50 | filtered_result = self._sample_from_clip(vid, clip) 51 | for type in filtered_result.keys(): 52 | n_filtered_out[type] += filtered_result[type] 53 | 54 | # print stats 55 | with self.dst_lmdb_env.begin() as txn: 56 | logging.info(f"no. of samples: {txn.stat()['entries']}") 57 | n_total_filtered = 0 58 | for type, n_filtered in n_filtered_out.items(): 59 | logging.info('{}: {}'.format(type, n_filtered)) 60 | n_total_filtered += n_filtered 61 | logging.info('no. of excluded samples: {} ({:.1f}%)'.format( 62 | n_total_filtered, 100 * n_total_filtered / (txn.stat()['entries'] + n_total_filtered))) 63 | 64 | # close db 65 | self.src_lmdb_env.close() 66 | self.dst_lmdb_env.sync() 67 | self.dst_lmdb_env.close() 68 | 69 | def _sample_from_clip(self, vid, clip): 70 | clip_skeleton_2d = clip['skeletons'] 71 | clip_skeleton = clip['skeletons_3d'] 72 | clip_audio_raw = clip['audio_raw'] 73 | clip_word_list = clip['words'] 74 | clip_s_f, clip_e_f = clip['start_frame_no'], clip['end_frame_no'] 75 | clip_s_t, clip_e_t = clip['start_time'], clip['end_time'] 76 | 77 | n_filtered_out = defaultdict(int) 78 | 79 | # skeleton resampling 80 | clip_skeleton = utils.data_utils.resample_pose_seq(clip_skeleton, clip_e_t - clip_s_t, 81 | self.skeleton_resampling_fps) 82 | 83 | # divide 84 | aux_info = [] 85 | sample_skeletons_list = [] 86 | sample_words_list = [] 87 | sample_audio_list = [] 88 | 89 | num_subdivision = math.floor( 90 | (len(clip_skeleton) - self.n_poses) 91 | / self.subdivision_stride) + 1 # floor((K - (N+M)) / S) + 1 92 | 93 | for i in range(num_subdivision): 94 | start_idx = i * self.subdivision_stride 95 | fin_idx = start_idx + self.n_poses 96 | 97 | sample_skeletons = clip_skeleton[start_idx:fin_idx] 98 | subdivision_start_time = clip_s_t + start_idx / self.skeleton_resampling_fps 99 | subdivision_end_time = clip_s_t + fin_idx / self.skeleton_resampling_fps 100 | sample_words = self.get_words_in_time_range(word_list=clip_word_list, 101 | start_time=subdivision_start_time, 102 | end_time=subdivision_end_time) 103 | 104 | # raw audio 105 | audio_start = math.floor(start_idx / len(clip_skeleton) * len(clip_audio_raw)) 106 | audio_end = audio_start + self.audio_sample_length 107 | if audio_end > len(clip_audio_raw): # correct size mismatch between poses and audio 108 | n_padding = audio_end - len(clip_audio_raw) 109 | padded_data = np.pad(clip_audio_raw, (0, n_padding), mode='symmetric') 110 | sample_audio = padded_data[audio_start:audio_end] 111 | else: 112 | sample_audio = clip_audio_raw[audio_start:audio_end] 113 | 114 | if len(sample_words) >= 2: 115 | # filtering motion skeleton data 116 | sample_skeletons, filtering_message = MotionPreprocessor(sample_skeletons, self.mean_pose).get() 117 | is_correct_motion = (sample_skeletons != []) 118 | motion_info = {'vid': vid, 119 | 'start_frame_no': clip_s_f + start_idx, 120 | 'end_frame_no': clip_s_f + fin_idx, 121 | 'start_time': subdivision_start_time, 122 | 'end_time': subdivision_end_time, 123 | } 124 | 125 | # logging.info('subdivision', clip_s + start_idx, clip_s + fin_idx, filtering_message) 126 | 127 | if is_correct_motion or self.disable_filtering: 128 | sample_skeletons_list.append(sample_skeletons) 129 | sample_words_list.append(sample_words) 130 | sample_audio_list.append(sample_audio) 131 | aux_info.append(motion_info) 132 | else: 133 | n_filtered_out[filtering_message] += 1 134 | 135 | if len(sample_skeletons_list) > 0: 136 | with self.dst_lmdb_env.begin(write=True) as txn: 137 | for words, poses, audio, aux in zip(sample_words_list, sample_skeletons_list, 138 | sample_audio_list, aux_info): 139 | # apply smooth filter and make normalized directional vectors 140 | poses = np.asarray(poses) 141 | poses = savgol_filter(poses, 7, 3, axis=0) 142 | dir_vec = utils.data_utils.convert_pose_seq_to_dir_vec(poses) 143 | 144 | # save 145 | k = '{:010}'.format(self.n_out_samples).encode('ascii') 146 | v = [words, poses, dir_vec, audio, aux] 147 | v = pyarrow.serialize(v).to_buffer() 148 | txn.put(k, v) 149 | self.n_out_samples += 1 150 | 151 | return n_filtered_out 152 | 153 | @staticmethod 154 | def get_words_in_time_range(word_list, start_time, end_time): 155 | words = [] 156 | 157 | for word in word_list: 158 | _, word_s, word_e = word[0], word[1], word[2] 159 | 160 | if word_s >= end_time: 161 | break 162 | 163 | if word_e <= start_time: 164 | continue 165 | 166 | words.append(word) 167 | 168 | return words 169 | 170 | @staticmethod 171 | def unnormalize_data(normalized_data, data_mean, data_std, dimensions_to_ignore): 172 | """ 173 | this method is from https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/generateMotionData.py#L12 174 | """ 175 | T = normalized_data.shape[0] 176 | D = data_mean.shape[0] 177 | 178 | origData = np.zeros((T, D), dtype=np.float32) 179 | dimensions_to_use = [] 180 | for i in range(D): 181 | if i in dimensions_to_ignore: 182 | continue 183 | dimensions_to_use.append(i) 184 | dimensions_to_use = np.array(dimensions_to_use) 185 | 186 | origData[:, dimensions_to_use] = normalized_data 187 | 188 | # potentially inefficient, but only done once per experiment 189 | stdMat = data_std.reshape((1, D)) 190 | stdMat = np.repeat(stdMat, T, axis=0) 191 | meanMat = data_mean.reshape((1, D)) 192 | meanMat = np.repeat(meanMat, T, axis=0) 193 | origData = np.multiply(origData, stdMat) + meanMat 194 | 195 | return origData 196 | -------------------------------------------------------------------------------- /sg_core/scripts/data_loader/lmdb_data_loader.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import pickle 5 | import random 6 | 7 | import librosa 8 | import numpy as np 9 | import lmdb as lmdb 10 | import torch 11 | from scipy.signal import savgol_filter 12 | from scipy.stats import pearsonr 13 | from torch.nn.utils.rnn import pad_sequence 14 | import torch.nn.functional as F 15 | 16 | from torch.utils.data import Dataset, DataLoader 17 | from torch.utils.data.dataloader import default_collate 18 | from tqdm import tqdm 19 | 20 | import utils.train_utils 21 | import utils.data_utils 22 | from model.vocab import Vocab 23 | from data_loader.data_preprocessor import DataPreprocessor 24 | import pyarrow 25 | 26 | 27 | def default_collate_fn(data): 28 | _, text_padded, pose_seq, vec_seq, audio, style_vec, aux_info = zip(*data) 29 | 30 | text_padded = default_collate(text_padded) 31 | pose_seq = default_collate(pose_seq) 32 | vec_seq = default_collate(vec_seq) 33 | audio = default_collate(audio) 34 | style_vec = default_collate(style_vec) 35 | aux_info = {key: default_collate([d[key] for d in aux_info]) for key in aux_info[0]} 36 | 37 | return torch.tensor([0]), torch.tensor([0]), text_padded, pose_seq, vec_seq, audio, style_vec, aux_info 38 | 39 | 40 | def calculate_style_vec(pose_seq, window_size, mean_pose, style_mean_std=None): 41 | if pose_seq.shape[-1] != 3: 42 | pose_seq = pose_seq.reshape(pose_seq.shape[:-1] + (-1, 3)) 43 | 44 | batch_size = pose_seq.shape[0] 45 | n_poses = pose_seq.shape[1] 46 | style_vec = torch.zeros((batch_size, n_poses, 3), dtype=pose_seq.dtype, device=pose_seq.device) 47 | half_window = window_size // 2 48 | 49 | for i in range(n_poses): 50 | start_idx = max(0, i - half_window) 51 | end_idx = min(n_poses, i + half_window) 52 | poses_roi = pose_seq[:, start_idx:end_idx] 53 | 54 | # motion speed 55 | diff = poses_roi[:, 1:] - poses_roi[:, :-1] 56 | motion_speed = torch.mean(torch.abs(diff), dim=(1, 2, 3)) 57 | 58 | # motion acceleration 59 | # accel = diff[:, 1:] - diff[:, :-1] 60 | # motion_accel = torch.mean(torch.abs(accel), dim=(1, 2, 3)) 61 | 62 | # space 63 | space = torch.norm(poses_roi[:, :, 6] - poses_roi[:, :, 9], dim=2) # distance between two hands 64 | space = torch.mean(space, dim=1) 65 | 66 | # handedness 67 | left_arm_move = torch.mean(torch.abs(poses_roi[:, 1:, 6] - poses_roi[:, :-1, 6]), dim=(1, 2)) 68 | right_arm_move = torch.mean(torch.abs(poses_roi[:, 1:, 9] - poses_roi[:, :-1, 9]), dim=(1, 2)) 69 | 70 | handedness = torch.where(right_arm_move > left_arm_move, 71 | left_arm_move / right_arm_move - 1, # (-1, 0] 72 | 1 - right_arm_move / left_arm_move) # [0, 1) 73 | handedness *= 3 # to [-3, 3] 74 | 75 | style_vec[:, i, 0] = motion_speed 76 | style_vec[:, i, 1] = space 77 | style_vec[:, i, 2] = handedness 78 | 79 | # normalize 80 | if style_mean_std is not None: 81 | mean, std, max_val = style_mean_std[0], style_mean_std[1], style_mean_std[2] 82 | style_vec = (style_vec - mean) / std 83 | style_vec = torch.clamp(style_vec, -3, 3) # +-3std 84 | # style_vec = style_vec / max_val 85 | # style_vec = torch.clamp(style_vec, -1, 1) 86 | 87 | return style_vec 88 | 89 | 90 | class SpeechMotionDataset(Dataset): 91 | def __init__(self, lmdb_dir, n_poses, subdivision_stride, pose_resampling_fps, mean_pose, mean_dir_vec, 92 | normalize_motion=False, style_stat=None): 93 | self.lmdb_dir = lmdb_dir 94 | self.n_poses = n_poses 95 | self.subdivision_stride = subdivision_stride 96 | self.skeleton_resampling_fps = pose_resampling_fps 97 | 98 | self.expected_audio_length = int(round(n_poses / pose_resampling_fps * 16000)) 99 | 100 | self.lang_model = None 101 | 102 | if mean_dir_vec.shape[-1] != 3: 103 | mean_dir_vec = mean_dir_vec.reshape(mean_dir_vec.shape[:-1] + (-1, 3)) 104 | self.mean_dir_vec = mean_dir_vec 105 | self.normalize_motion = normalize_motion 106 | 107 | logging.info("Reading data '{}'...".format(lmdb_dir)) 108 | preloaded_dir = lmdb_dir + '_cache' 109 | if not os.path.exists(preloaded_dir): 110 | data_sampler = DataPreprocessor(lmdb_dir, preloaded_dir, n_poses, 111 | subdivision_stride, pose_resampling_fps, mean_pose, mean_dir_vec) 112 | data_sampler.run() 113 | else: 114 | logging.info('Found pre-loaded samples from {}'.format(preloaded_dir)) 115 | 116 | # init lmdb 117 | self.lmdb_env = lmdb.open(preloaded_dir, readonly=True, lock=False) 118 | with self.lmdb_env.begin() as txn: 119 | self.n_samples = txn.stat()['entries'] 120 | 121 | # pre-compute style vec 122 | precomputed_style = lmdb_dir + '_style_vec.npy' 123 | if not os.path.exists(precomputed_style): 124 | if style_stat is not None: 125 | logging.info('Calculating style vectors...') 126 | mean_pose = torch.tensor(mean_pose).squeeze() 127 | mean_dir_vec = torch.tensor(mean_dir_vec).squeeze() 128 | style_stat = torch.tensor(style_stat).squeeze() 129 | self.style_vectors = [] 130 | with self.lmdb_env.begin(write=False) as txn: 131 | for i in tqdm(range(self.n_samples)): 132 | key = '{:010}'.format(i).encode('ascii') 133 | sample = txn.get(key) 134 | sample = pyarrow.deserialize(sample) 135 | word_seq, pose_seq, vec_seq, audio, aux_info = sample 136 | 137 | window_size = pose_resampling_fps * 2 138 | poses = torch.from_numpy(vec_seq).unsqueeze(0) 139 | if normalize_motion: 140 | poses += mean_dir_vec # unnormalize 141 | poses = utils.data_utils.convert_dir_vec_to_pose_torch(poses) # normalized bone lengths 142 | style_vec = calculate_style_vec(poses, window_size, mean_pose, style_stat) 143 | self.style_vectors.append(style_vec[0].numpy()) 144 | self.style_vectors = np.stack(self.style_vectors) 145 | 146 | with open(precomputed_style, 'wb') as f: 147 | np.save(f, self.style_vectors) 148 | print('style npy mean: ', np.mean(self.style_vectors, axis=(0, 1))) 149 | print('style npy std: ', np.std(self.style_vectors, axis=(0, 1))) 150 | else: 151 | self.style_vectors = None 152 | else: 153 | with open(precomputed_style, 'rb') as f: 154 | self.style_vectors = np.load(f) 155 | 156 | def __len__(self): 157 | return self.n_samples 158 | 159 | def __getitem__(self, idx): 160 | with self.lmdb_env.begin(write=False) as txn: 161 | key = '{:010}'.format(idx).encode('ascii') 162 | sample = txn.get(key) 163 | 164 | sample = pyarrow.deserialize(sample) 165 | word_seq, pose_seq, vec_seq, audio, aux_info = sample 166 | 167 | def extend_word_seq(lang, words, end_time=None): 168 | n_frames = self.n_poses 169 | if end_time is None: 170 | end_time = aux_info['end_time'] 171 | frame_duration = (end_time - aux_info['start_time']) / n_frames 172 | 173 | extended_word_indices = np.zeros(n_frames) # zero is the index of padding token 174 | for word in words: 175 | idx = max(0, int(np.floor((word[1] - aux_info['start_time']) / frame_duration))) 176 | if idx < n_frames: 177 | extended_word_indices[idx] = lang.get_word_index(word[0]) 178 | return torch.Tensor(extended_word_indices).long() 179 | 180 | def words_to_tensor(lang, words, end_time=None): 181 | indexes = [lang.SOS_token] 182 | for word in words: 183 | if end_time is not None and word[1] > end_time: 184 | break 185 | indexes.append(lang.get_word_index(word[0])) 186 | indexes.append(lang.EOS_token) 187 | return torch.Tensor(indexes).long() 188 | 189 | duration = aux_info['end_time'] - aux_info['start_time'] 190 | if self.style_vectors is not None: 191 | style_vec = torch.from_numpy(self.style_vectors[idx]) 192 | else: 193 | style_vec = torch.zeros((self.n_poses, 1)) 194 | 195 | do_clipping = True 196 | if do_clipping: 197 | sample_end_time = aux_info['start_time'] + duration * self.n_poses / vec_seq.shape[0] 198 | audio = utils.data_utils.make_audio_fixed_length(audio, self.expected_audio_length) 199 | vec_seq = vec_seq[0:self.n_poses] 200 | pose_seq = pose_seq[0:self.n_poses] 201 | style_vec = style_vec[0:self.n_poses] 202 | else: 203 | sample_end_time = None 204 | 205 | # motion data normalization 206 | vec_seq = np.copy(vec_seq) 207 | if self.normalize_motion: 208 | vec_seq -= self.mean_dir_vec 209 | 210 | # to tensors 211 | word_seq_tensor = words_to_tensor(self.lang_model, word_seq, sample_end_time) 212 | extended_word_seq = extend_word_seq(self.lang_model, word_seq, sample_end_time) 213 | vec_seq = torch.as_tensor(vec_seq).reshape((vec_seq.shape[0], -1)).float() 214 | pose_seq = torch.as_tensor(np.copy(pose_seq)).reshape((pose_seq.shape[0], -1)).float() 215 | audio = torch.as_tensor(np.copy(audio)).float() 216 | style_vec = style_vec.float() 217 | 218 | return word_seq_tensor, extended_word_seq, pose_seq, vec_seq, audio, style_vec, aux_info 219 | 220 | def set_lang_model(self, lang_model): 221 | self.lang_model = lang_model 222 | 223 | -------------------------------------------------------------------------------- /sg_core/scripts/data_loader/motion_preprocessor.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | 6 | class MotionPreprocessor: 7 | def __init__(self, skeletons, mean_pose): 8 | self.skeletons = np.array(skeletons) 9 | self.mean_pose = np.array(mean_pose).reshape(-1, 3) 10 | self.filtering_message = "PASS" 11 | 12 | def get(self): 13 | assert (self.skeletons is not None) 14 | 15 | # filtering 16 | if self.skeletons != []: 17 | verbose = False 18 | if self.check_frame_diff(verbose): 19 | self.skeletons = [] 20 | self.filtering_message = "frame_diff" 21 | # elif self.check_spine_angle(verbose): 22 | # self.skeletons = [] 23 | # self.filtering_message = "spine_angle" 24 | elif self.check_static_motion(verbose): 25 | if random.random() < 0.9: # keep 10% 26 | self.skeletons = [] 27 | self.filtering_message = "motion_var" 28 | 29 | if self.skeletons != []: 30 | self.skeletons = self.skeletons.tolist() 31 | for i, frame in enumerate(self.skeletons): 32 | # assertion: missing joints 33 | assert not np.isnan(self.skeletons[i]).any() 34 | 35 | return self.skeletons, self.filtering_message 36 | 37 | def check_static_motion(self, verbose=False): 38 | def get_variance(skeleton, joint_idx): 39 | wrist_pos = skeleton[:, joint_idx] 40 | variance = np.sum(np.var(wrist_pos, axis=0)) 41 | return variance 42 | 43 | left_arm_var = get_variance(self.skeletons, 6) 44 | right_arm_var = get_variance(self.skeletons, 9) 45 | 46 | th = 0.002 47 | ret = left_arm_var < th and right_arm_var < th 48 | if verbose: 49 | print('check_static_motion: {}, left var {}, right var {}'.format(ret, left_arm_var, right_arm_var)) 50 | return ret 51 | 52 | def check_frame_diff(self, verbose=False): 53 | diff = np.max(np.abs(np.diff(self.skeletons, axis=0, n=1))) 54 | 55 | th = 0.2 56 | ret = diff > th 57 | if verbose: 58 | print('check_frame_diff: {}, {:.5f}'.format(ret, diff)) 59 | return ret 60 | 61 | def check_spine_angle(self, verbose=False): 62 | def angle_between(v1, v2): 63 | v1_u = v1 / np.linalg.norm(v1) 64 | v2_u = v2 / np.linalg.norm(v2) 65 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 66 | 67 | angles = [] 68 | for i in range(self.skeletons.shape[0]): 69 | spine_vec = self.skeletons[i, 1] - self.skeletons[i, 0] 70 | angle = angle_between(spine_vec, [0, -1, 0]) 71 | angles.append(angle) 72 | 73 | if np.rad2deg(max(angles)) > 30 or np.rad2deg(np.mean(angles)) > 20: 74 | if verbose: 75 | print('skip - check_spine_angle {:.5f}, {:.5f}'.format(max(angles), np.mean(angles))) 76 | return True 77 | else: 78 | if verbose: 79 | print('pass - check_spine_angle {:.5f}'.format(max(angles))) 80 | return False 81 | 82 | 83 | -------------------------------------------------------------------------------- /sg_core/scripts/gesture_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import pickle 4 | import os 5 | import logging 6 | import random 7 | import time 8 | 9 | import soundfile as sf 10 | import librosa 11 | import torch 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import gentle 15 | 16 | from data_loader.data_preprocessor import DataPreprocessor 17 | from utils.data_utils import remove_tags_marks 18 | from utils.train_utils import load_checkpoint_and_model 19 | from utils.tts_helper import TTSHelper 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | gentle_resources = gentle.Resources() 23 | 24 | 25 | class GestureGenerator: 26 | def __init__(self, checkpoint_path, audio_cache_path=None): 27 | args, generator, lang_model, out_dim = load_checkpoint_and_model( 28 | checkpoint_path, device) 29 | self.args = args 30 | self.generator = generator 31 | self.lang_model = lang_model 32 | print(vars(args)) 33 | 34 | if audio_cache_path is None: 35 | audio_cache_path = '../output/cached_wav' 36 | self.tts = TTSHelper(cache_path=audio_cache_path) 37 | 38 | # load mean vec 39 | self.mean_dir_vec = np.array(args.mean_dir_vec).flatten() 40 | self.mean_pose = np.array(args.mean_pose).flatten() 41 | 42 | @staticmethod 43 | def align_words(audio, text): 44 | # resample audio to 8K 45 | audio_8k = librosa.resample(audio, 16000, 8000) 46 | wave_file = 'temp.wav' 47 | sf.write(wave_file, audio_8k, 8000, 'PCM_16') 48 | 49 | # run gentle to align words 50 | aligner = gentle.ForcedAligner(gentle_resources, text, nthreads=2, disfluency=False, 51 | conservative=False) 52 | gentle_out = aligner.transcribe(wave_file, logging=logging) 53 | words_with_timestamps = [] 54 | for gentle_word in gentle_out.words: 55 | if gentle_word.case == 'success': 56 | words_with_timestamps.append([gentle_word.word, gentle_word.start, gentle_word.end]) 57 | 58 | return words_with_timestamps 59 | 60 | def generate(self, input_text, pose_constraints=None, style_values=None, voice=None): 61 | # voice 62 | voice_lower = str(voice).lower() 63 | if voice_lower == 'none' or voice_lower == 'female': 64 | voice_name = 'en-female_2' 65 | elif voice_lower == 'male': 66 | voice_name = 'en-male_2' 67 | else: 68 | voice_name = voice # file path 69 | 70 | # make audio 71 | text_without_tags = remove_tags_marks(input_text) 72 | print(text_without_tags) 73 | 74 | if '.wav' in voice_name or '.mp3' in voice_name: # use external audio file 75 | tts_filename = voice_name 76 | if not os.path.isfile(tts_filename): 77 | return None 78 | else: # TTS 79 | tts_filename = self.tts.synthesis(input_text, voice_name=voice_name, verbose=True) 80 | 81 | audio, audio_sr = librosa.load(tts_filename, mono=True, sr=16000, res_type='kaiser_fast') 82 | 83 | # get timestamps (use caching) 84 | word_timestamps_cache = tts_filename.replace('.wav', '.json') 85 | if not os.path.exists(word_timestamps_cache): 86 | words_with_timestamps = self.align_words(audio, text_without_tags) 87 | with open(word_timestamps_cache, 'w') as outfile: 88 | json.dump(words_with_timestamps, outfile) 89 | else: 90 | with open(word_timestamps_cache) as json_file: 91 | words_with_timestamps = json.load(json_file) 92 | 93 | # run 94 | output = self.generate_core(audio, words_with_timestamps, 95 | pose_constraints=pose_constraints, style_value=style_values) 96 | 97 | # make output match to the audio length 98 | total_frames = math.ceil(len(audio) / 16000 * self.args.motion_resampling_framerate) 99 | output = output[:total_frames] 100 | 101 | return output, audio, tts_filename, words_with_timestamps 102 | 103 | def generate_core(self, audio, words, audio_sr=16000, pose_constraints=None, style_value=None, fade_out=False): 104 | args = self.args 105 | out_list = [] 106 | n_frames = args.n_poses 107 | clip_length = len(audio) / audio_sr 108 | 109 | # pose constraints 110 | mean_vec = torch.from_numpy(np.array(args.mean_dir_vec).flatten()) 111 | if pose_constraints is not None: 112 | assert pose_constraints.shape[1] == len(args.mean_dir_vec) + 1 113 | pose_constraints = torch.from_numpy(pose_constraints) 114 | mask = pose_constraints[:, -1] == 0 115 | if args.normalize_motion_data: # make sure that un-constrained frames have zero or mean values 116 | pose_constraints[:, :-1] = pose_constraints[:, :-1] - mean_vec 117 | pose_constraints[mask, :-1] = 0 118 | else: 119 | pose_constraints[mask, :-1] = mean_vec 120 | pose_constraints = pose_constraints.unsqueeze(0).to(device) 121 | 122 | # divide into inference units and do inferences 123 | unit_time = args.n_poses / args.motion_resampling_framerate 124 | stride_time = (args.n_poses - args.n_pre_poses) / args.motion_resampling_framerate 125 | if clip_length < unit_time: 126 | num_subdivision = 1 127 | else: 128 | num_subdivision = math.ceil((clip_length - unit_time) / stride_time) + 1 129 | audio_sample_length = int(unit_time * audio_sr) 130 | end_padding_duration = 0 131 | 132 | print('{}, {}, {}, {}, {}'.format(num_subdivision, unit_time, clip_length, stride_time, audio_sample_length)) 133 | 134 | out_dir_vec = None 135 | start = time.time() 136 | for i in range(0, num_subdivision): 137 | start_time = i * stride_time 138 | end_time = start_time + unit_time 139 | 140 | # prepare audio input 141 | audio_start = math.floor(start_time / clip_length * len(audio)) 142 | audio_end = audio_start + audio_sample_length 143 | in_audio = audio[audio_start:audio_end] 144 | if len(in_audio) < audio_sample_length: 145 | if i == num_subdivision - 1: 146 | end_padding_duration = audio_sample_length - len(in_audio) 147 | in_audio = np.pad(in_audio, (0, audio_sample_length - len(in_audio)), 'constant') 148 | in_audio = torch.from_numpy(in_audio).unsqueeze(0).to(device).float() 149 | 150 | # prepare text input 151 | word_seq = DataPreprocessor.get_words_in_time_range(word_list=words, start_time=start_time, 152 | end_time=end_time) 153 | extended_word_indices = np.zeros(n_frames) # zero is the index of padding token 154 | frame_duration = (end_time - start_time) / n_frames 155 | for word in word_seq: 156 | print(word[0], end=', ') 157 | idx = max(0, int(np.floor((word[1] - start_time) / frame_duration))) 158 | extended_word_indices[idx] = self.lang_model.get_word_index(word[0]) 159 | print(' ') 160 | in_text_padded = torch.LongTensor(extended_word_indices).unsqueeze(0).to(device) 161 | 162 | # prepare pre constraints 163 | start_frame = (args.n_poses - args.n_pre_poses) * i 164 | end_frame = start_frame + args.n_poses 165 | 166 | if pose_constraints is None: 167 | in_pose_const = torch.zeros((1, n_frames, len(args.mean_dir_vec) + 1)) 168 | if not args.normalize_motion_data: 169 | in_pose_const[:, :, :-1] = mean_vec 170 | else: 171 | in_pose_const = pose_constraints[:, start_frame:end_frame, :] 172 | 173 | if in_pose_const.shape[1] < n_frames: 174 | n_pad = n_frames - in_pose_const.shape[1] 175 | in_pose_const = F.pad(in_pose_const, [0, 0, 0, n_pad, 0, 0], "constant", 0) 176 | 177 | if i > 0: 178 | in_pose_const[0, 0:args.n_pre_poses, :-1] = out_dir_vec.squeeze(0)[-args.n_pre_poses:] 179 | in_pose_const[0, 0:args.n_pre_poses, -1] = 1 # indicating bit for constraints 180 | in_pose_const = in_pose_const.float().to(device) 181 | 182 | # style vector 183 | if style_value is None: 184 | style_vector = None 185 | elif isinstance(style_value, list) or len(style_value.shape) == 1: # global style 186 | style_value = np.nan_to_num(style_value) # nan to zero 187 | style_vector = torch.FloatTensor(style_value).to(device) 188 | style_vector = style_vector.repeat(1, in_text_padded.shape[1], 1) 189 | else: 190 | style_value = np.nan_to_num(style_value) # nan to zero 191 | style_vector = style_value[start_frame:end_frame] 192 | n_pad = in_text_padded.shape[1] - style_vector.shape[0] 193 | if n_pad > 0: 194 | style_vector = np.pad(style_vector, ((0, n_pad), (0, 0)), 'constant', constant_values=0) 195 | style_vector = torch.FloatTensor(style_vector).to(device).unsqueeze(0) 196 | 197 | # inference 198 | print(in_text_padded) 199 | out_dir_vec, *_ = self.generator(in_pose_const, in_text_padded, in_audio, style_vector) 200 | out_seq = out_dir_vec[0, :, :].data.cpu().numpy() 201 | 202 | # smoothing motion transition 203 | if len(out_list) > 0: 204 | last_poses = out_list[-1][-args.n_pre_poses:] 205 | out_list[-1] = out_list[-1][:-args.n_pre_poses] # delete last {n_pre_poses} frames 206 | 207 | for j in range(len(last_poses)): 208 | n = len(last_poses) 209 | prev = last_poses[j] 210 | next = out_seq[j] 211 | out_seq[j] = prev * (n - j) / (n + 1) + next * (j + 1) / (n + 1) 212 | 213 | out_list.append(out_seq) 214 | 215 | print('Avg. inference time: {:.2} s'.format((time.time() - start) / num_subdivision)) 216 | 217 | # aggregate results 218 | out_dir_vec = np.vstack(out_list) 219 | 220 | # fade out to the mean pose 221 | if fade_out: 222 | n_smooth = args.n_pre_poses 223 | start_frame = len(out_dir_vec) - int(end_padding_duration / audio_sr * args.motion_resampling_framerate) 224 | end_frame = start_frame + n_smooth * 2 225 | if len(out_dir_vec) < end_frame: 226 | out_dir_vec = np.pad(out_dir_vec, [(0, end_frame - len(out_dir_vec)), (0, 0)], mode='constant') 227 | 228 | # fade out to mean poses 229 | if args.normalize_motion_data: 230 | out_dir_vec[end_frame - n_smooth:] = np.zeros((len(args.mean_dir_vec))) 231 | else: 232 | out_dir_vec[end_frame - n_smooth:] = args.mean_dir_vec 233 | 234 | # interpolation 235 | y = out_dir_vec[start_frame:end_frame] 236 | x = np.array(range(0, y.shape[0])) 237 | w = np.ones(len(y)) 238 | w[0] = 5 239 | w[-1] = 5 240 | coeffs = np.polyfit(x, y, 2, w=w) 241 | fit_functions = [np.poly1d(coeffs[:, k]) for k in range(0, y.shape[1])] 242 | interpolated_y = [fit_functions[k](x) for k in range(0, y.shape[1])] 243 | interpolated_y = np.transpose(np.asarray(interpolated_y)) # (num_frames x dims) 244 | 245 | out_dir_vec[start_frame:end_frame] = interpolated_y 246 | 247 | if args.normalize_motion_data: 248 | output = out_dir_vec + self.mean_dir_vec # unnormalize 249 | else: 250 | output = out_dir_vec 251 | 252 | return output 253 | -------------------------------------------------------------------------------- /sg_core/scripts/model/embedding_net.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def ConvNormRelu(in_channels, out_channels, downsample=False, padding=0, batchnorm=True): 8 | if not downsample: 9 | k = 3 10 | s = 1 11 | else: 12 | k = 4 13 | s = 2 14 | 15 | conv_block = nn.Conv1d(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) 16 | norm_block = nn.BatchNorm1d(out_channels) 17 | 18 | if batchnorm: 19 | net = nn.Sequential( 20 | conv_block, 21 | norm_block, 22 | nn.LeakyReLU(0.2, True) 23 | ) 24 | else: 25 | net = nn.Sequential( 26 | conv_block, 27 | nn.LeakyReLU(0.2, True) 28 | ) 29 | 30 | return net 31 | 32 | 33 | class PoseEncoderConv(nn.Module): 34 | def __init__(self, length, dim): 35 | super().__init__() 36 | 37 | self.net = nn.Sequential( 38 | ConvNormRelu(dim, 32, batchnorm=True), 39 | ConvNormRelu(32, 64, batchnorm=True), 40 | ConvNormRelu(64, 64, True, batchnorm=True), 41 | nn.Conv1d(64, 32, 3) 42 | ) 43 | 44 | self.out_net = nn.Sequential( 45 | nn.Linear(800, 256), # for 60 frames 46 | # nn.Linear(864, 256), # for 64 frames 47 | # nn.Linear(384, 256), # for 34 frames 48 | nn.BatchNorm1d(256), 49 | nn.LeakyReLU(True), 50 | nn.Linear(256, 128), 51 | nn.BatchNorm1d(128), 52 | nn.LeakyReLU(True), 53 | nn.Linear(128, 32), 54 | ) 55 | 56 | self.fc_mu = nn.Linear(32, 32) 57 | self.fc_logvar = nn.Linear(32, 32) 58 | 59 | def forward(self, poses): 60 | # encode 61 | poses = poses.transpose(1, 2) # to (bs, dim, seq) 62 | out = self.net(poses) 63 | out = out.flatten(1) 64 | out = self.out_net(out) 65 | 66 | # return out, None, None 67 | mu = self.fc_mu(out) 68 | z = mu 69 | return z, None, None 70 | 71 | 72 | class PoseDecoderConv(nn.Module): 73 | def __init__(self, length, dim, use_pre_poses=False): 74 | super().__init__() 75 | self.use_pre_poses = use_pre_poses 76 | 77 | feat_size = 32 78 | if use_pre_poses: 79 | self.pre_pose_net = nn.Sequential( 80 | nn.Linear(dim * 4, 32), 81 | nn.BatchNorm1d(32), 82 | nn.ReLU(), 83 | nn.Linear(32, 32), 84 | ) 85 | feat_size += 32 86 | 87 | if length <= 34: 88 | self.pre_net = nn.Sequential( 89 | nn.Linear(feat_size, 64), 90 | nn.BatchNorm1d(64), 91 | nn.LeakyReLU(True), 92 | nn.Linear(64, length * 4), 93 | ) 94 | elif 34 < length < 128: 95 | self.pre_net = nn.Sequential( 96 | nn.Linear(feat_size, 128), 97 | nn.BatchNorm1d(128), 98 | nn.LeakyReLU(True), 99 | nn.Linear(128, length * 4), 100 | ) 101 | else: 102 | assert False 103 | 104 | self.net = nn.Sequential( 105 | nn.ConvTranspose1d(4, 32, 3), 106 | nn.BatchNorm1d(32), 107 | nn.LeakyReLU(0.2, True), 108 | nn.ConvTranspose1d(32, 32, 3), 109 | nn.BatchNorm1d(32), 110 | nn.LeakyReLU(0.2, True), 111 | nn.Conv1d(32, 32, 3), 112 | nn.Conv1d(32, dim, 3), 113 | ) 114 | 115 | def forward(self, feat, pre_poses=None): 116 | if self.use_pre_poses: 117 | pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1)) 118 | feat = torch.cat((pre_pose_feat, feat), dim=1) 119 | 120 | out = self.pre_net(feat) 121 | out = out.view(feat.shape[0], 4, -1) 122 | out = self.net(out) 123 | out = out.transpose(1, 2) 124 | return out 125 | 126 | 127 | class EmbeddingNet(nn.Module): 128 | def __init__(self, args, pose_dim, n_frames): 129 | super().__init__() 130 | self.pose_encoder = PoseEncoderConv(n_frames, pose_dim) 131 | self.decoder = PoseDecoderConv(n_frames, pose_dim) 132 | 133 | def forward(self, pre_poses, poses): 134 | # poses 135 | if poses is not None: 136 | poses_feat, _, _ = self.pose_encoder(poses) 137 | else: 138 | poses_feat = None 139 | 140 | # decoder 141 | latent_feat = poses_feat 142 | out_poses = self.decoder(latent_feat, pre_poses) 143 | 144 | return poses_feat, None, None, out_poses 145 | 146 | def freeze_pose_nets(self): 147 | for param in self.pose_encoder.parameters(): 148 | param.requires_grad = False 149 | for param in self.decoder.parameters(): 150 | param.requires_grad = False 151 | 152 | -------------------------------------------------------------------------------- /sg_core/scripts/model/embedding_space_evaluator.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import umap 7 | from scipy import linalg 8 | from scipy.spatial.distance import cosine 9 | from sklearn import mixture 10 | from sklearn.cluster import KMeans, MiniBatchKMeans 11 | 12 | from model.embedding_net import EmbeddingNet 13 | 14 | import warnings 15 | warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings 16 | 17 | 18 | class EmbeddingSpaceEvaluator: 19 | def __init__(self, args, embed_net_path, lang_model, device, cluster_sizes=None): 20 | if cluster_sizes is None: 21 | # cluster_sizes = [0.005, 0.01, 0.05, 0.1] 22 | cluster_sizes = [0.005] 23 | self.n_pre_poses = args.n_pre_poses 24 | self.cluster_sizes = cluster_sizes 25 | 26 | # init embed net 27 | ckpt = torch.load(embed_net_path, map_location=device) 28 | n_frames = args.n_poses 29 | word_embeddings = lang_model.word_embedding_weights 30 | mode = 'pose' 31 | self.pose_dim = ckpt['pose_dim'] 32 | self.net = EmbeddingNet(args, self.pose_dim, n_frames).to(device) 33 | self.net.load_state_dict(ckpt['gen_dict']) 34 | self.net.train(False) 35 | 36 | # storage 37 | self.real_feat_list = [] 38 | self.generated_feat_list = [] 39 | self.recon_err_diff = [] 40 | 41 | def reset(self): 42 | self.real_feat_list = [] 43 | self.generated_feat_list = [] 44 | self.recon_err_diff = [] 45 | 46 | def get_no_of_samples(self): 47 | return len(self.real_feat_list) 48 | 49 | def push_samples(self, generated_poses, real_poses): 50 | # convert poses to latent features 51 | pre_poses = real_poses[:, 0:self.n_pre_poses] 52 | with torch.no_grad(): 53 | real_feat, _, _, real_recon = self.net(pre_poses, real_poses) 54 | generated_feat, _, _, generated_recon = self.net(pre_poses, generated_poses) 55 | 56 | self.real_feat_list.append(real_feat.data.cpu().numpy()) 57 | self.generated_feat_list.append(generated_feat.data.cpu().numpy()) 58 | 59 | # reconstruction error 60 | recon_err_real = F.l1_loss(real_poses, real_recon).item() 61 | recon_err_fake = F.l1_loss(generated_poses, generated_recon).item() 62 | self.recon_err_diff.append(recon_err_fake - recon_err_real) 63 | 64 | def get_features_for_viz(self): 65 | generated_feats = np.vstack(self.generated_feat_list) 66 | real_feats = np.vstack(self.real_feat_list) 67 | 68 | transformed_feats = umap.UMAP().fit_transform(np.vstack((generated_feats, real_feats))) 69 | n = int(transformed_feats.shape[0] / 2) 70 | generated_feats = transformed_feats[0:n, :] 71 | real_feats = transformed_feats[n:, :] 72 | 73 | return real_feats, generated_feats 74 | 75 | def get_scores(self): 76 | generated_feats = np.vstack(self.generated_feat_list) 77 | real_feats = np.vstack(self.real_feat_list) 78 | 79 | # print('recon err diff', np.mean(self.recon_err_diff)) 80 | 81 | def frechet_distance(samples_A, samples_B): 82 | A_mu = np.mean(samples_A, axis=0) 83 | A_sigma = np.cov(samples_A, rowvar=False) 84 | B_mu = np.mean(samples_B, axis=0) 85 | B_sigma = np.cov(samples_B, rowvar=False) 86 | try: 87 | frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma) 88 | # print('[DEBUG] frechet distance') 89 | # print(A_mu, A_sigma, B_mu, B_sigma) 90 | # print(np.sum(np.abs(A_mu - B_mu)), np.trace(A_sigma), np.trace(B_sigma)) 91 | # print(np.sum(np.abs(A_mu - B_mu)), np.trace(A_sigma - B_sigma)) 92 | except ValueError: 93 | frechet_dist = 1e+10 94 | return frechet_dist 95 | 96 | #################################################################### 97 | # frechet distance 98 | frechet_dist = frechet_distance(generated_feats, real_feats) 99 | 100 | #################################################################### 101 | # distance between real and generated samples on the latent feature space 102 | dists = [] 103 | for i in range(real_feats.shape[0]): 104 | d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE 105 | dists.append(d) 106 | feat_dist = np.mean(dists) 107 | 108 | return frechet_dist, feat_dist 109 | 110 | @staticmethod 111 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 112 | """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """ 113 | """Numpy implementation of the Frechet Distance. 114 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 115 | and X_2 ~ N(mu_2, C_2) is 116 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 117 | Stable version by Dougal J. Sutherland. 118 | Params: 119 | -- mu1 : Numpy array containing the activations of a layer of the 120 | inception net (like returned by the function 'get_predictions') 121 | for generated samples. 122 | -- mu2 : The sample mean over activations, precalculated on an 123 | representative data set. 124 | -- sigma1: The covariance matrix over activations for generated samples. 125 | -- sigma2: The covariance matrix over activations, precalculated on an 126 | representative data set. 127 | Returns: 128 | -- : The Frechet Distance. 129 | """ 130 | 131 | mu1 = np.atleast_1d(mu1) 132 | mu2 = np.atleast_1d(mu2) 133 | 134 | sigma1 = np.atleast_2d(sigma1) 135 | sigma2 = np.atleast_2d(sigma2) 136 | 137 | assert mu1.shape == mu2.shape, \ 138 | 'Training and test mean vectors have different lengths' 139 | assert sigma1.shape == sigma2.shape, \ 140 | 'Training and test covariances have different dimensions' 141 | 142 | diff = mu1 - mu2 143 | 144 | # Product might be almost singular 145 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 146 | if not np.isfinite(covmean).all(): 147 | msg = ('fid calculation produces singular product; ' 148 | 'adding %s to diagonal of cov estimates') % eps 149 | print(msg) 150 | offset = np.eye(sigma1.shape[0]) * eps 151 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 152 | 153 | # Numerical error might give slight imaginary component 154 | if np.iscomplexobj(covmean): 155 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 156 | m = np.max(np.abs(covmean.imag)) 157 | raise ValueError('Imaginary component {}'.format(m)) 158 | covmean = covmean.real 159 | 160 | tr_covmean = np.trace(covmean) 161 | 162 | return (diff.dot(diff) + np.trace(sigma1) + 163 | np.trace(sigma2) - 2 * tr_covmean) 164 | -------------------------------------------------------------------------------- /sg_core/scripts/model/multimodal_context_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | 7 | from model import vocab 8 | import model.embedding_net 9 | from model.tcn import TemporalConvNet 10 | 11 | 12 | class AudioFeatExtractor(nn.Module): 13 | def __init__(self, feat_dim): 14 | super().__init__() 15 | self.encoder = models.resnet18(pretrained=False) 16 | num_ftrs = self.encoder.fc.in_features 17 | self.encoder.fc = nn.Linear(num_ftrs, feat_dim) 18 | 19 | def forward(self, x): 20 | if len(x.shape) == 3: 21 | x = x.unsqueeze(1) # add channel dim 22 | x = x.repeat(1, 3, 1, 1) # make 3-channels 23 | x = x.float() 24 | out = self.encoder(x) 25 | return out 26 | 27 | 28 | class AudioEncoder(nn.Module): 29 | def __init__(self, n_frames, feat_dim=32): 30 | super().__init__() 31 | self.n_frames = n_frames 32 | self.feat_extractor = AudioFeatExtractor(feat_dim) 33 | 34 | def forward(self, spectrogram): 35 | # divide into blocks and extract features 36 | feat_list = [] 37 | spectrogram_length = spectrogram.shape[2] 38 | block_start_pts = np.array(range(0, self.n_frames)) * spectrogram_length / self.n_frames 39 | for i in range(self.n_frames): 40 | if i-2 < 0: 41 | start = 0 42 | else: 43 | start = np.round(block_start_pts[i-2]) 44 | 45 | if i+1 >= self.n_frames: 46 | end = spectrogram_length 47 | else: 48 | end = block_start_pts[i+1] 49 | 50 | start = int(np.floor(start)) 51 | end = int(min(spectrogram_length, np.ceil(end))) 52 | spectrogram_roi = spectrogram[:, :, start:end] 53 | feat = self.feat_extractor(spectrogram_roi) 54 | feat_list.append(feat) 55 | 56 | out = torch.stack(feat_list, dim=1) 57 | return out 58 | 59 | 60 | class WavEncoder(nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | self.feat_extractor = nn.Sequential( 64 | nn.Conv1d(1, 16, 15, stride=5, padding=1600), 65 | nn.BatchNorm1d(16), 66 | nn.LeakyReLU(0.3, inplace=True), 67 | nn.Conv1d(16, 32, 15, stride=6), 68 | nn.BatchNorm1d(32), 69 | nn.LeakyReLU(0.3, inplace=True), 70 | nn.Conv1d(32, 64, 15, stride=6), 71 | nn.BatchNorm1d(64), 72 | nn.LeakyReLU(0.3, inplace=True), 73 | nn.Conv1d(64, 32, 15, stride=6), 74 | # nn.BatchNorm1d(128), 75 | # nn.LeakyReLU(0.3, inplace=True), 76 | # nn.Conv2d(32, 32, (5, 1), padding=0, stride=1) 77 | ) 78 | 79 | def forward(self, wav_data): 80 | wav_data = wav_data.unsqueeze(1) # add channel dim 81 | out = self.feat_extractor(wav_data) 82 | return out.transpose(1, 2) # to (batch x seq x dim) 83 | 84 | 85 | class TextEncoderTCN(nn.Module): 86 | """ based on https://github.com/locuslab/TCN/blob/master/TCN/word_cnn/model.py """ 87 | def __init__(self, args, n_words, embed_size=300, pre_trained_embedding=None, 88 | kernel_size=2, dropout=0.3, emb_dropout=0.1): 89 | super(TextEncoderTCN, self).__init__() 90 | 91 | if pre_trained_embedding is not None: # use pre-trained embedding (fasttext) 92 | assert pre_trained_embedding.shape[0] == n_words 93 | assert pre_trained_embedding.shape[1] == embed_size 94 | self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding), 95 | freeze=args.freeze_wordembed) 96 | else: 97 | self.embedding = nn.Embedding(n_words, embed_size) 98 | 99 | num_channels = [args.hidden_size] * args.n_layers 100 | self.tcn = TemporalConvNet(embed_size, num_channels, kernel_size, dropout=dropout) 101 | 102 | self.decoder = nn.Linear(num_channels[-1], 32) 103 | self.drop = nn.Dropout(emb_dropout) 104 | self.emb_dropout = emb_dropout 105 | self.init_weights() 106 | 107 | def init_weights(self): 108 | self.decoder.bias.data.fill_(0) 109 | self.decoder.weight.data.normal_(0, 0.01) 110 | 111 | def forward(self, input): 112 | emb = self.drop(self.embedding(input)) 113 | y = self.tcn(emb.transpose(1, 2)).transpose(1, 2) 114 | y = self.decoder(y) 115 | return y.contiguous(), 0 116 | 117 | 118 | class PoseGenerator(nn.Module): 119 | def __init__(self, args, pose_dim, n_words, word_embed_size, word_embeddings): 120 | super().__init__() 121 | self.pre_length = args.n_pre_poses 122 | self.gen_length = args.n_poses - args.n_pre_poses 123 | self.z_type = args.z_type 124 | self.input_context = args.input_context 125 | self.style_vec_size = len(args.style_val_mean)*2 # *2 for indicating bit 126 | 127 | if self.input_context == 'both': 128 | self.in_size = 32 + 32 + pose_dim + 1 # audio_feat + text_feat + last pose + constraint bit 129 | elif self.input_context == 'none': 130 | self.in_size = pose_dim + 1 131 | else: 132 | self.in_size = 32 + pose_dim + 1 # audio or text only 133 | 134 | self.audio_encoder = WavEncoder() 135 | self.text_encoder = TextEncoderTCN(args, n_words, word_embed_size, pre_trained_embedding=word_embeddings, 136 | dropout=args.dropout_prob) 137 | 138 | if self.z_type == 'style_vector': 139 | # self.z_size = 16 + self.style_vec_size 140 | self.z_size = self.style_vec_size 141 | self.in_size += self.z_size 142 | 143 | self.hidden_size = args.hidden_size 144 | self.gru = nn.GRU(self.in_size, hidden_size=self.hidden_size, num_layers=args.n_layers, batch_first=True, 145 | bidirectional=True, dropout=args.dropout_prob) 146 | self.out = nn.Sequential( 147 | # nn.Linear(hidden_size, pose_dim) 148 | nn.Linear(self.hidden_size, self.hidden_size//2), 149 | nn.LeakyReLU(True), 150 | nn.Linear(self.hidden_size//2, pose_dim) 151 | ) 152 | 153 | self.do_flatten_parameters = False 154 | if torch.cuda.device_count() > 1: 155 | self.do_flatten_parameters = True 156 | 157 | def forward(self, pose_constraints, in_text, in_audio, style_vector=None): 158 | decoder_hidden = None 159 | if self.do_flatten_parameters: 160 | self.gru.flatten_parameters() 161 | 162 | text_feat_seq = audio_feat_seq = None 163 | if self.input_context != 'none': 164 | # audio 165 | audio_feat_seq = self.audio_encoder(in_audio) # output (bs, n_frames, feat_size) 166 | 167 | # text 168 | text_feat_seq, _ = self.text_encoder(in_text) 169 | assert(audio_feat_seq.shape[1] == text_feat_seq.shape[1]) 170 | 171 | # z vector 172 | z_mu = z_logvar = None 173 | if self.z_type == 'style_vector' or self.z_type == 'random': 174 | z_context = torch.randn(in_text.shape[0], 16, device=in_text.device) 175 | else: # no z 176 | z_context = None 177 | 178 | # make an input 179 | if self.input_context == 'both': 180 | in_data = torch.cat((pose_constraints, audio_feat_seq, text_feat_seq), dim=2) 181 | elif self.input_context == 'audio': 182 | in_data = torch.cat((pose_constraints, audio_feat_seq), dim=2) 183 | elif self.input_context == 'text': 184 | in_data = torch.cat((pose_constraints, text_feat_seq), dim=2) 185 | else: 186 | assert False 187 | 188 | if self.z_type == 'style_vector': 189 | repeated_z = z_context.unsqueeze(1) 190 | repeated_z = repeated_z.repeat(1, in_data.shape[1], 1) 191 | if style_vector is None: 192 | style_vector = torch.zeros((in_data.shape[0], in_data.shape[1], self.style_vec_size), 193 | device=in_data.device, dtype=torch.float32) 194 | else: 195 | ones = torch.ones((in_data.shape[0], in_data.shape[1], self.style_vec_size//2), 196 | device=in_data.device, dtype=torch.float32) 197 | zeros = torch.zeros((in_data.shape[0], in_data.shape[1], self.style_vec_size//2), 198 | device=in_data.device, dtype=torch.float32) 199 | # style_vec_bit = torch.where(torch.isnan(style_vector), zeros, ones) 200 | style_vec_bit = torch.where(style_vector == 0, zeros, ones) 201 | style_vector[~style_vec_bit.bool()] = 0 # set masked elements to zeros 202 | style_vector = torch.cat((style_vector.float(), style_vec_bit), dim=2) 203 | 204 | # masking on frames having constraining poses 205 | constraint_mask = (pose_constraints[:, :, -1] == 1) 206 | style_vector[constraint_mask] = 0 207 | 208 | # in_data = torch.cat((in_data, repeated_z, style_vector), dim=2) 209 | in_data = torch.cat((in_data, style_vector), dim=2) 210 | elif z_context is not None: 211 | repeated_z = z_context.unsqueeze(1) 212 | repeated_z = repeated_z.repeat(1, in_data.shape[1], 1) 213 | in_data = torch.cat((in_data, repeated_z), dim=2) 214 | 215 | # forward 216 | output, decoder_hidden = self.gru(in_data, decoder_hidden) 217 | output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] # sum bidirectional outputs 218 | output = self.out(output.reshape(-1, output.shape[2])) 219 | decoder_outputs = output.reshape(in_data.shape[0], in_data.shape[1], -1) 220 | # decoder_outputs = torch.tanh(decoder_outputs) 221 | 222 | return decoder_outputs, z_context, z_mu, z_logvar 223 | 224 | 225 | class Discriminator(nn.Module): 226 | def __init__(self, args, input_size, n_words=None, word_embed_size=None, word_embeddings=None): 227 | super().__init__() 228 | self.input_size = input_size 229 | 230 | if n_words and word_embed_size: 231 | self.text_encoder = TextEncoderTCN(n_words, word_embed_size, word_embeddings) 232 | input_size += 32 233 | else: 234 | self.text_encoder = None 235 | 236 | self.hidden_size = args.hidden_size 237 | self.gru = nn.GRU(input_size, hidden_size=self.hidden_size, num_layers=args.n_layers, bidirectional=True, 238 | dropout=args.dropout_prob, batch_first=True) 239 | self.out = nn.Linear(self.hidden_size, 1) 240 | self.out2 = nn.Linear(args.n_poses, 1) 241 | 242 | self.do_flatten_parameters = False 243 | if torch.cuda.device_count() > 1: 244 | self.do_flatten_parameters = True 245 | 246 | def forward(self, poses, in_text=None): 247 | decoder_hidden = None 248 | if self.do_flatten_parameters: 249 | self.gru.flatten_parameters() 250 | 251 | # pose_diff = poses[:, 1:] - poses[:, :-1] 252 | 253 | if self.text_encoder: 254 | text_feat_seq, _ = self.text_encoder(in_text) 255 | poses = torch.cat((poses, text_feat_seq), dim=2) 256 | 257 | output, decoder_hidden = self.gru(poses, decoder_hidden) 258 | output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] # sum bidirectional outputs 259 | 260 | # use the last N outputs 261 | batch_size = poses.shape[0] 262 | # output = output[:, -self.gen_length:] 263 | output = output.contiguous().view(-1, output.shape[2]) 264 | output = self.out(output) # apply linear to every output 265 | output = output.view(batch_size, -1) 266 | output = self.out2(output) 267 | output = torch.sigmoid(output) 268 | 269 | return output 270 | 271 | 272 | class ConvDiscriminator(nn.Module): 273 | def __init__(self, input_size): 274 | super().__init__() 275 | self.input_size = input_size 276 | 277 | self.hidden_size = 64 278 | self.pre_conv = nn.Sequential( 279 | nn.Conv1d(input_size, 16, 3), 280 | nn.BatchNorm1d(16), 281 | nn.LeakyReLU(True), 282 | nn.Conv1d(16, 8, 3), 283 | nn.BatchNorm1d(8), 284 | nn.LeakyReLU(True), 285 | nn.Conv1d(8, 8, 3), 286 | ) 287 | 288 | self.gru = nn.GRU(8, hidden_size=self.hidden_size, num_layers=4, bidirectional=True, 289 | dropout=0.3, batch_first=True) 290 | self.out = nn.Linear(self.hidden_size, 1) 291 | self.out2 = nn.Linear(54, 1) 292 | 293 | self.do_flatten_parameters = False 294 | if torch.cuda.device_count() > 1: 295 | self.do_flatten_parameters = True 296 | 297 | def forward(self, poses, in_text=None): 298 | decoder_hidden = None 299 | if self.do_flatten_parameters: 300 | self.gru.flatten_parameters() 301 | 302 | poses = poses.transpose(1, 2) 303 | feat = self.pre_conv(poses) 304 | feat = feat.transpose(1, 2) 305 | 306 | output, decoder_hidden = self.gru(feat, decoder_hidden) 307 | output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:] # sum bidirectional outputs 308 | 309 | # use the last N outputs 310 | batch_size = poses.shape[0] 311 | # output = output[:, -self.gen_length:] 312 | output = output.contiguous().view(-1, output.shape[2]) 313 | output = self.out(output) # apply linear to every output 314 | output = output.view(batch_size, -1) 315 | output = self.out2(output) 316 | output = torch.sigmoid(output) 317 | 318 | return output 319 | 320 | -------------------------------------------------------------------------------- /sg_core/scripts/model/tcn.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py """ 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils import weight_norm 5 | 6 | 7 | class Chomp1d(nn.Module): 8 | def __init__(self, chomp_size): 9 | super(Chomp1d, self).__init__() 10 | self.chomp_size = chomp_size 11 | 12 | def forward(self, x): 13 | return x[:, :, :-self.chomp_size].contiguous() 14 | 15 | 16 | class TemporalBlock(nn.Module): 17 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): 18 | super(TemporalBlock, self).__init__() 19 | self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, 20 | stride=stride, padding=padding, dilation=dilation)) 21 | self.chomp1 = Chomp1d(padding) 22 | self.relu1 = nn.ReLU() 23 | self.dropout1 = nn.Dropout(dropout) 24 | 25 | self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, 26 | stride=stride, padding=padding, dilation=dilation)) 27 | self.chomp2 = Chomp1d(padding) 28 | self.relu2 = nn.ReLU() 29 | self.dropout2 = nn.Dropout(dropout) 30 | 31 | self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, 32 | self.conv2, self.chomp2, self.relu2, self.dropout2) 33 | self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None 34 | self.relu = nn.ReLU() 35 | self.init_weights() 36 | 37 | def init_weights(self): 38 | self.conv1.weight.data.normal_(0, 0.01) 39 | self.conv2.weight.data.normal_(0, 0.01) 40 | if self.downsample is not None: 41 | self.downsample.weight.data.normal_(0, 0.01) 42 | 43 | def forward(self, x): 44 | out = self.net(x) 45 | res = x if self.downsample is None else self.downsample(x) 46 | return self.relu(out + res) 47 | 48 | 49 | class TemporalConvNet(nn.Module): 50 | def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): 51 | super(TemporalConvNet, self).__init__() 52 | layers = [] 53 | num_levels = len(num_channels) 54 | for i in range(num_levels): 55 | dilation_size = 2 ** i 56 | in_channels = num_inputs if i == 0 else num_channels[i-1] 57 | out_channels = num_channels[i] 58 | layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, 59 | padding=(kernel_size-1) * dilation_size, dropout=dropout)] 60 | 61 | self.network = nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | return self.network(x) 65 | -------------------------------------------------------------------------------- /sg_core/scripts/model/vocab.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import numpy as np 5 | import fasttext 6 | 7 | 8 | class Vocab: 9 | PAD_token = 0 10 | SOS_token = 1 11 | EOS_token = 2 12 | UNK_token = 3 13 | 14 | def __init__(self, name, insert_default_tokens=True): 15 | self.name = name 16 | self.trimmed = False 17 | self.word_embedding_weights = None 18 | self.reset_dictionary(insert_default_tokens) 19 | 20 | def reset_dictionary(self, insert_default_tokens=True): 21 | self.word2index = {} 22 | self.word2count = {} 23 | if insert_default_tokens: 24 | self.index2word = {self.PAD_token: "", self.SOS_token: "", 25 | self.EOS_token: "", self.UNK_token: ""} 26 | else: 27 | self.index2word = {self.UNK_token: ""} 28 | self.n_words = len(self.index2word) # count default tokens 29 | 30 | def index_word(self, word): 31 | if word not in self.word2index: 32 | self.word2index[word] = self.n_words 33 | self.word2count[word] = 1 34 | self.index2word[self.n_words] = word 35 | self.n_words += 1 36 | else: 37 | self.word2count[word] += 1 38 | 39 | def add_vocab(self, other_vocab): 40 | for word, _ in other_vocab.word2count.items(): 41 | self.index_word(word) 42 | 43 | # remove words below a certain count threshold 44 | def trim(self, min_count): 45 | if self.trimmed: 46 | return 47 | self.trimmed = True 48 | 49 | keep_words = [] 50 | 51 | for k, v in self.word2count.items(): 52 | if v >= min_count: 53 | keep_words.append(k) 54 | 55 | logging.info(' word trimming, kept %s / %s = %.4f' % ( 56 | len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) 57 | )) 58 | 59 | # reinitialize dictionary 60 | self.reset_dictionary() 61 | for word in keep_words: 62 | self.index_word(word) 63 | 64 | def get_word_index(self, word): 65 | if word in self.word2index: 66 | return self.word2index[word] 67 | else: 68 | return self.UNK_token 69 | 70 | def load_word_vectors(self, pretrained_path, embedding_dim=300): 71 | logging.info(" loading word vectors from '{}'...".format(pretrained_path)) 72 | 73 | # initialize embeddings to random values for special words 74 | init_sd = 1 / np.sqrt(embedding_dim) 75 | weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) 76 | weights = weights.astype(np.float32) 77 | 78 | # read word vectors 79 | word_model = fasttext.load_model(pretrained_path) 80 | for word, id in self.word2index.items(): 81 | vec = word_model.get_word_vector(word) 82 | weights[id] = vec 83 | 84 | self.word_embedding_weights = weights 85 | 86 | def __get_embedding_weight(self, pretrained_path, embedding_dim=300): 87 | """ function modified from http://ronny.rest/blog/post_2017_08_04_glove/ """ 88 | logging.info("Loading word embedding '{}'...".format(pretrained_path)) 89 | cache_path = os.path.splitext(pretrained_path)[0] + '_cache.pkl' 90 | weights = None 91 | 92 | # use cached file if it exists 93 | if os.path.exists(cache_path): # 94 | with open(cache_path, 'rb') as f: 95 | logging.info(' using cached result from {}'.format(cache_path)) 96 | weights = pickle.load(f) 97 | if weights.shape != (self.n_words, embedding_dim): 98 | logging.warning(' failed to load word embedding weights. reinitializing...') 99 | weights = None 100 | 101 | if weights is None: 102 | # initialize embeddings to random values for special and OOV words 103 | init_sd = 1 / np.sqrt(embedding_dim) 104 | weights = np.random.normal(0, scale=init_sd, size=[self.n_words, embedding_dim]) 105 | weights = weights.astype(np.float32) 106 | 107 | with open(pretrained_path, encoding="utf-8", mode="r") as textFile: 108 | num_embedded_words = 0 109 | for line_raw in textFile: 110 | # extract the word, and embeddings vector 111 | line = line_raw.split() 112 | try: 113 | word, vector = (line[0], np.array(line[1:], dtype=np.float32)) 114 | # if word == 'love': # debugging 115 | # print(word, vector) 116 | 117 | # if it is in our vocab, then update the corresponding weights 118 | id = self.word2index.get(word, None) 119 | if id is not None: 120 | weights[id] = vector 121 | num_embedded_words += 1 122 | except ValueError: 123 | logging.info(' parsing error at {}...'.format(line_raw[:50])) 124 | continue 125 | logging.info(' {} / {} word vectors are found in the embedding'.format(num_embedded_words, len(self.word2index))) 126 | 127 | with open(cache_path, 'wb') as f: 128 | pickle.dump(weights, f) 129 | 130 | return weights 131 | -------------------------------------------------------------------------------- /sg_core/scripts/train_eval/diff_augment.py: -------------------------------------------------------------------------------- 1 | # code modified from https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_pytorch.py 2 | # Differentiable Augmentation for Data-Efficient GAN Training, https://arxiv.org/pdf/2006.10738 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def DiffAugment(x): 9 | for f in AUGMENT_FNS: 10 | x = f(x) 11 | x = x.contiguous() 12 | return x 13 | 14 | 15 | def rand_gaussian(x): 16 | noise = torch.randn(x.size(0), 1, x.size(2), dtype=x.dtype, device=x.device) 17 | noise *= 0.15 18 | x = x + noise 19 | return x 20 | 21 | 22 | AUGMENT_FNS = [rand_gaussian] 23 | -------------------------------------------------------------------------------- /sg_core/scripts/train_eval/train_gan.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import data_loader.lmdb_data_loader 8 | from utils.data_utils import convert_dir_vec_to_pose_torch 9 | 10 | from sg_core.scripts.train_eval.diff_augment import DiffAugment 11 | 12 | 13 | def add_noise(data): 14 | noise = torch.randn_like(data) * 0.1 15 | return data + noise 16 | 17 | 18 | def train_iter_gan(args, epoch, in_text, in_audio, target_data, style_vector, 19 | pose_decoder, discriminator, 20 | pose_dec_optim, dis_optim): 21 | warm_up_epochs = args.loss_warmup 22 | mean_dir_vec = torch.tensor(args.mean_dir_vec).squeeze().to(target_data.device) 23 | mean_pose = torch.tensor(args.mean_pose).squeeze().to(target_data.device) 24 | 25 | # make pose constraints 26 | pose_constraints = target_data.new_zeros((target_data.shape[0], target_data.shape[1], target_data.shape[2] + 1)) 27 | if not args.normalize_motion_data: 28 | # fill with mean data 29 | pose_constraints[:, :, :-1] = mean_dir_vec.repeat(target_data.shape[0], target_data.shape[1], 1) 30 | pose_constraints[:, 0:args.n_pre_poses, :-1] = target_data[:, 0:args.n_pre_poses] 31 | pose_constraints[:, 0:args.n_pre_poses, -1] = 1 # indicating bit for constraints 32 | if args.use_pose_control and random.random() < 0.5: 33 | n_samples = target_data.shape[0] 34 | 35 | copy_length = np.abs(np.random.triangular(-args.n_poses, 0, args.n_poses, n_samples).astype(np.int)) 36 | copy_length = np.clip(copy_length, a_min=1, a_max=args.n_poses - args.n_pre_poses) 37 | 38 | for i in range(n_samples): 39 | copy_point = random.randint(args.n_pre_poses, args.n_poses - copy_length[i]) 40 | pose_constraints[i, copy_point:copy_point + copy_length[i], :-1] = \ 41 | target_data[i, copy_point:copy_point + copy_length[i]] 42 | pose_constraints[i, copy_point:copy_point + copy_length[i], -1] = 1 43 | 44 | if args.use_style_control and random.random() < 0.5: 45 | use_div_reg = True 46 | 47 | # random dropout style element 48 | n_drop = random.randint(0, 2) 49 | if n_drop > 0: 50 | drop_idxs = random.sample(range(style_vector.shape[-1]), k=n_drop) 51 | # style_vector[:, :, drop_idxs] = float('nan') 52 | style_vector[:, :, drop_idxs] = 0 53 | else: 54 | use_div_reg = False 55 | style_vector = None 56 | 57 | ########################################################################################### 58 | # train D 59 | dis_error = None 60 | if epoch > warm_up_epochs and args.loss_gan_weight > 0.0: 61 | dis_optim.zero_grad() 62 | 63 | out_dir_vec, *_ = pose_decoder(pose_constraints, in_text, in_audio, 64 | style_vector) # out shape (batch x seq x dim) 65 | 66 | if args.diff_augment: 67 | dis_real = discriminator(DiffAugment(target_data), in_text) 68 | dis_fake = discriminator(DiffAugment(out_dir_vec.detach()), in_text) 69 | else: 70 | dis_real = discriminator(target_data, in_text) 71 | dis_fake = discriminator(out_dir_vec.detach(), in_text) 72 | 73 | dis_error = torch.sum(-torch.mean(torch.log(dis_real + 1e-8) + torch.log(1 - dis_fake + 1e-8))) # ns-gan 74 | dis_error.backward() 75 | dis_optim.step() 76 | 77 | ########################################################################################### 78 | # train G 79 | pose_dec_optim.zero_grad() 80 | 81 | # decoding 82 | out_dir_vec, z, z_mu, z_logvar = pose_decoder(pose_constraints, in_text, in_audio, style_vector) 83 | 84 | # loss 85 | beta = 0.1 86 | l1_loss = F.smooth_l1_loss(out_dir_vec / beta, target_data / beta) * beta 87 | 88 | if args.diff_augment: 89 | dis_output = discriminator(DiffAugment(out_dir_vec), in_text) 90 | else: 91 | dis_output = discriminator(out_dir_vec, in_text) 92 | 93 | gen_error = -torch.mean(torch.log(dis_output + 1e-8)) 94 | 95 | if args.z_type == 'style_vector' and use_div_reg and args.loss_reg_weight > 0.0: 96 | # calculate style control compliance 97 | style_stat = torch.tensor([args.style_val_mean, args.style_val_std, args.style_val_max]).squeeze().to(out_dir_vec.device) 98 | 99 | if args.normalize_motion_data: 100 | out_dir_vec += mean_dir_vec 101 | 102 | out_joint_poses = convert_dir_vec_to_pose_torch(out_dir_vec) 103 | window_size = args.motion_resampling_framerate * 2 # 2 sec 104 | 105 | out_style = data_loader.lmdb_data_loader.calculate_style_vec(out_joint_poses, window_size, mean_pose, style_stat) 106 | style_compliance = F.l1_loss(style_vector, out_style) 107 | 108 | loss = args.loss_l1_weight * l1_loss + args.loss_reg_weight * style_compliance 109 | else: 110 | loss = args.loss_l1_weight * l1_loss 111 | 112 | if epoch > warm_up_epochs: 113 | loss += args.loss_gan_weight * gen_error 114 | 115 | loss.backward() 116 | pose_dec_optim.step() 117 | 118 | ret_dict = {'loss': args.loss_l1_weight * l1_loss.item()} 119 | 120 | if epoch > warm_up_epochs and args.loss_gan_weight > 0.0: 121 | ret_dict['gen'] = args.loss_gan_weight * gen_error.item() 122 | ret_dict['dis'] = dis_error.item() 123 | 124 | return ret_dict 125 | -------------------------------------------------------------------------------- /sg_core/scripts/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | def __init__(self, name, fmt=':f'): 5 | self.name = name 6 | self.fmt = fmt 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | def __str__(self): 22 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 23 | return fmtstr.format(**self.__dict__) 24 | -------------------------------------------------------------------------------- /sg_core/scripts/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | 4 | import librosa 5 | import numpy as np 6 | import torch 7 | from scipy.interpolate import interp1d 8 | from sklearn.preprocessing import normalize 9 | 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | skeleton_line_pairs = [(0, 1, 'b'), (1, 2, 'darkred'), (2, 3, 'r'), (3, 4, 'orange'), (1, 5, 'darkgreen'), 14 | (5, 6, 'limegreen'), (6, 7, 'darkseagreen')] 15 | dir_vec_pairs = [(0, 1, 0.26), (1, 2, 0.18), (2, 3, 0.14), (1, 4, 0.22), (4, 5, 0.36), 16 | (5, 6, 0.33), (1, 7, 0.22), (7, 8, 0.36), (8, 9, 0.33)] # adjacency and bone length 17 | 18 | 19 | def normalize_string(s): 20 | """ lowercase, trim, and remove non-letter characters """ 21 | s = s.lower().strip() 22 | s = re.sub(r"([,.!?])", r" \1 ", s) # isolate some marks 23 | s = re.sub(r"(['])", r"", s) # remove apostrophe 24 | s = re.sub(r"[^a-zA-Z0-9,.!?]+", r" ", s) # replace other characters with whitespace 25 | s = re.sub(r"\s+", r" ", s).strip() 26 | return s 27 | 28 | 29 | def remove_tags_marks(text): 30 | reg_expr = re.compile('<.*?>|[.,:;!?]+') 31 | clean_text = re.sub(reg_expr, '', text) 32 | return clean_text 33 | 34 | 35 | def extract_melspectrogram(y, sr=16000): 36 | melspec = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=1024, hop_length=512, power=2) 37 | log_melspec = librosa.power_to_db(melspec, ref=np.max) # mels x time 38 | log_melspec = log_melspec.astype('float16') 39 | return log_melspec 40 | 41 | 42 | def calc_spectrogram_length_from_motion_length(n_frames, fps): 43 | ret = (n_frames / fps * 16000 - 1024) / 512 + 1 44 | return int(round(ret)) 45 | 46 | 47 | def resample_pose_seq(poses, duration_in_sec, fps): 48 | n = len(poses) 49 | x = np.arange(0, n) 50 | y = poses 51 | f = interp1d(x, y, axis=0, kind='linear', fill_value='extrapolate') 52 | expected_n = duration_in_sec * fps 53 | x_new = np.arange(0, n, n / expected_n) 54 | interpolated_y = f(x_new) 55 | if hasattr(poses, 'dtype'): 56 | interpolated_y = interpolated_y.astype(poses.dtype) 57 | return interpolated_y 58 | 59 | 60 | def time_stretch_for_words(words, start_time, speech_speed_rate): 61 | for i in range(len(words)): 62 | if words[i][1] > start_time: 63 | words[i][1] = start_time + (words[i][1] - start_time) / speech_speed_rate 64 | words[i][2] = start_time + (words[i][2] - start_time) / speech_speed_rate 65 | 66 | return words 67 | 68 | 69 | def make_audio_fixed_length(audio, expected_audio_length): 70 | n_padding = expected_audio_length - len(audio) 71 | if n_padding > 0: 72 | audio = np.pad(audio, (0, n_padding), mode='symmetric') 73 | else: 74 | audio = audio[0:expected_audio_length] 75 | return audio 76 | 77 | 78 | def pose_pca_transform_npy(poses_npy, pca, out_torch=True): 79 | if len(poses_npy.shape) == 2: 80 | pca_poses = pca.transform(poses_npy).astype(np.float32) # [N x D] -> [N x PCA_D] 81 | else: 82 | n_samples = poses_npy.shape[0] 83 | n_seq = poses_npy.shape[1] 84 | 85 | poses_npy = poses_npy.reshape((-1, poses_npy.shape[-1])) 86 | pca_poses = pca.transform(poses_npy).astype(np.float32) # [N x D] -> [N x PCA_D] 87 | pca_poses = pca_poses.reshape((n_samples, n_seq, -1)) 88 | 89 | if out_torch: 90 | return torch.from_numpy(pca_poses).to(device) 91 | else: 92 | return pca_poses 93 | 94 | 95 | def pose_pca_transform(poses, pca): 96 | poses_npy = poses.data.cpu().numpy() 97 | return pose_pca_transform_npy(poses_npy, pca) 98 | 99 | 100 | def pose_pca_inverse_transform_npy(pca_data_npy, pca, out_torch=True): 101 | if len(pca_data_npy.shape) == 2: # (samples, dim) 102 | poses = pca.inverse_transform(pca_data_npy).astype(np.float32) # [N x PCA_D] -> [N x D] 103 | else: # (samples, seq, dim) 104 | n_samples = pca_data_npy.shape[0] 105 | n_seq = pca_data_npy.shape[1] 106 | 107 | pca_data_npy = pca_data_npy.reshape((-1, pca_data_npy.shape[-1])) 108 | poses = pca.inverse_transform(pca_data_npy).astype(np.float32) # [N x PCA_D] -> [N x D] 109 | poses = poses.reshape((n_samples, n_seq, -1)) 110 | 111 | if out_torch: 112 | return torch.from_numpy(poses).to(device) 113 | else: 114 | return poses 115 | 116 | 117 | def pose_pca_inverse_transform(pca_data, pca): 118 | pca_data_npy = pca_data.data.cpu().numpy() 119 | return pose_pca_inverse_transform_npy(pca_data_npy, pca) 120 | 121 | 122 | def convert_dir_vec_to_pose(vec): 123 | vec = np.array(vec) 124 | 125 | if vec.shape[-1] != 3: 126 | vec = vec.reshape(vec.shape[:-1] + (-1, 3)) 127 | 128 | if len(vec.shape) == 2: 129 | joint_pos = np.zeros((10, 3)) 130 | for j, pair in enumerate(dir_vec_pairs): 131 | joint_pos[pair[1]] = joint_pos[pair[0]] + pair[2] * vec[j] 132 | elif len(vec.shape) == 3: 133 | joint_pos = np.zeros((vec.shape[0], 10, 3)) 134 | for j, pair in enumerate(dir_vec_pairs): 135 | joint_pos[:, pair[1]] = joint_pos[:, pair[0]] + pair[2] * vec[:, j] 136 | elif len(vec.shape) == 4: # (batch, seq, 9, 3) 137 | joint_pos = np.zeros((vec.shape[0], vec.shape[1], 10, 3)) 138 | for j, pair in enumerate(dir_vec_pairs): 139 | joint_pos[:, :, pair[1]] = joint_pos[:, :, pair[0]] + pair[2] * vec[:, :, j] 140 | else: 141 | assert False 142 | 143 | return joint_pos 144 | 145 | 146 | def convert_dir_vec_to_pose_torch(vec): 147 | assert len(vec.shape) == 3 or (len(vec.shape) == 4 and vec.shape[-1] == 3) 148 | 149 | if vec.shape[-1] != 3: 150 | vec = vec.reshape(vec.shape[:-1] + (-1, 3)) 151 | 152 | joint_pos = torch.zeros((vec.shape[0], vec.shape[1], 10, 3), dtype=vec.dtype, device=vec.device) 153 | for j, pair in enumerate(dir_vec_pairs): 154 | joint_pos[:, :, pair[1]] = joint_pos[:, :, pair[0]] + pair[2] * vec[:, :, j] 155 | 156 | return joint_pos 157 | 158 | 159 | def convert_pose_to_line_segments(pose): 160 | line_segments = np.zeros((len(dir_vec_pairs) * 2, 3)) 161 | for j, pair in enumerate(dir_vec_pairs): 162 | line_segments[2 * j] = pose[pair[0]] 163 | line_segments[2 * j + 1] = pose[pair[1]] 164 | 165 | line_segments[:, [1, 2]] = line_segments[:, [2, 1]] # swap y, z 166 | line_segments[:, 2] = -line_segments[:, 2] 167 | return line_segments 168 | 169 | 170 | def convert_dir_vec_to_line_segments(dir_vec): 171 | joint_pos = convert_dir_vec_to_pose(dir_vec) 172 | line_segments = np.zeros((len(dir_vec_pairs) * 2, 3)) 173 | for j, pair in enumerate(dir_vec_pairs): 174 | line_segments[2 * j] = joint_pos[pair[0]] 175 | line_segments[2 * j + 1] = joint_pos[pair[1]] 176 | 177 | line_segments[:, [1, 2]] = line_segments[:, [2, 1]] # swap y, z 178 | line_segments[:, 2] = -line_segments[:, 2] 179 | return line_segments 180 | 181 | 182 | def convert_pose_seq_to_dir_vec(pose): 183 | if pose.shape[-1] != 3: 184 | pose = pose.reshape(pose.shape[:-1] + (-1, 3)) 185 | 186 | if len(pose.shape) == 3: 187 | dir_vec = np.zeros((pose.shape[0], len(dir_vec_pairs), 3)) 188 | for i, pair in enumerate(dir_vec_pairs): 189 | dir_vec[:, i] = pose[:, pair[1]] - pose[:, pair[0]] 190 | dir_vec[:, i, :] = normalize(dir_vec[:, i, :], axis=1) # to unit length 191 | elif len(pose.shape) == 4: # (batch, seq, ...) 192 | dir_vec = np.zeros((pose.shape[0], pose.shape[1], len(dir_vec_pairs), 3)) 193 | for i, pair in enumerate(dir_vec_pairs): 194 | dir_vec[:, :, i] = pose[:, :, pair[1]] - pose[:, :, pair[0]] 195 | for j in range(dir_vec.shape[0]): # batch 196 | for i in range(len(dir_vec_pairs)): 197 | dir_vec[j, :, i, :] = normalize(dir_vec[j, :, i, :], axis=1) # to unit length 198 | else: 199 | assert False 200 | 201 | return dir_vec 202 | 203 | 204 | def normalize_3d_pose(kps): 205 | line_pairs = [(1, 0, 'b'), (2, 1, 'b'), (3, 2, 'b'), 206 | (4, 1, 'g'), (5, 4, 'g'), (6, 5, 'g'), 207 | # left (https://github.com/kenkra/3d-pose-baseline-vmd/wiki/body) 208 | (7, 1, 'r'), (8, 7, 'r'), (9, 8, 'r')] # right 209 | 210 | def unit_vector(vector): 211 | """ Returns the unit vector of the vector. """ 212 | return vector / np.linalg.norm(vector) 213 | 214 | def angle_between(v1, v2): 215 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 216 | 217 | >>> angle_between((1, 0, 0), (0, 1, 0)) 218 | 1.5707963267948966 219 | >>> angle_between((1, 0, 0), (1, 0, 0)) 220 | 0.0 221 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 222 | 3.141592653589793 223 | """ 224 | v1_u = unit_vector(v1) 225 | v2_u = unit_vector(v2) 226 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 227 | 228 | def rotation_matrix(axis, theta): 229 | """ 230 | Return the rotation matrix associated with counterclockwise rotation about 231 | the given axis by theta radians. 232 | """ 233 | axis = np.asarray(axis) 234 | axis = axis / math.sqrt(np.dot(axis, axis)) 235 | a = math.cos(theta / 2.0) 236 | b, c, d = -axis * math.sin(theta / 2.0) 237 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 238 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 239 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 240 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 241 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 242 | 243 | n_frames = kps.shape[0] 244 | for i in range(n_frames): 245 | # refine spine angles 246 | spine_vec = kps[i, 1] - kps[i, 0] 247 | angle = angle_between([0, -1, 0], spine_vec) 248 | th = np.deg2rad(10) 249 | if angle > th: 250 | angle = angle - th 251 | rot = rotation_matrix(np.cross([0, -1, 0], spine_vec), angle) 252 | kps[i] = np.matmul(kps[i], rot) 253 | 254 | # rotate 255 | shoulder_vec = kps[i, 7] - kps[i, 4] 256 | angle = np.pi - np.math.atan2(shoulder_vec[2], shoulder_vec[0]) # angles on XZ plane 257 | # if i == 0: 258 | # print(angle, np.rad2deg(angle)) 259 | if 180 > np.rad2deg(angle) > 20: 260 | angle = angle - np.deg2rad(20) 261 | rotate = True 262 | elif 180 < np.rad2deg(angle) < 340: 263 | angle = angle - np.deg2rad(340) 264 | rotate = True 265 | else: 266 | rotate = False 267 | 268 | if rotate: 269 | rot = rotation_matrix([0, 1, 0], angle) 270 | kps[i] = np.matmul(kps[i], rot) 271 | 272 | # rotate 180 deg 273 | rot = rotation_matrix([0, 1, 0], np.pi) 274 | kps[i] = np.matmul(kps[i], rot) 275 | 276 | # size 277 | bone_lengths = [] 278 | for pair in line_pairs: 279 | bone_lengths.append(np.linalg.norm(kps[i, pair[0], :] - kps[i, pair[1], :])) 280 | scale_factor = 0.2 / np.mean(bone_lengths) 281 | kps[i] *= scale_factor 282 | 283 | return kps 284 | -------------------------------------------------------------------------------- /sg_core/scripts/utils/gui_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import normalize 3 | 4 | from utils.data_utils import convert_dir_vec_to_pose, dir_vec_pairs 5 | 6 | 7 | def convert_pose_to_line_segments(pose): 8 | line_segments = np.zeros((len(dir_vec_pairs) * 2, 3)) 9 | for j, pair in enumerate(dir_vec_pairs): 10 | line_segments[2 * j] = pose[pair[0]] 11 | line_segments[2 * j + 1] = pose[pair[1]] 12 | 13 | line_segments[:, [1, 2]] = line_segments[:, [2, 1]] # swap y, z 14 | line_segments[:, 2] = -line_segments[:, 2] 15 | return line_segments 16 | 17 | 18 | def convert_dir_vec_to_line_segments(dir_vec): 19 | joint_pos = convert_dir_vec_to_pose(dir_vec) 20 | line_segments = np.zeros((len(dir_vec_pairs) * 2, 3)) 21 | for j, pair in enumerate(dir_vec_pairs): 22 | line_segments[2 * j] = joint_pos[pair[0]] 23 | line_segments[2 * j + 1] = joint_pos[pair[1]] 24 | 25 | line_segments[:, [1, 2]] = line_segments[:, [2, 1]] # swap y, z 26 | line_segments[:, 2] = -line_segments[:, 2] 27 | return line_segments 28 | 29 | -------------------------------------------------------------------------------- /sg_core/scripts/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import subprocess 5 | from collections import defaultdict, namedtuple 6 | from logging.handlers import RotatingFileHandler 7 | from textwrap import wrap 8 | 9 | import numpy as np 10 | import time 11 | import math 12 | import soundfile as sf 13 | 14 | import matplotlib 15 | import matplotlib.pyplot as plt 16 | import torch 17 | import matplotlib.ticker as ticker 18 | import matplotlib.animation as animation 19 | from mpl_toolkits import mplot3d 20 | 21 | import utils.data_utils 22 | import train 23 | 24 | 25 | matplotlib.rcParams['axes.unicode_minus'] = False 26 | 27 | 28 | def set_logger(log_path=None, log_filename='log'): 29 | for handler in logging.root.handlers[:]: 30 | logging.root.removeHandler(handler) 31 | handlers = [logging.StreamHandler()] 32 | if log_path is not None: 33 | os.makedirs(log_path, exist_ok=True) 34 | handlers.append( 35 | RotatingFileHandler(os.path.join(log_path, log_filename), maxBytes=10 * 1024 * 1024, backupCount=5)) 36 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s: %(message)s', handlers=handlers, 37 | datefmt='%Y%m%d %H:%M:%S') 38 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 39 | 40 | 41 | def as_minutes(s): 42 | m = math.floor(s / 60) 43 | s -= m * 60 44 | return '%dm %ds' % (m, s) 45 | 46 | 47 | def time_since(since): 48 | now = time.time() 49 | s = now - since 50 | return '%s' % as_minutes(s) 51 | 52 | 53 | def create_video_and_save(save_path, epoch, prefix, iter_idx, target, output, title, 54 | audio=None, aux_str=None, clipping_to_shortest_stream=False, delete_audio_file=True): 55 | print('saving a video...') 56 | start = time.time() 57 | 58 | fig = plt.figure(figsize=(8, 4)) 59 | axes = [fig.add_subplot(1, 2, 1, projection='3d'), fig.add_subplot(1, 2, 2, projection='3d')] 60 | axes[0].view_init(elev=20, azim=-60) 61 | axes[1].view_init(elev=20, azim=-60) 62 | fig_title = title 63 | 64 | if aux_str: 65 | fig_title += ('\n' + aux_str) 66 | fig.suptitle('\n'.join(wrap(fig_title, 75)), fontsize='medium') 67 | 68 | # convert to poses 69 | output_poses = utils.data_utils.convert_dir_vec_to_pose(output) 70 | target_poses = None 71 | if target is not None: 72 | target_poses = utils.data_utils.convert_dir_vec_to_pose(target) 73 | 74 | def animate(i): 75 | for k, name in enumerate(['target', 'predicted']): 76 | if name == 'target' and target is not None and i < len(target): 77 | pose = target_poses[i] 78 | elif name == 'predicted' and i < len(output): 79 | pose = output_poses[i] 80 | else: 81 | pose = None 82 | 83 | if pose is not None: 84 | axes[k].clear() 85 | for j, pair in enumerate(utils.data_utils.dir_vec_pairs): 86 | axes[k].plot([pose[pair[0], 0], pose[pair[1], 0]], 87 | [pose[pair[0], 2], pose[pair[1], 2]], 88 | [pose[pair[0], 1], pose[pair[1], 1]], 89 | zdir='z', linewidth=5) 90 | axes[k].set_xlim3d(0.5, -0.5) 91 | axes[k].set_ylim3d(0.5, -0.5) 92 | axes[k].set_zlim3d(0.5, -0.5) 93 | axes[k].set_xlabel('x') 94 | axes[k].set_ylabel('z') 95 | axes[k].set_zlabel('y') 96 | axes[k].set_title('{} ({}/{})'.format(name, i + 1, len(output))) 97 | 98 | if target is not None: 99 | num_frames = max(len(target), len(output)) 100 | else: 101 | num_frames = len(output) 102 | ani = animation.FuncAnimation(fig, animate, interval=30, frames=num_frames, repeat=False) 103 | 104 | # show audio 105 | audio_path = None 106 | if audio is not None: 107 | assert len(audio.shape) == 1 # 1-channel, raw signal 108 | audio = audio.astype(np.float32) 109 | sr = 16000 110 | audio_path = '{}/{}_audio_{:03d}_{}.wav'.format(save_path, prefix, epoch, iter_idx) 111 | sf.write(audio_path, audio, sr) 112 | 113 | # save video 114 | try: 115 | video_path = '{}/temp_{}_{:03d}_{}.mp4'.format(save_path, prefix, epoch, iter_idx) 116 | ani.save(video_path, fps=15, dpi=80) # dpi 150 for a higher resolution 117 | del ani 118 | plt.close(fig) 119 | except RuntimeError: 120 | assert False, 'RuntimeError' 121 | 122 | # merge audio and video 123 | if audio is not None: 124 | merged_video_path = '{}/{}_{:03d}_{}.mp4'.format(save_path, prefix, epoch, iter_idx) 125 | cmd = ['ffmpeg', '-loglevel', 'panic', '-y', '-i', video_path, '-i', audio_path, '-strict', '-2', 126 | merged_video_path] 127 | if clipping_to_shortest_stream: 128 | cmd.insert(len(cmd) - 1, '-shortest') 129 | # print(cmd) 130 | subprocess.call(cmd) 131 | if delete_audio_file: 132 | os.remove(audio_path) 133 | os.remove(video_path) 134 | 135 | print('saved, took {:.1f} seconds'.format(time.time() - start)) 136 | return output_poses, target_poses 137 | 138 | 139 | def save_checkpoint(state, filename): 140 | torch.save(state, filename) 141 | logging.info('Saved the checkpoint') 142 | 143 | 144 | def load_checkpoint_and_model(checkpoint_path, _device='cpu'): 145 | print('loading checkpoint {}'.format(checkpoint_path)) 146 | checkpoint = torch.load(checkpoint_path, map_location=_device) 147 | args = checkpoint['args'] 148 | lang_model = checkpoint['lang_model'] 149 | pose_dim = checkpoint['pose_dim'] 150 | 151 | generator, discriminator = train.init_model(args, lang_model, pose_dim, _device) 152 | generator.load_state_dict(checkpoint['gen_dict']) 153 | 154 | # set to eval mode 155 | generator.train(False) 156 | 157 | return args, generator, lang_model, pose_dim 158 | 159 | 160 | def set_random_seed(seed): 161 | torch.manual_seed(seed) 162 | torch.cuda.manual_seed_all(seed) 163 | np.random.seed(seed) 164 | random.seed(seed) 165 | os.environ['PYTHONHASHSEED'] = str(seed) 166 | -------------------------------------------------------------------------------- /sg_core/scripts/utils/tts_helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | import os 4 | import re 5 | import time 6 | 7 | from google.cloud import texttospeech 8 | 9 | 10 | class TTSHelper: 11 | """ helper class for google TTS 12 | set the environment variable GOOGLE_APPLICATION_CREDENTIALS first 13 | GOOGLE_APPLICATION_CREDENTIALS = 'path to json key file' 14 | """ 15 | 16 | cache_folder = './cached_wav/' 17 | 18 | def __init__(self, cache_path=None): 19 | if cache_path is not None: 20 | self.cache_folder = cache_path 21 | 22 | # create cache folder 23 | try: 24 | os.makedirs(self.cache_folder) 25 | except OSError: 26 | pass 27 | 28 | # init tts 29 | self.client = texttospeech.TextToSpeechClient() 30 | self.voice_en_standard = texttospeech.types.VoiceSelectionParams( 31 | language_code='en-US', name='en-US-Standard-B') 32 | self.voice_en_female = texttospeech.types.VoiceSelectionParams( 33 | language_code='en-US', name='en-US-Wavenet-F') 34 | self.voice_en_female_2 = texttospeech.types.VoiceSelectionParams( 35 | language_code='en-US', name='en-US-Wavenet-C') 36 | self.voice_gb_female = texttospeech.types.VoiceSelectionParams( 37 | language_code='en-GB', name='en-GB-Wavenet-C') 38 | self.voice_en_male = texttospeech.types.VoiceSelectionParams( 39 | language_code='en-US', name='en-US-Wavenet-D') 40 | self.voice_en_male_2 = texttospeech.types.VoiceSelectionParams( 41 | language_code='en-US', name='en-US-Wavenet-A') 42 | # self.voice_en_male_2 = texttospeech.types.VoiceSelectionParams( 43 | # language_code='en-US', name='en-AU-Wavenet-B') 44 | self.voice_ko_female = texttospeech.types.VoiceSelectionParams( 45 | language_code='ko-KR', name='ko-KR-Wavenet-A') 46 | self.voice_ko_male = texttospeech.types.VoiceSelectionParams( 47 | language_code='ko-KR', name='ko-KR-Wavenet-D') 48 | self.audio_config_en = texttospeech.types.AudioConfig( 49 | # speaking_rate=0.67, 50 | speaking_rate=1.0, 51 | audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16) # using WAV takes more time than MP3 (about 0.Xs) 52 | self.audio_config_en_slow = texttospeech.types.AudioConfig( 53 | speaking_rate=0.85, 54 | audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16) # using WAV takes more time than MP3 (about 0.Xs) 55 | self.audio_config_kr = texttospeech.types.AudioConfig( 56 | speaking_rate=1.0, 57 | audio_encoding=texttospeech.enums.AudioEncoding.LINEAR16) 58 | 59 | # clean up cache folder 60 | self._cleanup_cachefolder() 61 | 62 | def _cleanup_cachefolder(self): 63 | """ remove least accessed files in the cache """ 64 | dir_to_search = self.cache_folder 65 | for dirpath, dirnames, filenames in os.walk(dir_to_search): 66 | for file in filenames: 67 | curpath = os.path.join(dirpath, file) 68 | file_accessed = datetime.datetime.fromtimestamp(os.path.getatime(curpath)) 69 | if datetime.datetime.now() - file_accessed > datetime.timedelta(days=30): 70 | os.remove(curpath) 71 | 72 | def _string2numeric_hash(self, text): 73 | import hashlib 74 | return int(hashlib.md5(text.encode('utf-8')).hexdigest()[:16], 16) 75 | 76 | def synthesis(self, ssml_text, voice_name='en-female', verbose=False): 77 | if not ssml_text.startswith(u''): 78 | ssml_text = u'' + ssml_text + u'' 79 | 80 | filename = os.path.join(self.cache_folder, str(self._string2numeric_hash(voice_name + ssml_text)) + '.wav') 81 | 82 | # load or synthesis audio 83 | if not os.path.exists(filename): 84 | if verbose: 85 | start = time.time() 86 | 87 | # let's synthesis 88 | if voice_name == 'en-female': 89 | voice = self.voice_en_female 90 | audio_config = self.audio_config_en 91 | elif voice_name == 'en-female_2': 92 | voice = self.voice_en_female_2 93 | audio_config = self.audio_config_en_slow 94 | # audio_config = self.audio_config_en 95 | elif voice_name == 'gb-female': 96 | voice = self.voice_gb_female 97 | audio_config = self.audio_config_en 98 | elif voice_name == 'en-male': 99 | voice = self.voice_en_male 100 | audio_config = self.audio_config_en 101 | elif voice_name == 'en-male_2': 102 | voice = self.voice_en_male_2 103 | # audio_config = self.audio_config_en_slow 104 | audio_config = self.audio_config_en 105 | elif voice_name == 'kr-female': 106 | voice = self.voice_ko_female 107 | audio_config = self.audio_config_kr 108 | elif voice_name == 'kr-male': 109 | voice = self.voice_ko_male 110 | audio_config = self.audio_config_kr 111 | elif voice_name == 'en-standard': 112 | voice = self.voice_en_standard 113 | audio_config = self.audio_config_en 114 | else: 115 | raise ValueError 116 | 117 | synthesis_input = texttospeech.types.SynthesisInput(ssml=ssml_text) 118 | response = self.client.synthesize_speech(synthesis_input, voice, audio_config) 119 | 120 | if verbose: 121 | print('synthesis: took {0:.2f} seconds'.format(time.time() - start)) 122 | start = time.time() 123 | 124 | # save to a file 125 | with open(filename, 'wb') as out: 126 | out.write(response.audio_content) 127 | if verbose: 128 | print('written to file "{}"'.format(filename)) 129 | 130 | if verbose: 131 | print('save wav file: took {0:.2f} seconds'.format(time.time() - start)) 132 | else: 133 | if verbose: 134 | print('use cached file "{}"'.format(filename)) 135 | 136 | return filename 137 | 138 | 139 | def test_tts_helper(): 140 | tts = TTSHelper() 141 | 142 | voice = 'en-male' # 'kr' 143 | text = 'load a new sound buffer from a filename, a python file object' 144 | 145 | # voice = 'kr' 146 | # text = u'나는 나오입니다 안녕하세요.' 147 | 148 | # split into sentences 149 | sentences = list(filter(None, re.split("[.,!?:\-]+", text))) 150 | sentences = [s.strip().lower() for s in sentences] 151 | print(sentences) 152 | 153 | # synthesis 154 | filenames = [] 155 | for s in sentences: 156 | filenames.append(tts.synthesis(s, voice_name=voice, verbose=True)) 157 | 158 | # play 159 | for f in filenames: 160 | sound_obj, duration = tts.get_sound_obj(f) 161 | tts.play(sound_obj) 162 | print('playing... {0:.2f} seconds'.format(duration)) 163 | time.sleep(duration) 164 | 165 | 166 | if __name__ == '__main__': 167 | test_tts_helper() 168 | -------------------------------------------------------------------------------- /sg_core/scripts/utils/vocab_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import lmdb 6 | import pyarrow 7 | 8 | from model.vocab import Vocab 9 | 10 | 11 | def build_vocab(name, dataset_list, cache_path, word_vec_path=None, feat_dim=None): 12 | logging.info(' building a language model...') 13 | if not os.path.exists(cache_path): 14 | lang_model = Vocab(name) 15 | for dataset_path in dataset_list: 16 | logging.info(' indexing words from {}'.format(dataset_path)) 17 | index_words(lang_model, dataset_path) 18 | 19 | if word_vec_path is not None: 20 | lang_model.load_word_vectors(word_vec_path, feat_dim) 21 | 22 | with open(cache_path, 'wb') as f: 23 | pickle.dump(lang_model, f) 24 | else: 25 | logging.info(' loaded from {}'.format(cache_path)) 26 | with open(cache_path, 'rb') as f: 27 | lang_model = pickle.load(f) 28 | 29 | if word_vec_path is None: 30 | lang_model.word_embedding_weights = None 31 | elif lang_model.word_embedding_weights.shape[0] != lang_model.n_words: 32 | logging.warning(' failed to load word embedding weights. check this') 33 | assert False 34 | 35 | return lang_model 36 | 37 | 38 | def index_words(lang_model, lmdb_dir): 39 | lmdb_env = lmdb.open(lmdb_dir, readonly=True, lock=False) 40 | txn = lmdb_env.begin(write=False) 41 | cursor = txn.cursor() 42 | 43 | for key, buf in cursor: 44 | video = pyarrow.deserialize(buf) 45 | 46 | for clip in video['clips']: 47 | for word_info in clip['words']: 48 | word = word_info[0] 49 | lang_model.index_word(word) 50 | 51 | lmdb_env.close() 52 | logging.info(' indexed %d words' % lang_model.n_words) 53 | 54 | # filtering vocab 55 | # MIN_COUNT = 3 56 | # lang_model.trim(MIN_COUNT) 57 | 58 | -------------------------------------------------------------------------------- /sg_core_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper for sg_core modules. 3 | If sg_core changes, only this file should be changed. 4 | """ 5 | import pathlib 6 | 7 | this_dir_path = pathlib.Path(__file__).parent 8 | 9 | # define model file path 10 | model_path = this_dir_path.joinpath('sg_core', 'output', 'sgtoolkit', 'multimodal_context_checkpoint_best.bin') 11 | assert model_path.exists(), "model file ({}) does not exists:".format(str(model_path)) 12 | model_file_name = str(model_path) 13 | 14 | # add sg_core in path 15 | import sys 16 | import os 17 | 18 | sg_core_scripts_path = this_dir_path.joinpath('sg_core', 'scripts') 19 | sg_core_path = this_dir_path.joinpath('sg_core') 20 | gentle_path = this_dir_path.joinpath('gentle') 21 | gentle_ext_path = this_dir_path.joinpath('gentle', 'ext') 22 | google_key_path = this_dir_path.joinpath('sg_core', 'google-key.json') 23 | sys.path.append(str(sg_core_scripts_path)) 24 | sys.path.append(str(sg_core_path)) 25 | sys.path.append(str(gentle_path)) 26 | sys.path.append(str(gentle_ext_path)) 27 | os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = str(google_key_path) 28 | 29 | from sg_core.scripts.gesture_generator import GestureGenerator 30 | import numpy as np 31 | 32 | 33 | def get_gesture_generator(): 34 | audio_cache_path = './cached_wav' 35 | return GestureGenerator(model_file_name, audio_cache_path) 36 | 37 | 38 | def convert_pose_coordinate_for_ui(pose_mat): 39 | return flip_y_axis_of_pose(pose_mat) 40 | 41 | 42 | def convert_pose_coordinate_for_model(constraint_mat): 43 | mask_col = constraint_mat[:, -1] 44 | mask_col = mask_col[:, np.newaxis] 45 | pose_mat = constraint_mat[:, :-1] 46 | pose_mat = flip_y_axis_of_pose(pose_mat) 47 | return np.hstack((pose_mat, mask_col)) 48 | 49 | 50 | def convert_pose_coordinate_for_ui_for_motion_library(motions_cursor): 51 | converted = [] 52 | for motion in list(motions_cursor): 53 | motion_mat = np.array(motion['motion']) 54 | motion['motion'] = convert_pose_coordinate_for_ui(motion_mat).tolist() 55 | converted.append(motion) 56 | 57 | return converted 58 | 59 | 60 | def convert_pose_coordinate_for_ui_for_rule_library(rules_cursor): 61 | converted = [] 62 | for rule in list(rules_cursor): 63 | if (rule['motion_info'] == []): 64 | continue 65 | motion_mat = np.array(rule['motion_info'][0]['motion']) 66 | rule['motion_info'][0]['motion'] = convert_pose_coordinate_for_ui(motion_mat).tolist() 67 | converted.append(rule) 68 | 69 | return converted 70 | 71 | 72 | def flip_y_axis_of_pose(pose_mat): 73 | n_frame = pose_mat.shape[0] 74 | pose_mat = np.reshape(pose_mat, (n_frame, -1, 3)) 75 | pose_mat[:, :, 1] = -pose_mat[:, :, 1] 76 | pose_mat = np.reshape(pose_mat, (n_frame, -1)) 77 | return pose_mat 78 | -------------------------------------------------------------------------------- /static/css/index.css: -------------------------------------------------------------------------------- 1 | .content { 2 | margin: 10px; 3 | } 4 | 5 | .interactive { 6 | margin-top: 10px; 7 | margin-bottom: 10px; 8 | } 9 | 10 | .modal-content.loading-content { 11 | background: transparent; 12 | position: relative; 13 | margin: 0 auto; 14 | border: none; 15 | width: auto; 16 | } 17 | 18 | .spinner-border { 19 | width: 3rem; 20 | height: 3rem; 21 | } 22 | 23 | .expandable-container { 24 | padding: 0px; 25 | } 26 | 27 | .expandable-container .header { 28 | background-color: transparent; 29 | padding: 0px; 30 | cursor: pointer; 31 | font-weight: bold; 32 | } 33 | 34 | .expandable-container .content { 35 | display: none; 36 | border: 1px solid #d3d3d3; 37 | padding: 10px; 38 | margin: 0px; 39 | } 40 | 41 | region.wavesurfer-region { 42 | border-left: 1px solid black; 43 | } 44 | 45 | /* for labels */ 46 | region.wavesurfer-region:after { 47 | content: attr(data-region-label); 48 | position: absolute; 49 | top: 0; 50 | padding-left: 5px; 51 | } 52 | 53 | #renderCanvas { 54 | width: 90%; 55 | height: 90%; 56 | } 57 | 58 | #renderInLibrary { 59 | width: 90%; 60 | height: 90%; 61 | } 62 | 63 | #speech-word { 64 | position: relative; 65 | top: -30px; 66 | } 67 | 68 | #word-canvas { 69 | width: 100%; 70 | height: 50px; 71 | } 72 | 73 | .stop-scrolling { 74 | height: 100%; 75 | overflow: hidden; 76 | } 77 | 78 | .alert { 79 | display: none; 80 | position: absolute; 81 | top: 0; 82 | width: 100%; 83 | } 84 | 85 | .floatlefthalf { 86 | float: left; 87 | width: 50%; 88 | } 89 | 90 | .floatrighthalf { 91 | float: right; 92 | width: 50%; 93 | } 94 | 95 | .selected-region { 96 | background-color: "rgb(255, 20, 147, 0.5)"; 97 | } 98 | 99 | .btn-with-margin { 100 | margin-top: 5px; 101 | margin-bottom: 5px; 102 | } 103 | 104 | .slider.slider-horizontal { 105 | width: 50%; 106 | } 107 | 108 | .slider-track-high { 109 | background: #778899; 110 | } 111 | 112 | .slider-selection { 113 | background: #778899; 114 | } 115 | 116 | .flex-grid { 117 | display: flex; 118 | } 119 | 120 | :root { 121 | --cell-border-width: 0.01em; 122 | } 123 | 124 | .track-frame { 125 | flex: 1; 126 | height: 3em; 127 | background: #CCD1D1; 128 | border: var(--cell-border-width) solid white; 129 | } 130 | 131 | .group-start { 132 | border-left: var(--cell-border-width) solid black; 133 | } 134 | 135 | .group-end { 136 | border-right: var(--cell-border-width) solid black; 137 | } 138 | 139 | .group-middle { 140 | border-top: var(--cell-border-width) solid black; 141 | border-bottom: var(--cell-border-width) solid black; 142 | } 143 | 144 | .cell-track.blue div.modified { 145 | background: #4169E1; 146 | cursor: move; 147 | } 148 | 149 | .cell-track.green div.modified { 150 | background: #00FA9A; 151 | cursor: move; 152 | } 153 | 154 | .cell-track div.selected { 155 | border: var(--cell-border-width) solid crimson; 156 | } 157 | -------------------------------------------------------------------------------- /static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/static/favicon.ico -------------------------------------------------------------------------------- /static/js/avatarInterface.js: -------------------------------------------------------------------------------- 1 | /* 2 | Handles mouse interaction with avatar. 3 | Joint controls are done in here. 4 | This should not hold an avatar object. 5 | */ 6 | class AvatarInterface { 7 | constructor(avatar) { 8 | this.attachInterationCallback(avatar); 9 | //this.handleMeshes = this.attachControlHandles(avatar); 10 | this.skeletonViewer = this.addSkeletonViewer(avatar); 11 | this.rotationGizmos = this.createRotationGizmo(avatar); 12 | this.positionGizmos = this.createPositionGizmo(avatar); 13 | 14 | this.currentGizmoType = 'none'; 15 | this.turnOffEditMode(); 16 | 17 | this.gizmoDragEndCallback = null; 18 | this.gizmoDragCallback = null; 19 | } 20 | 21 | attachControlHandles(avatar) { 22 | // attach invisible spheres to control joints directly 23 | var bones = avatar.getControllableBones(); 24 | var avatarMesh = avatar.getMesh(); 25 | var handleMeshes = []; 26 | for (var i = 0; i < bones.length; i++) { 27 | var sphere = BABYLON.MeshBuilder.CreateSphere("sphere", {diameter: 0.1}, avatar.scene); //default sphere 28 | var bone = bones[i]; 29 | sphere.attachToBone(bone, avatarMesh); 30 | handleMeshes.push(sphere); 31 | sphere.setEnabled(false); 32 | } 33 | return handleMeshes; 34 | } 35 | 36 | addSkeletonViewer(avatar) { 37 | var skeleton = avatar.getSkeleton(); 38 | var mesh = avatar.getMesh(); 39 | var scene = avatar.scene; 40 | var skeletonViewer = new BABYLON.Debug.SkeletonViewer(skeleton, mesh, scene); 41 | skeletonViewer.isEnabled = true; 42 | skeletonViewer.color = BABYLON.Color3.Yellow(); // Change default color from white to red 43 | return skeletonViewer; 44 | } 45 | 46 | attachInterationCallback(avatar) { 47 | avatar.scene.onPointerObservable.add(function (pointerInfo) { 48 | switch (pointerInfo.type) { 49 | case BABYLON.PointerEventTypes.POINTERDOWN: 50 | this.pointerDown(pointerInfo); 51 | case BABYLON.PointerEventTypes.POINTERMOVE: 52 | this.pointerMove(pointerInfo); 53 | 54 | } 55 | }.bind(this)); 56 | } 57 | 58 | pointerDown(pointerInfo) { 59 | 60 | } 61 | 62 | pointerUp() { 63 | //console.log("pointer up"); 64 | } 65 | 66 | pointerMove() { 67 | //console.log("pointer move"); 68 | } 69 | 70 | turnOnEditMode(avatar, gizmoType) { 71 | this.turnOffEditMode(); 72 | 73 | this.skeletonViewer.isEnabled = true; 74 | 75 | if (gizmoType == 'rotation') { 76 | this.attachRotationGizmo(avatar); 77 | } else if (gizmoType == 'position') { 78 | this.attachPostionGizmo(avatar); 79 | } 80 | 81 | this.currentGizmoType = gizmoType; 82 | 83 | } 84 | 85 | turnOffEditMode() { 86 | this.skeletonViewer.isEnabled = false; 87 | if (this.currentGizmoType == 'none') { 88 | return; 89 | } 90 | 91 | if (this.currentGizmoType == 'rotation') { 92 | this.detachGizmos(this.rotationGizmos); 93 | } else if (this.currentGizmoType == 'position') { 94 | this.detachGizmos(this.positionGizmos); 95 | } 96 | 97 | } 98 | 99 | createPositionGizmo(avatar) { 100 | var gizmos = []; 101 | var gizmoScale = 0.75 102 | var utilLayer = new BABYLON.UtilityLayerRenderer(avatar.scene); 103 | var that = this; 104 | 105 | function getOnDragEndCallback(bone) { 106 | // You have to do this weired thing to freeze (or save) 107 | // variable in the loop for callback 108 | // https://stackoverflow.com/questions/7053965/when-using-callbacks-inside-a-loop-in-javascript-is-there-any-way-to-save-a-var 109 | return function () { 110 | if (that.gizmoDragEndCallback != null) { 111 | that.gizmoDragEndCallback(); 112 | } 113 | avatar.refineBoneAfterManualTraslation(bone); 114 | } 115 | } 116 | 117 | function getOnDragStartCallback(bone) { 118 | return function () { 119 | avatar.registerForManualTranslation(bone); 120 | } 121 | } 122 | 123 | var bonesToAttach = avatar.getPositionGizmoAttachableBones(); 124 | 125 | for (var i = 0; i < bonesToAttach.length; i++) { 126 | var gizmo = new BABYLON.PositionGizmo(utilLayer); 127 | var bone = bonesToAttach[i]; 128 | gizmo.onDragStartObservable.add(getOnDragStartCallback(bone)); 129 | gizmo.onDragEndObservable.add(getOnDragEndCallback(bone)); 130 | gizmo.scaleRatio = gizmoScale; 131 | gizmos.push(gizmo); 132 | } 133 | 134 | return gizmos; 135 | } 136 | 137 | createRotationGizmo(avatar) { 138 | 139 | var gizmos = []; 140 | var gizmoScale = 0.75 141 | var utilLayer = new BABYLON.UtilityLayerRenderer(avatar.scene); 142 | var that = this; 143 | 144 | var bonesToAttach = avatar.getRotationGizmoAttachbleBones(); 145 | 146 | for (var i = 0; i < bonesToAttach.length; i++) { 147 | var gizmo = new BABYLON.RotationGizmo(utilLayer); 148 | var bone = bonesToAttach[i]; 149 | gizmo.onDragEndObservable.add(function () { 150 | if (that.gizmoDragEndCallback != null) { 151 | that.gizmoDragEndCallback(); 152 | } 153 | }); 154 | gizmo.scaleRatio = gizmoScale; 155 | gizmos.push(gizmo); 156 | } 157 | 158 | return gizmos; 159 | 160 | } 161 | 162 | createGizmos(avatar, bonesToAttach, gizmoClass) { 163 | 164 | } 165 | 166 | attachPostionGizmo(avatar) { 167 | this.attachGizmos(avatar, avatar.getPositionGizmoAttachableBones(), this.positionGizmos); 168 | } 169 | 170 | attachRotationGizmo(avatar) { 171 | this.attachGizmos(avatar, avatar.getRotationGizmoAttachbleBones(), this.rotationGizmos); 172 | } 173 | 174 | attachGizmos(avatar, bonesToAttach, gizmos) { 175 | for (var i = 0; i < bonesToAttach.length; i++) { 176 | var bone = bonesToAttach[i]; 177 | var gizmo = gizmos[i]; 178 | gizmo.attachedNode = bone; 179 | } 180 | } 181 | 182 | detachGizmos(gizmos) { 183 | for (var i = 0; i < gizmos.length; i++) { 184 | var gizmo = gizmos[i]; 185 | gizmo.attachedNode = null; 186 | } 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /static/js/cell.js: -------------------------------------------------------------------------------- 1 | class Cell { 2 | 3 | constructor(frameId) { 4 | var $div = $("
", { 5 | "class": "track-frame", 6 | "data-fid": frameId 7 | }); 8 | this.fid = frameId; 9 | this.div = $div; 10 | this.types = []; 11 | this.group = null; 12 | } 13 | 14 | setData(data) { 15 | this.data = data; 16 | } 17 | 18 | getData(key) { 19 | return this.data[key]; 20 | } 21 | 22 | updateData(key, value) { 23 | this.data[key] = value; 24 | } 25 | 26 | setControlData(data) { 27 | this.controlData = data; 28 | if (!$(this.div).hasClass("modified")) { 29 | $(this.div).addClass("modified"); 30 | } 31 | } 32 | 33 | getControlData(key) { 34 | if (!this.isModified()) { 35 | return null; 36 | } 37 | return this.controlData[key]; 38 | } 39 | 40 | setGroupPartType(type) { 41 | if (this.types.indexOf(type) == -1) { 42 | this.types.push(type); 43 | } 44 | var $div = $(this.div); 45 | if (type == "start") { 46 | $div.addClass("group-start"); 47 | } else if (type == "middle") { 48 | $div.addClass("group-middle"); 49 | } else if (type == "end") { 50 | $div.addClass("group-end"); 51 | } 52 | } 53 | 54 | resetGroupPartType() { 55 | this.types = []; 56 | $(this.div).removeClass("group-start"); 57 | $(this.div).removeClass("group-middle"); 58 | $(this.div).removeClass("group-end"); 59 | } 60 | 61 | removeControlData() { 62 | if (this.isModified()) { 63 | $(this.div).removeClass("modified"); 64 | } 65 | this.resetGroupPartType(); 66 | this.controlData = null; 67 | this.group = null; 68 | } 69 | 70 | select() { 71 | $(this.div).addClass("selected"); 72 | } 73 | 74 | deselect() { 75 | $(this.div).removeClass("selected"); 76 | } 77 | 78 | toggleSelect() { 79 | if (this.isSelected()) { 80 | this.deselect(); 81 | } else { 82 | this.select(); 83 | } 84 | } 85 | 86 | isModified() { 87 | return $(this.div).hasClass("modified"); 88 | } 89 | 90 | isSelected() { 91 | return $(this.div).hasClass("selected"); 92 | } 93 | 94 | isInControlData(key) { 95 | if (!this.isModified()) { 96 | return false; 97 | } 98 | return key in this.controlData; 99 | } 100 | 101 | isInGroup() { 102 | return (this.group != null); 103 | } 104 | 105 | copyControlDataTo(cell) { 106 | var copiedControlData = deepCopyDict(this.controlData); 107 | cell.setControlData(copiedControlData); 108 | } 109 | 110 | } 111 | 112 | class CellGroup { 113 | constructor() { 114 | this.cells = [] 115 | } 116 | 117 | setCells(cells) { 118 | this.cells = []; 119 | 120 | for (var cell of cells) { 121 | this.cells.push(cell); 122 | cell.group = this; 123 | } 124 | this.startFid = this.cells[0].fid; 125 | this.endFid = this.cells[this.cells.length - 1].fid; 126 | this.updateCellGroupType(); 127 | } 128 | 129 | getCell(idx) { 130 | return this.cells[idx]; 131 | } 132 | 133 | hasCell(cell) { 134 | for (var _cell of cells) { 135 | if (_cell.fid == cell.fid) { 136 | return true; 137 | } 138 | } 139 | 140 | return false; 141 | } 142 | 143 | getNumCells() { 144 | return this.cells.length; 145 | } 146 | 147 | setControlData(controlData) { 148 | var numData = controlData.length; 149 | for (var i = 0; i < numData; i++) { 150 | var cell = this.cells[i]; 151 | cell.setControlData(controlData[i]); 152 | } 153 | } 154 | 155 | removeControlData() { 156 | for (var cell of this.cells) { 157 | cell.removeControlData(); 158 | } 159 | } 160 | 161 | moveTo(newCells) { 162 | var numCells = newCells.length; 163 | if (numCells != this.cells.length) { 164 | return; 165 | } 166 | 167 | var controlData = []; 168 | for (var cell of this.cells) { 169 | controlData.push(cell.controlData); 170 | cell.removeControlData(); 171 | cell.group = null; 172 | } 173 | 174 | this.setCells(newCells); 175 | this.setControlData(controlData); 176 | } 177 | 178 | updateCellGroupType() { 179 | 180 | // This is just for css. 181 | var numCells = this.getNumCells(); 182 | if (numCells == 1) { 183 | var cell = this.cells[0]; 184 | cell.resetGroupPartType(); 185 | cell.setGroupPartType("start"); 186 | cell.setGroupPartType("middle"); 187 | cell.setGroupPartType("end"); 188 | } else { 189 | var firstCell = this.cells[0]; 190 | firstCell.resetGroupPartType(); 191 | firstCell.setGroupPartType("start"); 192 | firstCell.setGroupPartType("middle"); 193 | var lastCell = this.cells[numCells - 1]; 194 | lastCell.resetGroupPartType(); 195 | lastCell.setGroupPartType("end"); 196 | lastCell.setGroupPartType("middle"); 197 | 198 | for (var i = 1; i < numCells - 1; i++) { 199 | var cell = this.cells[i]; 200 | cell.resetGroupPartType(); 201 | cell.setGroupPartType("middle"); 202 | } 203 | } 204 | } 205 | 206 | getMiddleCell() { 207 | var numCells = this.getNumCells(); 208 | var middleIdx = Math.floor(numCells / 2); 209 | return this.cells[middleIdx]; 210 | } 211 | 212 | copyControlDataTo(group) { 213 | var numCells = group.getNumCells(); 214 | if (numCells != this.getNumCells()) { 215 | return; 216 | } 217 | 218 | for (var i = 0; i < numCells; i++) { 219 | this.cells[i].copyControlDataTo(group.cells[i]); 220 | } 221 | } 222 | 223 | remove() { 224 | for (var cell of this.cells) { 225 | cell.removeControlData(); 226 | } 227 | } 228 | } 229 | 230 | class CellGroupSnapshot { 231 | // holds only control data (no frame data). 232 | constructor(group) { 233 | if (group.cells.length == 0) { 234 | return; 235 | } 236 | 237 | this.startFid = group.startFid; 238 | this.endFid = group.endFid; 239 | this.controlDatas = []; 240 | for (var cell of group.cells) { 241 | this.controlDatas.push(deepCopyDict(cell.controlData)); 242 | } 243 | this.length = this.controlDatas.length; 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /static/js/history.js: -------------------------------------------------------------------------------- 1 | class HistoryManager { 2 | constructor() { 3 | this.motionTrackSnapshots = []; 4 | this.styleTrackSnapshots = []; 5 | this.cursor = -1; 6 | } 7 | 8 | addSnapshots(motionTrack, styleTrack) { 9 | 10 | if (this.canRedo()) { 11 | // cursor not at the end. 12 | this.resetFutureFromNow(); 13 | } 14 | 15 | this.motionTrackSnapshots.push(motionTrack.takeSnapshot()); 16 | this.styleTrackSnapshots.push(styleTrack.takeSnapshot()); 17 | this.cursor += 1; 18 | } 19 | 20 | clearHistory() { 21 | this.motionTrackSnapshots = []; 22 | this.styleTrackSnapshots = []; 23 | this.cursor = -1; 24 | } 25 | 26 | canUndo() { 27 | return this.cursor >= 0; 28 | } 29 | 30 | canRedo() { 31 | return this.cursor + 1 < this.motionTrackSnapshots.length; 32 | } 33 | 34 | undo(motionTrack, styleTrack) { 35 | if (!this.canUndo()) { 36 | return; 37 | } 38 | this.cursor -= 1; 39 | console.log("undo", this.cursor); 40 | this.loadCurrentCursor(motionTrack, styleTrack); 41 | } 42 | 43 | redo(motionTrack, styleTrack) { 44 | if (!this.canRedo()) { 45 | return; 46 | } 47 | this.cursor += 1; 48 | console.log("redo", this.cursor); 49 | this.loadCurrentCursor(motionTrack, styleTrack); 50 | } 51 | 52 | loadCurrentCursor(motionTrack, styleTrack) { 53 | if (this.cursor < 0) { 54 | this.tracksToInitialState(motionTrack, styleTrack); 55 | return; 56 | } 57 | var ms = this.motionTrackSnapshots[this.cursor]; 58 | var ss = this.styleTrackSnapshots[this.cursor]; 59 | 60 | motionTrack.loadSnapshot(ms); 61 | styleTrack.loadSnapshot(ss); 62 | } 63 | 64 | tracksToInitialState(motionTrack, styleTrack) { 65 | // initial state 66 | motionTrack.clearGroups(); 67 | styleTrack.clearGroups(); 68 | } 69 | 70 | resetFutureFromNow() { 71 | this.motionTrackSnapshots = this.motionTrackSnapshots.slice(0, this.cursor + 1); 72 | this.styleTrackSnapshots = this.styleTrackSnapshots.slice(0, this.cursor + 1); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /static/js/index.js: -------------------------------------------------------------------------------- 1 | var timeline = null; 2 | var motionLibrary = null; 3 | var ruleManager = null; 4 | 5 | var avatar = null; 6 | var avatarInLibrary = null; 7 | var avatarInterface = null; 8 | var stylePanel = null; 9 | var lastGenerationResult = null; 10 | var uploadedAudioPath = null; 11 | 12 | $(window).on('load', function () { 13 | 14 | avatar = setupAvatar("renderCanvas"); 15 | stylePanel = createStylePanel(); 16 | timeline = new Timeline(avatar, stylePanel, "#waveform", "word-canvas", "#motion-ctrl-track", "#style-ctrl-track"); 17 | 18 | avatar.scene.onReadyObservable.addOnce(function () { 19 | avatarInterface = new AvatarInterface(avatar); 20 | avatarInterface.gizmoDragEndCallback = applyMotionControlToTimeline; 21 | }); 22 | avatarInLibrary = setupAvatar("renderInLibrary"); 23 | motionLibrary = new MotionLibrary(avatarInLibrary, timeline, "#btn-open-motion-library", "#btn-apply-motion", "#btn-delete-motion") 24 | ruleManager = new RuleManager(this, "#btn-apply-rule", "#btn-view-rule", "#viewRuleModal") 25 | 26 | // attach button callbacks 27 | $("#btn-generate").click(generate); 28 | $("#btn-generate-sure").click(postInputText); // btn in generate modal 29 | $("#btn-play").click(togglePlay); 30 | $("#btn-update").click(postInputTextAndConstraints); 31 | $("#btn-export").click(exportData); 32 | $("#btn-open-import-dialog").click(function () { 33 | $("#importJson").modal('show'); 34 | }); 35 | $("#btn-import").click(importData); 36 | $('#btn-avatar-edit-off').click(function () { 37 | avatarInterface.turnOffEditMode(); 38 | }); 39 | $('#btn-avatar-edit-position').click(function () { 40 | avatarInterface.turnOnEditMode(avatar, 'position'); 41 | }); 42 | $('#btn-avatar-edit-rotation').click(function () { 43 | avatarInterface.turnOnEditMode(avatar, 'rotation'); 44 | }); 45 | $('#btn-help').click(function () { 46 | $('#help-modal').modal('show'); 47 | }); 48 | $('#btn-undo').click(function () { 49 | timeline.undoTracks(); 50 | updateUndoReduButtonState(); 51 | }); 52 | $('#btn-redo').click(function () { 53 | timeline.redoTracks(); 54 | updateUndoReduButtonState(); 55 | }) 56 | 57 | // sample text 58 | $('#sampleTextDropdown .dropdown-item').click(function (e) { 59 | let text = $(this).text(); 60 | if (text[2] == '.') { 61 | text = text.substring(4) 62 | } 63 | 64 | $("#text-input").val(text); 65 | 66 | // set voice btn state 67 | if ($(this).data("audio")) { 68 | uploadedAudioPath = $(this).data("audio"); 69 | $("#voice-female").removeAttr('checked'); 70 | $("#voice-female").parent().removeClass('active'); 71 | $("#voice-file").attr('checked', 'checked'); 72 | $("#voice-file").parent().addClass('active'); 73 | } else { 74 | $("#voice-file").removeAttr('checked'); 75 | $("#voice-file").parent().removeClass('active'); 76 | $("#voice-female").attr('checked', 'checked'); 77 | $("#voice-female").parent().addClass('active'); 78 | } 79 | }) 80 | 81 | // for audio upload form 82 | FilePond.registerPlugin(FilePondPluginFileValidateType); 83 | FilePond.setOptions({ 84 | server: { 85 | url: './', 86 | process: { 87 | url: './upload_audio', 88 | headers: {'X-CSRF-TOKEN': $('input[name="csrf_token"]').val()}, 89 | onload: onAudioUploaded, 90 | } 91 | } 92 | }) 93 | FilePond.parse(document.body); 94 | $("input[name$='voice']").click(function () { 95 | var val = $(this).val(); 96 | if (val == 'file') { 97 | $("#audio-upload-form").show(); 98 | } else { 99 | $("#audio-upload-form").hide(); 100 | } 101 | }); 102 | 103 | // in modal 104 | $("#btn-delete-annotation").click(function () { 105 | $('#deleteModal').modal('hide'); 106 | timeline.deleteSelectedCellControlData(); 107 | }); 108 | 109 | $("#btn-motion-track-bypass").on('change', function () { 110 | var checked = $(this).is(':checked') 111 | timeline.toggleMotionTrackControlDataBypass(checked); 112 | }); 113 | 114 | $("#btn-scenario-selector").on('change', function () { 115 | if ($(this).is(':checked')) { 116 | $("#btn-update").html("Apply Control") 117 | } else { 118 | $("#btn-update").html("Interpolate") 119 | } 120 | }); 121 | 122 | timeline.controlDataModificationInvalidWarningId = "#controlDataWarning"; 123 | 124 | // disable mouse wheel event in babylon canvases 125 | $("#renderCanvas").bind("wheel mousewheel", function (e) { 126 | e.preventDefault() 127 | }); 128 | $("#renderInLibrary").bind("wheel mousewheel", function (e) { 129 | e.preventDefault() 130 | }); 131 | 132 | addStylePanelSliderUpdateCallback(); 133 | 134 | }); 135 | 136 | function createStylePanel() { 137 | var stylePanel = null; 138 | stylePanel = new StylePannel(); 139 | stylePanel.addStyle("speed", "#speed-style-slider", "#speed-style-val-label"); 140 | //stylePanel.addStyle("accel", "#accel-style-slider", "#accel-style-val-label"); 141 | stylePanel.addStyle("space", "#space-style-slider", "#space-style-val-label"); 142 | stylePanel.addStyle("handedness", "#handedness-style-slider", "#handedness-style-val-label"); 143 | 144 | stylePanel.addStylePreset("happy", "#btn-preset-happy", {"speed": 2, "space": 1, "handedness": 0}) 145 | stylePanel.addStylePreset("sad", "#btn-preset-sad", {"speed": -1, "space": 0, "handedness": 0}) 146 | stylePanel.addStylePreset("angry", "#btn-preset-angry", {"speed": 2.5, "space": 2, "handedness": 0}) 147 | 148 | return stylePanel; 149 | } 150 | 151 | function addStylePanelSliderUpdateCallback() { 152 | stylePanel.setSliderUpdateCallback(function (data) { 153 | if (timeline.styleTrack == null) { 154 | return; 155 | } 156 | timeline.updateStyleTrackControlData(data); 157 | }); 158 | } 159 | 160 | $(window).keydown(function (e) { 161 | var code = e.code; 162 | var target = $(e.target) 163 | if (!target.is('textarea') && !target.is('input')) { 164 | if (code == "Delete" || code == "KeyD") { 165 | $('#deleteModal').modal('show'); 166 | } else if (code == "KeyF") { 167 | //timeline.fillMotionControl(); 168 | } else if (e.ctrlKey && code == "KeyC") { 169 | timeline.copySelectedCellGroup(); 170 | } else if (e.ctrlKey && code == "KeyV") { 171 | timeline.pasteSelectedCellGroup(); 172 | } 173 | } 174 | }); 175 | 176 | function setupAvatar(id) { 177 | 178 | var canvas = document.getElementById(id); // Get the canvas element 179 | return new Avatar(canvas); 180 | } 181 | 182 | function afterGeneratedCallback(result) { 183 | $("#loading-modal").modal('hide'); 184 | console.log("generated", result); 185 | if (result['msg'] === 'success') { 186 | lastGenerationResult = result; 187 | var audioFilename = result['audio-filename']; 188 | var wordsWithTimestamps = result['words-with-timestamps']; 189 | 190 | var motionKeypoints = result['output-data']; 191 | let is_manual_mode = !$("#btn-scenario-selector").is(':checked') 192 | if (is_manual_mode) { 193 | // set mean pose for all frames 194 | var nFrames = result['output-data'].length 195 | for (var i = 0; i < nFrames; i++) { 196 | motionKeypoints[i] = avatar.meanVec; 197 | } 198 | } 199 | 200 | timeline.load(audioFilename, wordsWithTimestamps, motionKeypoints, stylePanel.getStyleNames()); 201 | 202 | timeline.setPlayCallback(function () { 203 | // bypass motion control when play 204 | var toggleBtn = $("#btn-motion-track-bypass") 205 | toggleBtn.bootstrapToggle('on'); 206 | $("#btn-play").text("Pause"); 207 | }) 208 | 209 | timeline.setPauseCallback(function () { 210 | var toggleBtn = $("#btn-motion-track-bypass") 211 | toggleBtn.bootstrapToggle('off'); 212 | $("#btn-play").text("Play"); 213 | }) 214 | 215 | ruleManager.loadWords(wordsWithTimestamps) 216 | 217 | if (is_manual_mode) { 218 | // set first and last pose controls 219 | // executed after 1 second (wait for track creation) 220 | setTimeout(function () { 221 | var restPose = avatar.meanVec; 222 | timeline.motionTrack.updateMultipleKeypointsAsGroup([restPose], 0); 223 | timeline.motionTrack.updateMultipleKeypointsAsGroup([restPose], timeline.motionTrack.duration); 224 | }, 1000); 225 | } 226 | } else { 227 | bootbox.alert("Hmm! Something went wrong. Please reload the page. If it happens again, please contact the maintainer."); 228 | } 229 | 230 | updateUndoReduButtonState(); 231 | } 232 | 233 | function afterUpdatedCallback(result) { 234 | $("#loading-modal").modal('hide'); 235 | if (result['msg'] === 'success') { 236 | if (!('words-with-timestamps' in result)) { // reuse the previous word information in the manual mode 237 | result['words-with-timestamps'] = lastGenerationResult['words-with-timestamps']; 238 | } 239 | lastGenerationResult = result; 240 | var motionKeypoints = result['output-data']; 241 | timeline.updateMotionKeypoitns(motionKeypoints); 242 | timeline.setCursorToStart(); 243 | } else { 244 | bootbox.alert("Hmm! Something went wrong. Please reload the page. If it happens again, please contact the maintainer."); 245 | } 246 | updateUndoReduButtonState(); 247 | } 248 | 249 | function updateUndoReduButtonState() { 250 | if (timeline.canUndo()) { 251 | $("#btn-undo").removeClass("disabled"); 252 | } else { 253 | $("#btn-undo").addClass("disabled"); 254 | } 255 | 256 | if (timeline.canRedo()) { 257 | $("#btn-redo").removeClass("disabled"); 258 | } else { 259 | $("#btn-redo").addClass("disabled"); 260 | } 261 | } 262 | 263 | function applyMotionControlToTimeline() { 264 | if (timeline.motionTrack == null) { 265 | return; 266 | } 267 | var keypoints = avatar.modelKeypointsToArray(); 268 | timeline.motionTrack.updateKeypointsToFirstSelectedCell(keypoints); 269 | } 270 | 271 | function generate() { 272 | if (timeline.isAnyCellModified()) { 273 | $("#generateModal").modal('show'); 274 | return; 275 | } 276 | postInputText(); 277 | } 278 | 279 | function postInputText() { 280 | $("#generateModal").modal('hide'); 281 | 282 | var inputText = $("#text-input").val(); 283 | 284 | if (inputText.length <= 0) { 285 | bootbox.alert(' Please input speech text.'); 286 | return; 287 | } 288 | 289 | var voiceName = document.querySelector('input[name="voice"]:checked').value; 290 | if (voiceName == 'file') { 291 | voiceName = uploadedAudioPath 292 | } 293 | data = {"text-input": inputText, "voice": voiceName}; 294 | $("#loading-modal").one('shown.bs.modal', function () { 295 | console.log("callback attached"); 296 | postInput(data, afterGeneratedCallback); 297 | }); 298 | $("#loading-modal").modal('show'); 299 | 300 | timeline.clearHistory(); 301 | } 302 | 303 | function postInputTextAndConstraints() { 304 | var inputText = $("#text-input").val(); 305 | 306 | if (inputText.length <= 0) { 307 | bootbox.alert(' Please input speech text.'); 308 | return; 309 | } 310 | 311 | var voiceName = document.querySelector('input[name="voice"]:checked').value; 312 | if (voiceName == 'file') { 313 | voiceName = uploadedAudioPath 314 | } 315 | data = {"text-input": inputText, "voice": voiceName}; 316 | 317 | var keypointConstraints = timeline.getMotionConstraints(); 318 | data['keypoint-constraints'] = keypointConstraints; 319 | var styleConstraints = timeline.getStyleConstraints(); 320 | data['style-constraints'] = styleConstraints; 321 | if ($("#btn-scenario-selector").is(':checked')) { 322 | data['is-manual-scenario'] = 0 323 | } else { 324 | data['is-manual-scenario'] = 1 325 | } 326 | 327 | timeline.saveTracksToHistory(); 328 | $("#loading-modal").one('shown.bs.modal', function () { 329 | postInput(data, afterUpdatedCallback); 330 | }); 331 | $("#loading-modal").modal('show'); 332 | } 333 | 334 | function postInput(data, callback) { 335 | post("api/input", data, callback); 336 | } 337 | 338 | function post(url, data, callback) { 339 | $.ajax({ 340 | type: "POST", 341 | contentType: "application/json; charset=utf-8", 342 | url: url, 343 | data: JSON.stringify(data), 344 | dataType: "json", 345 | success: callback 346 | }); 347 | } 348 | 349 | function get(url, callback) { 350 | $.ajax({ 351 | type: "GET", 352 | contentType: "application/json; charset=utf-8", 353 | url: url, 354 | success: callback 355 | }) 356 | } 357 | 358 | function togglePlay() { 359 | if (timeline.isPlaying()) { 360 | timeline.pause(); 361 | } else { 362 | timeline.play(); 363 | } 364 | } 365 | 366 | function importData() { 367 | $("#importJson").modal('hide'); 368 | 369 | var files = document.getElementById('importFilePath').files; 370 | if (files.length <= 0) { 371 | return false; 372 | } 373 | 374 | filepath = files.item(0) // use the first one 375 | console.log(filepath) 376 | $("#btn-scenario-selector").bootstrapToggle('on'); // set to auto mode to avoid filling mean poses inside afterGeneratedCallback fn 377 | 378 | var fr = new FileReader(); 379 | fr.onload = function (e) { 380 | var json = JSON.parse(e.target.result); 381 | console.log(json) 382 | afterGeneratedCallback(json) 383 | } 384 | fr.readAsText(filepath); 385 | } 386 | 387 | function exportData() { 388 | if (lastGenerationResult) { 389 | // export filename 390 | var now = new Date(); 391 | var filename = now.toISOString().slice(0, 10) + '_' + now.getTime() 392 | 393 | // audio (.wav) 394 | var audioFilename = lastGenerationResult['audio-filename']; 395 | if (audioFilename) { 396 | var link = document.createElement("a"); 397 | link.download = filename + '.wav'; 398 | link.href = "media/" + audioFilename + "/" + link.download; 399 | link.click(); 400 | } 401 | 402 | // input and output data (.json) 403 | var dataStr = "data:text/json;charset=utf-8," + encodeURIComponent(JSON.stringify(lastGenerationResult)); 404 | var link = document.createElement('a'); 405 | link.href = dataStr 406 | link.download = filename + ".json" 407 | link.click() 408 | } else { 409 | bootbox.alert("Nothing to save. Please synthesize first."); 410 | } 411 | } 412 | 413 | function onAudioUploaded(r) { 414 | r = $.parseJSON(r); 415 | let filepath = r.filename[0]; 416 | uploadedAudioPath = filepath 417 | } 418 | -------------------------------------------------------------------------------- /static/js/motionLibrary.js: -------------------------------------------------------------------------------- 1 | class MotionLibrary { 2 | constructor(_avatar, timeline, open_close_btn_id, apply_btn_id, delete_btn_id) { 3 | var that = this; 4 | this.avatar = _avatar 5 | this.timeline = timeline 6 | this.curFrame = 0 7 | this.motionRepeat = 0 8 | this.openCloseBtn = $(open_close_btn_id); 9 | this.openCloseBtn.click(function () { 10 | var $btn = $(this); 11 | var $content = $(this).parent().next(".content"); 12 | $content.slideToggle(500, () => { 13 | if ($content.is(":visible")) { 14 | $btn.text("Close motion library"); 15 | that.loadMotionFromDB(); 16 | window.dispatchEvent(new Event('resize')); // to adjust avatar resolution 17 | } else { 18 | $btn.text("Open motion library"); 19 | if (that.timer) { 20 | clearInterval(that.timer) 21 | } 22 | } 23 | }); 24 | }); 25 | 26 | $('#motion-library-add button').click(function (e) { 27 | e.preventDefault(); // cancel form submit 28 | if ($(this).attr("value") == "add") { 29 | that.addMotionLibrary(); 30 | } else { // add using array data 31 | that.addMotionLibraryManual(); 32 | } 33 | }); 34 | 35 | $(apply_btn_id).click(function () { 36 | if (that.selectedIndex >= 0) { 37 | var motion = that.motions[that.selectedIndex].motion 38 | var motionSpeed = parseInt(document.querySelector('input[name="motion-speed"]:checked').value); 39 | var flip_lr = $("#btn-flip-left-right").is(':checked') 40 | that.applyMotion(motion, motionSpeed, flip_lr) 41 | } 42 | }); 43 | 44 | $("#btn-delete-library").click(function () { 45 | var modalDialog = $('#deleteLibraryModal') 46 | modalDialog.modal('hide'); 47 | var delete_id = modalDialog.data('delete-id'); 48 | get("api/delete_motion/" + delete_id, function (data) { 49 | console.log('motion deleted', data) 50 | that.loadMotionFromDB() 51 | }); 52 | }); 53 | /* 54 | $(delete_btn_id).click(function (){ 55 | if(that.selectedIndex >= 0){ 56 | var delete_id = that.motions[that.selectedIndex].id; 57 | var modalDialog = $('#deleteLibraryModal'); 58 | modalDialog.data('delete-id', delete_id); 59 | modalDialog.modal('show'); 60 | } 61 | }); 62 | */ 63 | this.motions = []; 64 | this.loadMotionFromDB(); 65 | this.selectedIndex = -1 66 | } 67 | 68 | loadMotionFromDB() { 69 | // loading indicator 70 | var list = $('#motion-library-list') 71 | list.empty() 72 | var li = document.createElement('li'); 73 | li.innerHTML = ' Loading...'; 74 | list.append(li); 75 | 76 | // load 77 | var that = this; 78 | that.motions = []; 79 | get("api/motion", function (responseText) { 80 | var json = JSON.parse(responseText); 81 | for (var i = 0; i < json.length; i++) { 82 | that.motions.push(new MotionItem(json[i]._id.$oid, json[i].name, json[i].motion)); 83 | } 84 | 85 | that.displayMotions(); 86 | }); 87 | } 88 | 89 | displayMotions() { 90 | var that = this; 91 | 92 | // remove existing motion list 93 | var list = $('#motion-library-list') 94 | list.empty() 95 | 96 | // create motion list 97 | that.selectedIndex = -1 98 | 99 | for (var i = 0; i < this.motions.length; i++) { 100 | var li = document.createElement('li'); 101 | li.innerHTML = this.motions[i].name + " (" + this.motions[i].motion.length + ")"; 102 | li.onclick = function () { 103 | var itemIdx = $(this).index(); 104 | that.selectedIndex = itemIdx 105 | that.avatar.gestureKeypoints = that.motions[itemIdx].motion; 106 | that.curFrame = 0; 107 | that.motionRepeat = 0; 108 | if (that.timer) { 109 | clearInterval(that.timer) 110 | } 111 | 112 | that.avatar.restPose(); 113 | that.timer = setInterval(function () { 114 | that.play(); 115 | }, 30); 116 | 117 | $('#motion-library-list .list-group-item').removeClass('active'); 118 | this.classList.add('active') 119 | }; 120 | li.classList.add('list-group-item') 121 | li.classList.add('col-sm-4') // multi-column 122 | list.append(li); 123 | } 124 | } 125 | 126 | play() { 127 | this.curFrame++; 128 | if (this.curFrame >= this.avatar.gestureKeypoints.length) { 129 | this.curFrame = 0; 130 | this.avatar.restPose(); 131 | this.motionRepeat++; 132 | 133 | if (this.motionRepeat >= 5) { 134 | clearInterval(this.timer) 135 | this.motionRepeat = 0; 136 | } 137 | } 138 | 139 | this.avatar.moveBody(this.avatar.gestureKeypoints[this.curFrame]); 140 | } 141 | 142 | addMotionLibrary() { 143 | // get selected region from motion track 144 | var motionTrack = this.timeline.motionTrack 145 | if (motionTrack) { 146 | var selectedMotion = motionTrack.getSelectedKeypoints() 147 | 148 | console.log(selectedMotion) 149 | 150 | var motionName = $("#text-motion-name").val(); 151 | if (motionName.length > 0) { 152 | var data = {"name": motionName, "motion": selectedMotion}; 153 | post("api/motion", data, (data) => { 154 | this.loadMotionFromDB() 155 | }); 156 | } 157 | } else { 158 | bootbox.alert("Please select motion region first!"); 159 | } 160 | } 161 | 162 | addMotionLibraryManual() { 163 | var that = this; 164 | bootbox.prompt({ 165 | title: "Please input raw motion array.", 166 | inputType: 'textarea', 167 | callback: function (arrStr) { 168 | var motionName = $("#text-motion-name").val(); 169 | if (motionName.length > 0 && arrStr != null) { 170 | var array = JSON.parse(arrStr); 171 | var data = {"name": motionName, "motion": array}; 172 | post("api/motion", data, (data) => { 173 | that.loadMotionFromDB() 174 | }); 175 | } 176 | } 177 | }); 178 | // bootbox.alert("!"); 179 | } 180 | 181 | applyMotion(motion, speed, flip_lr) { 182 | var sampled_data = null; 183 | if (speed == 2 || speed == 3) { 184 | var tmp = []; 185 | for (var i = 0; i < motion.length; i += speed) { 186 | tmp.push(motion[i]); 187 | } 188 | sampled_data = tmp 189 | } else { 190 | sampled_data = motion 191 | } 192 | 193 | if (flip_lr) { 194 | // copy array 195 | sampled_data = cloneGrid(sampled_data) 196 | 197 | // invert x coordinates 198 | let nFrames = sampled_data.length 199 | let dataLength = sampled_data[0].length 200 | for (var i = 0; i < nFrames; i++) { 201 | for (var j = 0; j < dataLength; j += 3) { 202 | sampled_data[i][j] *= -1; 203 | } 204 | } 205 | 206 | // switch left/right joints 207 | for (var i = 0; i < nFrames; i++) { 208 | for (var j = 9; j < 18; j++) { 209 | var val = sampled_data[i][j]; 210 | sampled_data[i][j] = sampled_data[i][j + 9]; 211 | sampled_data[i][j + 9] = val; 212 | } 213 | } 214 | } 215 | 216 | timeline.updateMotionTrackControlData(sampled_data); 217 | } 218 | } 219 | 220 | class MotionItem { 221 | constructor(id, name, motion) { 222 | this.id = id 223 | this.name = name; 224 | this.motion = motion; 225 | } 226 | } 227 | 228 | function cloneGrid(grid) { 229 | // function code from https://ozmoroz.com/2020/07/how-to-copy-array/ 230 | 231 | // Clone the 1st dimension (column) 232 | const newGrid = [...grid] 233 | // Clone each row 234 | newGrid.forEach((row, rowIndex) => newGrid[rowIndex] = [...row]) 235 | return newGrid 236 | } 237 | -------------------------------------------------------------------------------- /static/js/ruleManager.js: -------------------------------------------------------------------------------- 1 | class RuleManager { 2 | constructor(mainWindow, apply_btn_id, view_btn_id, modal_id) { 3 | var that = this 4 | this.mainWindow = mainWindow; 5 | this.modalId = modal_id; 6 | this.ruleDialog = null; 7 | this.rules = {}; // dictionary (key: word, value: list of (ruleid, motion_info)) 8 | this.words = [] 9 | 10 | $(apply_btn_id).click(() => { 11 | this.applyRules(); 12 | }); 13 | 14 | $(view_btn_id).click(() => { 15 | this.ruleDialog = $(this.modalId) 16 | this.updateRuleList(); 17 | this.updateMotionList(); 18 | this.ruleDialog.modal('show'); 19 | }); 20 | 21 | $('#rule-add').submit(function () { 22 | console.log('clicked add rule') 23 | $('#btn-add-rule').html(' Loading...'); 24 | $('#btn-add-rule').prop('disabled', true); 25 | that.addRule(); 26 | return false; 27 | }); 28 | 29 | $("#ruleTable").on('click', '.btnDelete', function () { 30 | $(this).closest('tr').remove() 31 | var deleteId = $(this).closest('tr').data('id') 32 | get("api/delete_rule/" + deleteId, function (data) { 33 | console.log('rule deleted', data); 34 | that.getRules(false); 35 | }); 36 | }); 37 | 38 | this.getRules(false); 39 | } 40 | 41 | applyRules() { 42 | var words = this.words; 43 | var timeline = this.mainWindow.timeline 44 | 45 | for (var i = 0; i < words.length; i++) { 46 | var word = words[i][0].toLowerCase(); 47 | var startTime = words[i][1]; 48 | var endTime = words[i][2]; 49 | 50 | console.log(word); 51 | 52 | if (word in this.rules) { 53 | // select randomly if the rules associated to the same word exist 54 | var nMotions = this.rules[word].length 55 | var randomIdx = Math.floor(Math.random() * nMotions) 56 | var motion = this.rules[word][randomIdx][1].motion 57 | timeline.motionTrack.updateMultipleKeypointsAsGroup(motion, (endTime + startTime) / 2.0, true) 58 | } 59 | } 60 | } 61 | 62 | updateRuleList() { 63 | var dialog = this.ruleDialog; 64 | dialog.find("#ruleTable > tbody").empty(); 65 | 66 | var rules = this.rules 67 | 68 | // sort by name 69 | var rules = Object.keys(rules).sort().reduce(function (Obj, key) { 70 | Obj[key] = rules[key]; 71 | return Obj; 72 | }, {}); 73 | 74 | // add rows 75 | for (var key in rules) { 76 | for (var i = 0; i < rules[key].length; i++) { 77 | var ruleId = rules[key][i][0]; 78 | var motionName = rules[key][i][1].name; 79 | var motion = rules[key][i][1].motion; 80 | dialog.find('#ruleTable > tbody:last-child').append('' + key + '' + motionName + '' + 'length: ' + motion.length + ''); 81 | } 82 | } 83 | } 84 | 85 | updateMotionList() { 86 | var selector = this.ruleDialog.find("#select-motion"); 87 | selector.empty(); 88 | var motions = this.mainWindow.motionLibrary.motions 89 | for (var i = 0; i < motions.length; i++) { 90 | selector.append(new Option(motions[i].name, motions[i].id)) 91 | } 92 | } 93 | 94 | getRules(updateList = false) { 95 | get("api/rule", (responseText) => { 96 | var items = JSON.parse(responseText); 97 | this.rules = {} 98 | for (var i = 0; i < items.length; i++) { 99 | if (!this.rules[items[i].word]) { 100 | this.rules[items[i].word] = []; 101 | } 102 | var ruleId = items[i]._id.$oid 103 | this.rules[items[i].word].push([ruleId, items[i].motion_info[0]]) 104 | } 105 | 106 | if (updateList) { 107 | this.updateRuleList(); 108 | } 109 | 110 | $('#btn-add-rule').html('Add rule'); 111 | $('#btn-add-rule').prop('disabled', false); 112 | }); 113 | } 114 | 115 | loadWords(wordsWithTimestamps) { 116 | this.words = wordsWithTimestamps 117 | } 118 | 119 | addRule() { 120 | var word = $("#rule-name").val(); 121 | if (word.length > 0) { 122 | var motionId = $('#select-motion option:selected').val(); 123 | var data = {"word": word, "description": word, "motion": motionId}; 124 | post("api/rule", data, (data) => { 125 | this.getRules(true); 126 | }); 127 | } 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /static/js/sortedlist.js: -------------------------------------------------------------------------------- 1 | class Node { 2 | constructor(item, prev, next) { 3 | this.item = item; 4 | this.prev = prev; 5 | this.next = next; 6 | } 7 | } 8 | 9 | class SortedLinkedList { 10 | constructor(sortAttr) { 11 | this.head = new Node(null, null); 12 | this.tail = new Node(this.head, null); 13 | this.head.next = this.tail; 14 | 15 | this.sortAttr = sortAttr; 16 | } 17 | 18 | isEmpty() { 19 | return this.head.next == this.tail; 20 | } 21 | 22 | add(node) { 23 | var prevNode = this.findLastNodeBefore(node); 24 | this.addAfter(prevNode, node); 25 | } 26 | 27 | findLastNodeBefore(node) { 28 | if (this.isEmpty()) { 29 | return this.head; 30 | } 31 | var val = node.item[this.sortAttr]; 32 | var cursor = this.head.next; 33 | while (cursor != this.tail) { 34 | var curVal = cursor.item[this.sortAttr]; 35 | if (curVal > val) { 36 | break; 37 | } 38 | cursor = cursor.next; 39 | } 40 | 41 | return cursor.prev; 42 | } 43 | 44 | addAfter(node, nodeToAdd) { 45 | // update forward link; 46 | nodeToAdd.next = node.next; 47 | node.next = nodeToAdd; 48 | 49 | // update backward link; 50 | nodeToAdd.prev = node; 51 | nodeToAdd.next.prev = nodeToAdd; 52 | } 53 | 54 | remove(node) { 55 | if (node == this.head || node == this.tail) { 56 | // head and tail cannot be removed; 57 | return; 58 | } 59 | console.log(node); 60 | 61 | node.prev.next = node.next; 62 | node.next.prev = node.prev; 63 | 64 | node.next = null; 65 | node.prev = null; 66 | 67 | console.log(this.head); 68 | } 69 | 70 | getItems() { 71 | var items = []; 72 | var cursor = this.head.next; 73 | while (cursor != this.tail) { 74 | items.push(cursor.item); 75 | cursor = cursor.next; 76 | } 77 | return items; 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /static/js/stylePannel.js: -------------------------------------------------------------------------------- 1 | class SliderAndLabel { 2 | constructor(sliderDivId, setting, labelDivId) { 3 | this.sliderDivId = sliderDivId; 4 | this.slider = new Slider(sliderDivId, setting); 5 | this.labelDivId = labelDivId; 6 | } 7 | 8 | updateLabel(value) { 9 | $(this.labelDivId).text(value); 10 | } 11 | 12 | update(value) { 13 | this.slider.setValue(value); 14 | this.updateLabel(value); 15 | } 16 | 17 | setSliderUpdateCallback(callback) { 18 | $(this.sliderDivId).on("slide", callback); 19 | } 20 | 21 | getValue() { 22 | return this.slider.getValue(); 23 | } 24 | } 25 | 26 | class StylePannel { 27 | constructor() { 28 | this.sliders = {} 29 | this.sliderUpdateCallback = null; 30 | this.styleNames = [] 31 | } 32 | 33 | addStyle(name, sliderDivId, sliderValueLabelId) { 34 | var slider = this.createSlider(name, sliderDivId, sliderValueLabelId); 35 | this.sliders[name] = slider; 36 | this.styleNames.push(name); 37 | } 38 | 39 | addStylePreset(name, btnId, styleVal) { 40 | var that = this 41 | $(btnId).click(function () { 42 | var data = styleVal 43 | that.setValues(data) 44 | if (that.sliderUpdateCallback != null) { 45 | var copiedData = {}; // use copied data because it is altered in sliderUpdateCallback fn 46 | Object.assign(copiedData, data); 47 | that.sliderUpdateCallback(copiedData); 48 | } 49 | }); 50 | } 51 | 52 | createSlider(styleName, divId, valueLabelId) { 53 | var that = this; 54 | 55 | var setting = {min: -3, max: 3, step: 0.1, value: 0}; 56 | var sliderAndLabel = new SliderAndLabel(divId, setting, valueLabelId); 57 | 58 | sliderAndLabel.setSliderUpdateCallback(function (slideEvt) { 59 | sliderAndLabel.updateLabel(slideEvt.value); 60 | if (that.sliderUpdateCallback != null) { 61 | var data = {}; 62 | for (var name of that.styleNames) { 63 | data[name] = that.sliders[name].getValue(); 64 | } 65 | data[styleName] = slideEvt.value; 66 | that.sliderUpdateCallback(data); 67 | } 68 | }); 69 | return sliderAndLabel; 70 | } 71 | 72 | setSliderUpdateCallback(callback) { 73 | this.sliderUpdateCallback = callback; 74 | } 75 | 76 | getStyleNames() { 77 | return this.styleNames; 78 | } 79 | 80 | setValues(data) { 81 | for (var key in data) { 82 | if (key in this.sliders) { 83 | this.sliders[key].update(data[key]); 84 | } 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /static/js/timeline.js: -------------------------------------------------------------------------------- 1 | class Timeline { 2 | constructor(avatar, stylePanel, audio_div_id, wordCanvasId, motionControlContainerId, styleControlContainerId) { 3 | this.avatar = avatar; 4 | this.stylePanel = stylePanel; 5 | this.audio_div_id = audio_div_id; 6 | this.audioTrack = createAudioTrack(audio_div_id); 7 | this.wordCanvasId = wordCanvasId; 8 | this.motionControlContainerId = motionControlContainerId; 9 | this.styleControlContainerId = styleControlContainerId; 10 | this.historyManager = new HistoryManager(); 11 | 12 | this.syncTracksOnPlayback(); 13 | this.syncTracksOnSeek(); 14 | this.controlDataModificationInvalidWarningId = null; 15 | } 16 | 17 | setCursorToStart() { 18 | this.audioTrack.seekTo(0); 19 | } 20 | 21 | load(filename, words, keypoints, styleNames) { 22 | var that = this; 23 | this.audioTrack.loadAudio(filename); 24 | this.audioTrack.wavesurfer.once('ready', function () { 25 | that.createControlTracks(keypoints, styleNames); 26 | 27 | function showWarning() { 28 | $(that.controlDataModificationInvalidWarningId).modal('show'); 29 | } 30 | 31 | that.motionTrack.setCellGroupCreationInvalidCallback(showWarning); 32 | that.styleTrack.setCellGroupCreationInvalidCallback(showWarning); 33 | 34 | that.displayWords(words); 35 | }); 36 | } 37 | 38 | displayWords(words) { 39 | var canvas = document.getElementById(this.wordCanvasId); 40 | canvas.width = canvas.clientWidth * 2; 41 | canvas.height = canvas.clientHeight * 2; 42 | var ctx = canvas.getContext('2d'); 43 | // ctx.fillStyle = "gray"; 44 | // ctx.fillRect(0, 0, canvas.width, canvas.height); 45 | 46 | ctx.fillStyle = 'black'; 47 | ctx.textAlign = 'center'; 48 | 49 | if (canvas.clientWidth < 480) { 50 | ctx.font = '14px sans-serif'; 51 | } else if (canvas.clientWidth < 768) { 52 | ctx.font = '18px sans-serif'; 53 | } else { 54 | ctx.font = '25px sans-serif'; 55 | } 56 | 57 | let audioDuration = this.audioTrack.getDuration(); 58 | var wordPosition = 0; // in px 59 | var wordWidth = 0; 60 | var y = 40; 61 | for (var word of words) { 62 | wordWidth = ctx.measureText(word[0]).width; 63 | wordPosition = (word[1] + word[2]) * 0.5 / audioDuration * canvas.width; 64 | 65 | ctx.save(); 66 | ctx.translate(wordPosition, y) 67 | ctx.rotate(-Math.PI / 4); 68 | ctx.fillText(word[0], 0, y / 2); 69 | ctx.restore(); 70 | } 71 | } 72 | 73 | createControlTracks(keypoints, styleNames) { 74 | this.createMotionTrack(keypoints); 75 | this.createStyleTrack(keypoints.length, styleNames); 76 | 77 | this.syncAvatarWithCurrentCursor(); 78 | this.setCellSelectedCallbacks(); 79 | } 80 | 81 | syncAvatarWithCurrentCursor() { 82 | this.motionTrack.syncAvatarWithCursor(this.avatar); 83 | } 84 | 85 | syncStylePanelWithCurrentCursor() { 86 | this.styleTrack.syncStylePanelWithCursor(this.stylePanel); 87 | } 88 | 89 | createMotionTrack(keypoints) { 90 | var setting = { 91 | container: this.motionControlContainerId, 92 | numFrames: keypoints.length, 93 | duration: this.audioTrack.getDuration(), 94 | }; 95 | 96 | if (this.motionTrack != null) { 97 | this.motionTrack.destruct(); 98 | } 99 | this.motionTrack = new MotionCellTrack(setting, keypoints); 100 | this.motionTrack.seekTo(0); 101 | } 102 | 103 | createStyleTrack(nframes, styleNames) { 104 | var setting = { 105 | container: this.styleControlContainerId, 106 | numFrames: nframes, 107 | duration: this.audioTrack.getDuration(), 108 | }; 109 | 110 | if (this.styleTrack != null) { 111 | this.styleTrack.destruct(); 112 | } 113 | 114 | this.styleTrack = new StyleCellTrack(setting, styleNames); 115 | } 116 | 117 | updateMotionKeypoitns(keypoints) { 118 | this.motionTrack.updateBaseKeypoints(keypoints); 119 | } 120 | 121 | syncTracksOnPlayback() { 122 | this.audioTrack.wavesurfer.on('audioprocess', function (currentTime) { 123 | this.motionTrack.seekToIfOneCellSelected(currentTime); 124 | this.styleTrack.seekToIfOneCellSelected(currentTime); 125 | this.syncAvatarWithCurrentCursor(); 126 | }.bind(this)); 127 | } 128 | 129 | addCallbackOn(action, func) { 130 | this.audioTrack.wavesurfer.on(action, func); 131 | } 132 | 133 | getCurrentTime() { 134 | return this.audioTrack.getCurrentTime(); 135 | } 136 | 137 | getDuration() { 138 | return this.audioTrack.getDuration(); 139 | } 140 | 141 | syncTracksOnSeek() { 142 | this.audioTrack.wavesurfer.on('seek', function (position) { 143 | var time = position * this.audioTrack.getDuration(); 144 | this.motionTrack.seekToIfOneCellSelected(time); 145 | this.styleTrack.seekToIfOneCellSelected(time); 146 | this.syncAvatarWithCurrentCursor(); 147 | this.syncStylePanelWithCurrentCursor(); 148 | }.bind(this)); 149 | } 150 | 151 | play() { 152 | this.styleTrack.deselectAll(); 153 | this.motionTrack.seekTo(0); 154 | this.avatar.restPose(); 155 | this.audioTrack.play(); 156 | } 157 | 158 | pause() { 159 | this.audioTrack.pause(); 160 | } 161 | 162 | isPlaying() { 163 | return this.audioTrack.isPlaying(); 164 | } 165 | 166 | setCellSelectedCallbacks() { 167 | 168 | function seekTrackToCellMiddle(track, cell) { 169 | var st = cell.getData('start-time'); 170 | var et = cell.getData('end-time'); 171 | var progress = track.timeToProgress((st + et) / 2); 172 | track.seekTo(progress); 173 | }; 174 | 175 | this.motionTrack.addSelectedCallback(function (selectedCell) { 176 | if (selectedCell != null) { 177 | this.styleTrack.deselectAll(); 178 | seekTrackToCellMiddle(this.audioTrack, selectedCell); 179 | } 180 | }.bind(this)); 181 | 182 | this.styleTrack.addSelectedCallback(function (selectedCell) { 183 | if (selectedCell != null) { 184 | this.motionTrack.deselectAll(); 185 | seekTrackToCellMiddle(this.audioTrack, selectedCell); 186 | } 187 | }.bind(this)); 188 | } 189 | 190 | getMotionConstraints() { 191 | return this.motionTrack.getKeypointsConstraint(); 192 | } 193 | 194 | getStyleConstraints() { 195 | return this.styleTrack.getStyleConstraints(); 196 | } 197 | 198 | deleteSelectedCellControlData() { 199 | this.motionTrack.deleteSelectedCellControlData(); 200 | this.styleTrack.deleteSelectedCellControlData(); 201 | this.syncAvatarWithCurrentCursor(); 202 | this.syncStylePanelWithCurrentCursor(); 203 | } 204 | 205 | toggleMotionTrackControlDataBypass(isBypass) { 206 | this.motionTrack.isBypassControlData = isBypass; 207 | this.syncAvatarWithCurrentCursor(); 208 | } 209 | 210 | updateMotionTrackControlData(keypoints) { 211 | this.motionTrack.updateMultipleKeypointsAsGroup(keypoints); 212 | this.syncAvatarWithCurrentCursor(); 213 | } 214 | 215 | updateStyleTrackControlData(data) { 216 | if (this.styleTrack == null) { 217 | return; 218 | } 219 | this.styleTrack.updateStyleControlForSelectedCells(data); 220 | } 221 | 222 | setPlayCallback(func) { 223 | this.audioTrack.wavesurfer.on('play', func); 224 | } 225 | 226 | setPauseCallback(func) { 227 | this.audioTrack.wavesurfer.on('pause', func); 228 | } 229 | 230 | isAnyCellModified() { 231 | if (this.motionTrack == null || this.styleTrack == null) { 232 | return false; 233 | } 234 | return this.motionTrack.isAnyCellModified() || this.styleTrack.isAnyCellModified(); 235 | } 236 | 237 | fillMotionControl() { 238 | this.motionTrack.interpolateTwoPoses(); 239 | } 240 | 241 | copySelectedCellGroup() { 242 | if (this.motionTrack != null && this.motionTrack.isAnyCellSelected()) { 243 | this.motionTrack.copyGroup(); 244 | } else if (this.styleTrack != null && this.styleTrack.isAnyCellSelected()) { 245 | this.styleTrack.copyGroup(); 246 | } 247 | } 248 | 249 | pasteSelectedCellGroup() { 250 | if (this.motionTrack != null && this.motionTrack.isAnyCellSelected()) { 251 | this.motionTrack.pasteGroup(); 252 | } else if (this.styleTrack != null && this.styleTrack.isAnyCellSelected()) { 253 | this.styleTrack.pasteGroup(); 254 | } 255 | } 256 | 257 | saveTracksToHistory() { 258 | this.historyManager.addSnapshots(this.motionTrack, this.styleTrack); 259 | } 260 | 261 | clearHistory() { 262 | this.historyManager.clearHistory(); 263 | } 264 | 265 | undoTracks() { 266 | this.historyManager.undo(this.motionTrack, this.styleTrack); 267 | } 268 | 269 | redoTracks() { 270 | this.historyManager.redo(this.motionTrack, this.styleTrack); 271 | } 272 | 273 | canUndo() { 274 | return this.historyManager.canUndo(); 275 | } 276 | 277 | canRedo() { 278 | return this.historyManager.canRedo(); 279 | } 280 | 281 | } 282 | 283 | function createAudioTrack(div_id) { 284 | var setting = { 285 | container: div_id, 286 | waveColor: 'tomato', 287 | progressColor: 'red', 288 | cursorColor: 'red', 289 | cursorWidth: 2, 290 | height: 128, 291 | barWidth: 2, 292 | barHeight: 1.5, 293 | barMinHeight: 0.1, 294 | barGap: null, 295 | responsive: true, 296 | hideScrollbar: true 297 | } 298 | return new Track(setting); 299 | } 300 | -------------------------------------------------------------------------------- /static/js/track.js: -------------------------------------------------------------------------------- 1 | class Track { 2 | constructor(setting) { 3 | this.wavesurfer = WaveSurfer.create(setting); 4 | } 5 | 6 | getWaveSurfer() { 7 | return this.wavesurfer; 8 | } 9 | 10 | loadAudio(filename) { 11 | this.wavesurfer.load("media/" + filename + "/temp"); 12 | } 13 | 14 | getCurrentTime() { 15 | return this.wavesurfer.getCurrentTime(); 16 | } 17 | 18 | play() { 19 | this.wavesurfer.play(); 20 | } 21 | 22 | pause() { 23 | this.wavesurfer.pause(); 24 | } 25 | 26 | getDuration() { 27 | return this.wavesurfer.getDuration(); 28 | } 29 | 30 | getPosition() { 31 | return this.getCurrentTime() / this.getDuration(); 32 | } 33 | 34 | seekTo(position) { 35 | this.wavesurfer.seekTo(position); 36 | } 37 | 38 | timeToProgress(time) { 39 | return time / this.wavesurfer.getDuration(); 40 | } 41 | 42 | isPlaying() { 43 | return this.wavesurfer.isPlaying(); 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /static/js/util.js: -------------------------------------------------------------------------------- 1 | function zerosLike(array) { 2 | zeros = []; 3 | for (var a of array) { 4 | zeros.push(0); 5 | } 6 | return zeros; 7 | } 8 | 9 | function deepCopyArray(array) { 10 | var newArray = []; 11 | for (var a of array) { 12 | if (Array.isArray(a)) { 13 | newArray.push(deepCopyArray(a)); 14 | } else if (a.constructor == Object) { 15 | newArray.push(deepCopyDict(a)) 16 | } else { 17 | newArray.push(a); 18 | } 19 | } 20 | return newArray; 21 | } 22 | 23 | function deepCopyDict(data) { 24 | var newData = {}; 25 | for (var key in data) { 26 | var d = data[key]; 27 | var copied = null; 28 | if (Array.isArray(d)) { 29 | copied = deepCopyArray(d); 30 | } else if (d.constructor == Object) { // check if dictionary 31 | copied = deepCopyDict(d); 32 | } else { 33 | copied = d; 34 | } 35 | newData[key] = d; 36 | } 37 | return newData; 38 | } 39 | -------------------------------------------------------------------------------- /static/mesh/mannequin/Ch36_1001_Diffuse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/static/mesh/mannequin/Ch36_1001_Diffuse.png -------------------------------------------------------------------------------- /static/mesh/mannequin/Ch36_1001_Glossiness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/static/mesh/mannequin/Ch36_1001_Glossiness.png -------------------------------------------------------------------------------- /static/mesh/mannequin/Ch36_1001_Normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/static/mesh/mannequin/Ch36_1001_Normal.png -------------------------------------------------------------------------------- /static/mesh/mannequin/Ch36_1001_Specular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/static/mesh/mannequin/Ch36_1001_Specular.png -------------------------------------------------------------------------------- /static/screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4r/SGToolkit/0c11a54ee64a7b29f4cba79a9d9ff3ae48e3af4e/static/screenshot.jpg -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | SG Toolkit 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 54 | 55 | 56 | 63 |
64 |
65 |

Text Input

66 | 74 | 75 |

76 |

Select Voice
77 |
78 | 81 | 84 | 87 |
88 |

89 | 92 |
93 | 94 | 95 | 96 | 97 | 98 |
99 |
100 | 101 | 102 | 103 | 104 | 107 |
108 |

109 | 110 |
111 |
112 |

Audio Output

113 |
114 |
115 |
116 | 117 |
118 |
119 |
120 | 121 |
122 |

Motion Control Track

123 |
124 |
125 |
126 | 127 |
128 |

Style Control Track

129 |
130 |
131 | 132 |
133 |
134 |

Style Control

135 |
136 |

Speed

137 | 138 | Value: 0 139 |
140 |
141 |

Space

142 | 143 | Value: 0 144 |
145 |
146 |

Handedness (+: left, -: right)

147 | 148 | Value: 0 149 |
150 |
151 |
152 |

Preset

153 |

154 | 155 | 156 | 157 |

158 |
159 |
160 |
161 |
162 |

Gesture Output & Control

163 |
164 | 167 | 170 | 173 |
174 |
175 | 176 |
177 |
178 |
179 |
180 | 181 |
182 |
183 |
184 | 185 |
186 | Flip L&R 187 |
188 | 191 | 194 | 197 |
198 | 199 |
200 |
201 |
202 |
203 |
204 |
205 | 206 | 207 |
208 |

209 | 210 | 211 |

212 |
213 | 214 |
215 |
216 |
217 |
    218 |
    219 |
    220 |
    221 |
    222 | 223 | 224 | {%include "modal/control_data_delete_modal.html"%} 225 | {%include "modal/motion_library_delete_modal.html"%} 226 | {%include "modal/view_rule_modal.html"%} 227 | {%include "modal/loading_progress_modal.html"%} 228 | {%include "modal/cell_control_data_modification_invalid_warning.html"%} 229 | {%include "modal/help_dialog_modal.html"%} 230 | {%include "modal/generate_modal.html"%} 231 | {%include "modal/import_json.html"%} 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /templates/modal/cell_control_data_modification_invalid_warning.html: -------------------------------------------------------------------------------- 1 | 21 | -------------------------------------------------------------------------------- /templates/modal/control_data_delete_modal.html: -------------------------------------------------------------------------------- 1 | 20 | -------------------------------------------------------------------------------- /templates/modal/generate_modal.html: -------------------------------------------------------------------------------- 1 | 20 | -------------------------------------------------------------------------------- /templates/modal/help_dialog_modal.html: -------------------------------------------------------------------------------- 1 | 26 | -------------------------------------------------------------------------------- /templates/modal/import_json.html: -------------------------------------------------------------------------------- 1 | 20 | -------------------------------------------------------------------------------- /templates/modal/loading_progress_modal.html: -------------------------------------------------------------------------------- 1 | 10 | -------------------------------------------------------------------------------- /templates/modal/motion_library_delete_modal.html: -------------------------------------------------------------------------------- 1 | 20 | -------------------------------------------------------------------------------- /templates/modal/view_rule_modal.html: -------------------------------------------------------------------------------- 1 | 44 | -------------------------------------------------------------------------------- /templates/sample_text.html: -------------------------------------------------------------------------------- 1 | 01. nights to get those tickets what was that doing that was eliminating 98 percent of the population from even considering going to it so we the mobile unit and took shakespeare to prisons to homeless shelters to community centers in all five boroughs and even in new jersey and westchester county and that program 2 | 02. well meeting bandura was really cathartic for me because i realized that this famous scientist had documented and scientifically validated something that weve seen happen for the last 30 years that we could take people who had the fear that they werent creative and we could take them through a series of steps 3 | 03. so we have a choice we can either choose to start to take climate change seriously and significantly cut and mitigate our greenhouse gas emissions and then we will have to adapt to less of the climate change impacts in future alternatively we can continue to really ignore the climate change problem 4 | 04. africa he stumbled across an extraordinary complex a complex of abandoned stone buildings and he never quite recovered from what he saw a granite drystone city stranded on an outcrop above an empty savannah great zimbabwe 5 | 05. and were rejected from the communities that we loved because we wanted to make them better because we believed that they could be and i began to expect this reaction from my own people i know what it feels like when you feel like someones trying to change you or criticize you 6 | 06. their beauty is missed because theyre so omnipresent so i dont know commonplace that people dont notice them they dont notice the beauty but they dont even notice the clouds unless they get in the way of the sun and so people think of as things that get in the way 7 | 07. fine jocelyn is doing the shopping tonys doing the gardening melissa and joe are going to come and cook and chat so five circle members had organized themselves to take care of belinda and 80 although she says that she feels 25 inside but she also says that she felt stuck and pretty down when she joined circle 8 | 08. well meaning kenyans came up to those of us in the temperate world and said you know you people have a lot of cold and flu weve designed this great easy to use cheap tool were going to give it to you for free its called a face mask and all you need to do is wear it every day during cold and flu season when you go to school and when you go to work would we do that 9 | 09. speaking of physical connection you guys know i love the hormone oxytocin you release oxytocin you feel bonded to everyone in the room you guys know that the best way to release oxytocin quickly is to hold someone elses hand for at least six seconds you guys were all holding hands for way more than six seconds so we are all now biochemically to love 10 | 10. in terms of space this is a continent in the same way that in north america the rocky mountains everglades and great lakes regions are very distinct so are the subsurface regions of antarctica and in terms of time we now know that ice sheets not only evolve over the of millennia and centuries but theyre also changing over the scale of years and days 11 | 11. this is perhaps the glue that holds all these conditions together the concept is that we speak in the exact same manner about someone whos not in the room as if they are in the room now this seems basic but its an aspirational practice 12 | 12. an excellent way for staying ahead of the reality curve to make possible today what science will make a reality tomorrow as a cyber magician i combine elements of illusion and science to give us a feel of how future technologies might be experienced youve probably all heard of googles project glass its new technology you look through them and the world 13 | -------------------------------------------------------------------------------- /waitress_server.py: -------------------------------------------------------------------------------- 1 | from waitress import serve 2 | import app 3 | 4 | serve(app.app, host='0.0.0.0', port=8080) 5 | --------------------------------------------------------------------------------