├── .gitignore ├── LICENSE ├── README.md ├── anim_data └── __init__.py ├── install_maya.py ├── install_venv.bat ├── predict_data └── __init__.py ├── requirements.txt ├── system ├── __init__.py ├── apply_prediction.py ├── generate_train_data.py ├── gpr_model.py ├── gpr_predict.py ├── irm_ui.py ├── prep_anim_data.py └── train_model.py ├── trained_model └── __init__.py ├── training_data ├── __init__.py ├── jnt │ └── __init__.py └── rig │ └── __init__.py └── utils ├── __init__.py ├── data_gen.py ├── maya.py ├── pytorch.py └── ui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | # IDE settings 106 | .vscode/ 107 | 108 | # csv files 109 | *.csv 110 | 111 | # pytorch models 112 | *.pt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | MIT License 13 | 14 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inverse Rig Mapping 2 | 3 | ## 1. Introduction 4 | Inverse Rig Mapping is a toolset for Maya to map skeletal animation data e.g. mocap data back to arbitrary character rigs using machine learning. It is based on the paper "Learning an inverse rig mapping for character animation". The machine learning model learns the correlation between the rig control parameters and the dependent joint parameters. After learning, it can predict these rig control parameters based on the skeletal animation data. 5 | 6 | The common workflow for mapping skeletal animation back to rigs is to use retargeting tools like Maya's HumanIK, constraint setups, or creating custom scripts that work for the specific character rig. All of these workflows have some sort of limitation, are often time consuming and not flexible for different rigs. The Inverse Rig Mapping (IRM) tool aims to solve this problem by providing a flexible and easy-to-use workflow that works with any rig - at least as long as the model can understand the correlation between rig control and joint parameters. Currently the tool only works with simpler rigs - see the limitations section below for more details. 7 | 8 | 9 | 10 | ## 2. Installation 11 | The toolset has been developed and tested on Windows 10/11 using Autodesk Maya 2022.4, 2023.3 and 2024.0.1 and Python 3.10.11. 12 | 13 | 14 | Place the Inverse Rig Mapping folder in a location of your choice and run the "install_venv.bat" file. This will create a new virtual Python environment in the IRM folder (make sure you run it in that folder), install all the necessary Python libraries and set an environment variable that is essential for the Maya installation. Make sure that a similar Python version (3.10.11 and above) is installed and that "python.exe" is registered as "python" command. 15 | 16 | The machine learning part uses its own Python environment - different from the one Maya uses because it needs libraries that aren't available in Maya. You can use your own Python environment if you like, but make sure you have all the libraries listed in the requirements.txt file installed. 17 | 18 | 19 | You can use your NVIDIA GPU to improve performance when using this tool. This uses the CUDA Toolkit, which needs to be installed first. You can download the latest one here: https://developer.nvidia.com/cuda-downloads 20 | 21 | The IRM tool will dynamically use CUDA/GPU if it is available on your workstation. Otherwise it will switch to CPU - you can also force CPU in the train settings. 22 | 23 | More information on using CUDA with PyTorch can be found here - especially if you have an older version of CUDA or a different OS: https://pytorch.org/get-started/locally/ 24 | 25 | 26 | Last but not least, install the tool in Maya: open the file "install_maya.py" in the Maya script editor and run it. This should create a shelf tool called "IRM" in the currently visible shelf. When you click on this button, the tool window will appear and you are ready to use the tool. 27 | 28 | 29 | 30 | 31 | ## 3. Workflow & Settings 32 | ### 3.1 Data Generation 33 | #### 3.1.1 Train Parameters 34 | To get started, you need data that will later be used to train the machine learning model. There are two separate areas in the interface for adding the control rig and the joint parameters. Simply select any control in your rig (NURBS curves/surfaces and meshes are accepted) and press the 'Add' button. This will add every keyable and scalar attribute of the selected control to the tree view. You can remove specific attributes by right-clicking on them and selecting 'Delete' - this also works with multi-selection. You can reset the Parameter UI using the 'Clear All' button. 35 | 36 | The same applies to the joint parameters at the bottom. Simply select all the joints in your rig and press the 'Add' button to include every joint attribute - only joints are accepted. Any attribute that has incoming connections and is therefore influenced by other nodes will be added to the tree view. You can also delete certain joint attributes to exclude them from training. Your UI should now look something like this: 37 | 38 | The order of the parameters doesn't matter, as they will be reordered when the training data is generated to ensure the same order every time, regardless of the selection. 39 | 40 | 41 | #### 3.1.2 Min/Max Range 42 | For a better prediction of the animation later on, it is important to set correct and plausible ranges for each rig parameter. Only this range will be used for training and therefore the model can only predict correct values for this range later on - so if your maximum range for translateX is 50, but your expected predicted value would be 100, the model will have a hard time predicting this as it hasn't been trained for this range. You can change the range of a single or multiple parameters by selecting them and using the "Minimum" and "Maximum" fields on the right hand side of the "Generator Settings". For rotation, you don't need to go beyond 180, as a range of -180 to 180 is the full possible range for rotations. 43 | 44 | IMPORTANT: In general, only the given parameters are used for the prediction part, so if you exclude e.g. translateX, it won't be predicted later. The same goes for the attribute ranges you set - if the skeletal animation later goes beyond that range, it will have a hard time predicting the rig control values. Getting this right is vital for the best and most efficient result. 45 | 46 | 47 | #### 3.1.3 Number of Poses 48 | The number of poses defines the number of random steps between the min/max range, so 1000 means there will be 1000 different steps between the min and max value. In general the number of poses should be high enough to get a sufficient training result, more poses and therefore more data can greatly improve the training but will also increase the training time and the amount of memory needed which may be too much for the current workstation. 1000 poses is a good starting point and you can slowly go up to 5000 if you like - it always depends on the rig and the amount of parameters. 49 | 50 | 51 | #### 3.1.4 Output Files 52 | The last step is to specify a path where the generated data will be stored. By default this will be "IRM_folder/training_data/jnt" and "IRM_folder/training_data/jnt" - feel free to use these default paths. 53 | 54 | 55 | #### 3.1.5 Generate Train Data 56 | Once everything is set up, all you need to do is press the 'Generate Train Data' button. This will take a moment, depending on the amount of train parameters and the number of poses. In general, it should only take a few seconds to a few minutes. When it is finished, you will find the generated rig and joint data as CSV files in the directories provided. 57 | 58 | 59 | 60 | 61 | ### 3.2 Model Training 62 | #### 3.2.1 Learning Rate 63 | The learning rate describes how fast your model is learning, essentially describing the steps through the data. If it is too high, the model may converge too quickly, potentially missing important patterns in the data. If it's too low, the model may take too long to learn and/or get stuck. Typical learning rates are between 0.1 and 0.0001. 64 | 65 | 66 | #### 3.2.2 Epochs 67 | This is the number of times the entire dataset is run through the model during training. If you use a small number, your model may not learn everything it needs to know (underfitting). If you use a large number, your model may start to memorise the data (overfitting), and it won't be good at making predictions about data it hasn't seen before. The optimal number of epochs is problem-specific, but a typical range is between 10 and 1000. 68 | 69 | 70 | #### 3.2.3 Force CPU 71 | This option forces the model to use the CPU instead of the GPU for computation. This can be useful if you are experiencing GPU limitations or don't want to use GPU acceleration. 72 | 73 | 74 | #### 3.2.3 Python Exe 75 | This is the path to the Python executable on your machine. The model will use this Python environment to run the script. Make sure you have all the necessary libraries installed - see above under "Installation" for more details. 76 | 77 | 78 | #### 3.2.4 Output Settings 79 | Here you need to specify the control rig and the joint CSV file containing the previously generated data. You will also need to specify the path where the trained model will be saved as a PyTorch (PT) file, so that you can make predictions later without having to train the model again. 80 | 81 | 82 | #### 3.2.5 Model Training 83 | Once you have adjusted the settings and defined the paths, you are ready to train the model by pressing the 'Train Model' button. 84 | 85 | This is the process of feeding your data to the model and allowing it to learn from it. During training, the model attempts to minimise its prediction error (loss) through an iterative process of adjusting its internal parameters. The model adjusts these parameters based on the learning rate and the number of epochs. 86 | 87 | During training, the goal is to minimise this loss value. This means adjusting the model parameters (training data, learning rate, epochs) so that the difference between the predicted and actual values is as small as possible. A decrease in the loss value over epochs usually indicates that the model is learning and improving its predictions. However, if the loss stops decreasing or increases, it may indicate problems such as overfitting or that the model has reached its capacity for this data. 88 | 89 | You can find the current loss and epoch in the progress bar window that pops up during training. 90 | 91 | 92 | 93 | 94 | ### 3.3 Prediction / Inference 95 | #### 3.3.1 Animation Parameters 96 | Similarly to the train data, you first need to specify all the joints (only joints are accepted) of the animated skeleton that will be mapped back to the control rig. To do this, select all the animated joints and add them using the 'Add' button. Again, check the animation parameters listed - all attributes with keys will be added automatically. You can delete them again by right-clicking or using the 'Clear All' button. 97 | 98 | These parameters must be the same as those used for the joint training data - this means that every joint used for the training data must also be added here, otherwise the prediction will not work! 99 | 100 | The names of the animated joints should be close to the names of the trained joints - this is important to get the same order of parameters as in the training data. 101 | 102 | 103 | #### 3.3.2 Train Data and Trained Model 104 | These are the file paths for the training data and the trained model respectively. The training data is the control rig and joint dataset that was used to train the model. The trained model is the result of the training process, a file that contains all the weights and biases that the model has learned. This file is used to load the model for prediction. You have specified these paths in the training process. 105 | 106 | 107 | #### 3.3.3 Python Exe 108 | This is the path to the Python executable on your machine. The model will use this Python environment to run the script. Make sure you have all the necessary libraries installed - see above under "Installation" for more details. 109 | 110 | 111 | #### 3.3.4 Mapping Prediciton 112 | Once everything is set up, you can start mapping the animation of the animated skeleton back to the rig by pressing the 'Map Prediction' button. 113 | 114 | This is a multi-step process. First it will collect the animation data from the joints provided and modify it to work with the model. Then it will use the trained model to predict the animation values of the control rig and finally it will apply these predicted values to the rig. In the end, the animation of the control rig should closely match that of the animated joints. If not, you may need to change some settings in the training process to ensure a better prediction result. 115 | 116 | Note: The mapping of values to the control is limited to the namespace provided in the training process. So if your control had the name "arm_L_wrist_ctrl" in the training process, then this control must have the same name in the prediction process. 117 | 118 | Depending on the complexity of the rig, the predicted result can be quite far from the animated skeleton, or even static. This is a known limitation and bug. 119 | 120 | As a workaround, you can split the training into different parts to reduce the complexity of the data - so only learn and map the leg, arm or spine one at a time. 121 | 122 | 123 | 124 | 125 | ## 4. Limitations and Bugs 126 | - The predicted rig animation can be quite different from the skeletal one - changing the number of poses, learning rate or epochs can improve the mapping. 127 | - A large number of poses and/or many control rig and joint parameters require a lot of memory, sometimes too much for the workstation. 128 | - Mapping skeletal data back to rigs only works with the rig control names that were used in training - so it won't work if you want to map it to rigs with different namespaces to those used in training. 129 | - Training complex character rigs with many parameters can result in a static mapping - (almost) the same values for every frame. 130 | 131 | 132 | 133 | 134 | ## 5. References 135 | 1. Daniel Holden, Jun Saito, Taku Komura. 2015. "Learning an inverse rig mapping for character animation" in SCA '15: Proceedings of the 14th ACM SIGGRAPH / Eurographics Symposium on Computer Animation, August 2015, Pages 165–173, https://dl.acm.org/doi/10.1145/2786784.2786788 136 | 2. Carl Edward Rasmussen, Christopher K. I. Williams. 2005. "Gaussian Processes for Machine Learning", Retrieved June 21, 2023, from https://gaussianprocess.org/gpml/chapters/RW.pdf 137 | 3. GPyTorch. 2023. "GPyTorch Documentation". Retrieved June 21, 2023, from https://docs.gpytorch.ai/en/latest/ 138 | 4. PyTorch. 2023. "PyTorch Documentation Release 2.0". Retrieved June 21, 2023, from https://pytorch.org/docs/2.0/ 139 | 5. Unreal Engine. 2023. "How to use the Machine Learning Deformer". Retrieved June 21, 2023, from https://docs.unrealengine.com/5.2/en-US/how-to-use-the-machine-learning-deformer-in-unreal-engine/ 140 | 6. Dmitry Kostiaev. 2020. "Better rotation representations for accurate pose estimation". Retrieved June 21, 2023, from https://towardsdatascience.com/better-rotation-representations-for-accurate-pose-estimation-e890a7e1317f 141 | 7. Eric Perim, Wessel Bruinsma, and Will Tebbutt. 2021. "Gaussian Processes: from one to many outputs". Retrieved June 21, 2023, from https://invenia.github.io/blog/2021/02/19/OILMM-pt1/#:~:text=Then%2C%20a%20multi%2Doutput%20Gaussian,on%20an%20extended%20input%20space 142 | 143 | ## 6. Dependencies 144 | All dependencies are linked dynamically. 145 | - Maya SDK (Maya Commands / OpenMaya) - Autodesk EULA: https://download.autodesk.com/us/FY17/Suites/LSA/en-us/lsa.html 146 | - PyMEL - BSD license: https://github.com/LumaPictures/pymel/blob/master/LICENSE 147 | - PyTorch - BSD license: https://github.com/pytorch/pytorch/blob/main/LICENSE 148 | - GPyTorch - MIT license: https://github.com/cornellius-gp/gpytorch/blob/master/LICENSE 149 | - NumPy - BSD 3-clause license: https://github.com/numpy/numpy/blob/main/LICENSE.txt 150 | - Pandas - BSD 3-clause license: https://github.com/pandas-dev/pandas/blob/main/LICENSE 151 | - PySide2 - LGPLv3 license: https://doc.qt.io/qt-6/lgpl.html 152 | 153 | ## 7. About 154 | This Tool has been developed within the scope of the Technical Director course at Filmakademie Baden-Wuerttemberg. http://technicaldirector.de. Written by Lukas Kapp. Copyright © 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg. 155 | -------------------------------------------------------------------------------- /anim_data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /install_maya.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import maya.cmds as cmds 13 | import maya.mel as mel 14 | import os 15 | 16 | 17 | def create_shelf_button(): 18 | irm_path = os.environ.get('IRM_PATH') 19 | if irm_path is None: 20 | print("Environment variable not found.") 21 | else: 22 | print(irm_path) 23 | 24 | button_command = f''' 25 | import sys 26 | from imp import reload 27 | 28 | path = "{irm_path} 29 | 30 | if not path in sys.path: 31 | sys.path.append(path) 32 | 33 | import system.irm_ui as ui 34 | reload(ui) 35 | ''' 36 | 37 | # Find the currently active shelf 38 | top_shelf = mel.eval("$tempVar = $gShelfTopLevel") 39 | active_shelf = cmds.tabLayout(top_shelf, query=True, selectTab=True) 40 | 41 | # Create a new button in the current shelf 42 | cmds.shelfButton( 43 | parent=active_shelf, 44 | command=button_command, 45 | annotation='IRM Tool', 46 | image1='commandButton.png', # Replace with your icon 47 | width=10, 48 | height=10, 49 | label='IRM_button', 50 | imageOverlayLabel='IRM' 51 | ) 52 | 53 | create_shelf_button() 54 | -------------------------------------------------------------------------------- /install_venv.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | goto :start 3 | ----------------------------------------------------------------------------- 4 | This file has been developed within the scope of the 5 | Technical Director course at Filmakademie Baden-Wuerttemberg. 6 | http://technicaldirector.de 7 | 8 | Written by Lukas Kapp 9 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 10 | ----------------------------------------------------------------------------- 11 | :start 12 | 13 | REM Get the directory of the current batch script 14 | set "SCRIPT_DIR=%~dp0" 15 | 16 | echo Creating a new virtual environment... 17 | python -m venv venv 18 | 19 | echo Activating the virtual environment... 20 | call .\venv\Scripts\activate 21 | 22 | echo Installing requirements... 23 | pip install --upgrade pip 24 | pip install -r requirements.txt 25 | 26 | echo Setting up environment variable... 27 | setx IRM_PATH "%SCRIPT_DIR%" 28 | 29 | echo Setup completed. 30 | pause 31 | -------------------------------------------------------------------------------- /predict_data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asttokens==2.2.1 2 | backcall==0.2.0 3 | certifi==2022.12.7 4 | charset-normalizer==2.1.1 5 | colorama==0.4.6 6 | comm==0.1.3 7 | contourpy==1.0.7 8 | cycler==0.11.0 9 | debugpy==1.6.7 10 | decorator==5.1.1 11 | executing==1.2.0 12 | filelock==3.9.0 13 | fonttools==4.39.3 14 | gpytorch==1.9.1 15 | idna==3.4 16 | ipykernel==6.22.0 17 | ipython==8.12.0 18 | jedi==0.18.2 19 | Jinja2==3.1.2 20 | joblib==1.2.0 21 | jupyter_client==8.2.0 22 | jupyter_core==5.3.0 23 | kiwisolver==1.4.4 24 | linear-operator==0.3.0 25 | MarkupSafe==2.1.2 26 | mpmath==1.2.1 27 | nest-asyncio==1.5.6 28 | networkx==3.0 29 | numpy==1.24.1 30 | packaging==23.1 31 | pandas==2.0.0 32 | parso==0.8.3 33 | pickleshare==0.7.5 34 | Pillow==9.3.0 35 | platformdirs==3.2.0 36 | prompt-toolkit==3.0.38 37 | psutil==5.9.4 38 | pure-eval==0.2.2 39 | Pygments==2.15.0 40 | pyparsing==3.0.9 41 | python-dateutil==2.8.2 42 | pytz==2023.3 43 | pywin32==306 44 | pyzmq==25.0.2 45 | requests==2.28.1 46 | six==1.16.0 47 | stack-data==0.6.2 48 | sympy==1.11.1 49 | threadpoolctl==3.1.0 50 | tornado==6.2 51 | traitlets==5.9.0 52 | typing_extensions==4.4.0 53 | tzdata==2023.3 54 | urllib3==1.26.13 55 | wcwidth==0.2.6 56 | --extra-index-url https://download.pytorch.org/whl/cu118 57 | torch==2.0.0+cu118 58 | torchaudio==2.0.1+cu118 59 | torchvision==0.15.1+cu118 60 | -------------------------------------------------------------------------------- /system/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /system/apply_prediction.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import maya.cmds as cmds 13 | import pymel.core as pm 14 | import maya.api.OpenMaya as om 15 | import math 16 | import csv 17 | import pathlib 18 | import os 19 | import subprocess 20 | 21 | import system.prep_anim_data as prep_anim_data 22 | 23 | 24 | 25 | def get_predict_data(anim_path, model_path, rig_path, py_app): 26 | py_path = pathlib.Path(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent 27 | 28 | py_cmd = f"import sys; sys.path.append('{py_path}'); import system.gpr_predict as gpr_predict; gpr_predict.predict_data('{anim_path}', '{model_path}', '{rig_path}')" 29 | 30 | command = [py_app, "-u", "-c", py_cmd] 31 | # start subprocess; prevent output window 32 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, 33 | text=True, creationflags=subprocess.CREATE_NO_WINDOW) 34 | 35 | # update progress bar 36 | 37 | for line in iter(process.stdout.readline, ""): 38 | if line.startswith("PROGRESS "): 39 | progress = float(line.strip().split(" ")[1]) 40 | if cmds.progressWindow(query=True, isCancelled=True): 41 | process.kill() 42 | break 43 | cmds.progressWindow(edit=True, progress=progress, status=(f'Progress: {progress} %')) 44 | 45 | process.wait() 46 | 47 | if process.returncode: 48 | error_message = process.stderr.read() 49 | raise ValueError(error_message) 50 | else: 51 | om.MGlobal.displayInfo("Data predicted successfully!") 52 | 53 | 54 | 55 | def map_data(anim_input_data, jnt_path, rig_path, model_path, py_app): 56 | # Initialize the progress window 57 | cmds.progressWindow(title='Map Prediction', progress=0, status='Getting Animation Data...', isInterruptable=True) 58 | 59 | # get animation data 60 | anim_path, frames = prep_anim_data.prep_data(anim_input_data, jnt_path) 61 | 62 | # predict data 63 | cmds.progressWindow(edit=True, progress=0, status=('Initialize Prediction...')) 64 | get_predict_data(anim_path, model_path, rig_path, py_app) 65 | 66 | cmds.progressWindow(edit=True, progress=0, status=('Applying Prediction...')) 67 | predict_path = pathlib.PurePath(pathlib.Path(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "predict_data/irm_predict_data.csv") 68 | with open(predict_path, "r") as csvfile: 69 | reader = csv.reader(csvfile) 70 | header = next(reader, None) # skip header 71 | 72 | predict_data = [row for row in reader if len(row) != 0] 73 | ctrl_depth = len(list(dict.fromkeys([int(n[0]) for n in predict_data]))) 74 | 75 | current_frame = cmds.currentTime(q=1) 76 | for i, data in enumerate(predict_data): 77 | ctrl = data[1] 78 | values = data[3:] 79 | # predict data is ctrl depth times frames 80 | # so need to divide i with ctrl depth to get current anim frame 81 | frame = frames[math.floor(i/ctrl_depth)] 82 | rotMtx = [] 83 | for value in values: 84 | if value != "nan": 85 | attr_name = header[data.index(value)] 86 | if "rotMtx_" in attr_name: 87 | rotMtx.append(value) 88 | continue 89 | cmds.setAttr("{}.{}".format(ctrl, attr_name), float(value)) 90 | cmds.setKeyframe(ctrl, t=frame, at=attr_name, v=float(value)) 91 | 92 | if rotMtx: 93 | mtx = pm.dt.TransformationMatrix((float(rotMtx[0]), float(rotMtx[1]), float(rotMtx[2]), 0.0, 94 | float(rotMtx[3]), float(rotMtx[4]), float(rotMtx[5]), 0.0, 95 | float(rotMtx[6]), float(rotMtx[7]), float(rotMtx[8]), 0.0, 96 | 0.0, 0.0, 0.0, 1.0)).euler 97 | rot = [math.degrees(mtx[0]), math.degrees(mtx[1]), math.degrees(mtx[2])] 98 | 99 | for attr in ["rx", "ry", "rz"]: 100 | cmds.xform(ctrl, ro=rot, os=1) 101 | value = cmds.getAttr("{}.{}".format(ctrl, attr)) 102 | cmds.setKeyframe(ctrl, t=frame, at=attr, v=float(value)) 103 | 104 | cmds.xform(ctrl, s=[1.0,1.0,1.0]) 105 | cmds.setAttr("{}.shearXY".format(ctrl), 0.0) 106 | cmds.setAttr("{}.shearXZ".format(ctrl), 0.0) 107 | cmds.setAttr("{}.shearYZ".format(ctrl), 0.0) 108 | 109 | cmds.progressWindow(edit=True, progress=(i/len(predict_data))*100, status=('Applying Prediction...')) 110 | cmds.currentTime(current_frame) 111 | 112 | cmds.progressWindow(endProgress=True) 113 | -------------------------------------------------------------------------------- /system/generate_train_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import maya.cmds as cmds 13 | import pymel.core as pm 14 | import maya.api.OpenMaya as om 15 | import os 16 | import random 17 | import csv 18 | 19 | import utils.maya as mUtils 20 | import utils.data_gen as genUtils 21 | import utils.ui as uiUtils 22 | 23 | def generate_data(rig_input_data, jnt_input_data, rig_out, jnt_out, train_poses): 24 | # filter selection into joints and controls 25 | ctrl_list = [ctrl for ctrl in rig_input_data.keys() if mUtils.query_visibility(ctrl)] 26 | ctrl_list.sort() 27 | 28 | jnt_list = [jnt for jnt in jnt_input_data.keys()] 29 | jnt_list.sort() 30 | 31 | for ctrl in ctrl_list: 32 | mUtils.restore_defaults(ctrl) 33 | 34 | ctrl_unique_attrs = list(set([attr for ctrl in ctrl_list for attr in mUtils.filter_attrs_from_dict(rig_input_data[ctrl])])) 35 | # if one rotation axis is included in "ctrl_unique_attrs" lis, remove all rotation axis (X,Y,Z) and add a single "rotate" value 36 | # as a hint for later that a rotation matrix should be included in the train data 37 | if ("rotateX" or "rotateY" or "rotateZ") in ctrl_unique_attrs: 38 | ctrl_unique_attrs = list(set(ctrl_unique_attrs).difference(["rotateX", "rotateY", "rotateZ"])) 39 | ctrl_unique_attrs.append("rot_mtx") 40 | ctrl_unique_attrs.sort() # reoder list to make it independent of control selection 41 | 42 | # list all unique attrs across all joints / controls - that will define the amount of column headers 43 | jnt_unique_attrs = list(set([jnt_attr[0] for jnt in jnt_list for jnt_attr in jnt_input_data[jnt]])) 44 | if ("rotateX" or "rotateY" or "rotateZ") in jnt_unique_attrs: 45 | jnt_unique_attrs = list(set(jnt_unique_attrs).difference(["rotateX", "rotateY", "rotateZ"])) 46 | jnt_unique_attrs.append("rot_mtx") 47 | jnt_unique_attrs.sort() # reoder list to make it independent of joint selection 48 | 49 | # build data header based on unique attrs of controls and joints 50 | rig_header = genUtils.build_header(base_header=["No.", "rigName", "dimension"], attr_list=ctrl_unique_attrs) 51 | jnt_header = genUtils.build_header(base_header=["No.", "jointName", "dimension"], attr_list=jnt_unique_attrs) 52 | 53 | # create empty csv files, fill them row by row with data 54 | with open(rig_out, "w") as f: 55 | writer = csv.writer(f) 56 | writer.writerow(rig_header) 57 | 58 | with open(jnt_out, "w") as f: 59 | writer = csv.writer(f) 60 | writer.writerow(jnt_header) 61 | 62 | # Initialize the progress window 63 | cmds.progressWindow(title='Generating...', progress=0, status='Starting...', isInterruptable=True) 64 | 65 | rig_data = [] 66 | jnt_data = [] 67 | for i in range(train_poses): 68 | if cmds.progressWindow(query=True, isCancelled=True): 69 | break 70 | 71 | for ctrl_index, ctrl in enumerate(ctrl_list): 72 | # only get integer and float attributes of selected control 73 | attr_list = [attr[0] for attr in rig_input_data[ctrl]] 74 | 75 | # check if rotation is in attr list 76 | rotation = genUtils.check_for_rotation(attr_list) 77 | 78 | # set dimension to length of attr_list; if rotate in list, remove rotate and add matrix3 (9 dimension) 79 | attr_dimension = genUtils.get_attr_dimension(attr_list, rotation) 80 | 81 | input_range_list = {} 82 | for values in rig_input_data[ctrl]: 83 | input_range_list[values[0]] = [values[1], values[2]] 84 | 85 | random.seed(i+ctrl_index) 86 | for attr in attr_list: 87 | input_rand_min = float(input_range_list[attr][0]) 88 | input_rand_max = float(input_range_list[attr][1]) 89 | 90 | if cmds.attributeQuery(attr, node=ctrl, minExists=1) or cmds.attributeQuery(attr, node=ctrl, maxExists=1): # if min or max range exists, use those as new values for rand min/max 91 | if cmds.attributeQuery(attr, node=ctrl, minExists=1): 92 | rand_min = cmds.attributeQuery(attr, node=ctrl, min=1)[0] 93 | else: 94 | rand_min = input_rand_min 95 | 96 | if cmds.attributeQuery(attr, node=ctrl, maxExists=1): 97 | rand_max = cmds.attributeQuery(attr, node=ctrl, max=1)[0] 98 | else: 99 | rand_max = input_rand_max 100 | else: 101 | rand_min, rand_max = mUtils.check_transformLimit(ctrl, axis=attr, default_min=input_rand_min, default_max=input_rand_max) 102 | 103 | if rand_min < input_rand_min: 104 | rand_min = input_rand_min 105 | 106 | if rand_max > input_rand_max: 107 | rand_max = input_rand_max 108 | 109 | cmds.setAttr("{}.{}".format(ctrl, attr), round(random.uniform(rand_min, rand_max), 5)) 110 | 111 | 112 | # create list with n/a for every attr in rig_header 113 | rig_data_row = [ctrl_index, ctrl, attr_dimension] 114 | rig_data_row.extend(["n/a" for i in range(len(rig_header)-3)]) 115 | for attr in attr_list: 116 | if not attr in ["rotateX", "rotateY", "rotateZ"]: 117 | # replace only used attr of ctrls in n/a list, rest stays at n/a 118 | rig_data_row[rig_header.index(attr)] = cmds.getAttr("{}.{}".format(ctrl, attr)) 119 | 120 | 121 | if rotation: 122 | ctrl_mtx = pm.dt.TransformationMatrix(cmds.xform(ctrl, m=1, q=1, os=1)) 123 | ctrl_rot_mtx3 = [x for mtx in ctrl_mtx.asRotateMatrix()[:-1] for x in mtx[:-1]] 124 | 125 | start_index = rig_header.index("rotMtx_00") # get index of first rotMtx entry in rig_header and start replacing rotMtx values from there 126 | for mtx_index, rot_mtx in enumerate(ctrl_rot_mtx3): 127 | rig_data_row[start_index + mtx_index] = rot_mtx 128 | 129 | with open(rig_out, "a") as f: 130 | writer = csv.writer(f) 131 | writer.writerow(rig_data_row) 132 | 133 | 134 | for y, jnt in enumerate(jnt_list): 135 | attr_list = [attr[0] for attr in jnt_input_data[jnt]] 136 | rotation = genUtils.check_for_rotation(attr_list) 137 | jnt_dimension = genUtils.get_attr_dimension(attr_list, rotation) 138 | 139 | jnt_data_row = [y, jnt, jnt_dimension] 140 | jnt_data_row.extend(["n/a" for i in range(len(jnt_header)-3)]) 141 | for attr in attr_list: 142 | if not attr in ["rotateX", "rotateY", "rotateZ"]: 143 | # replace only used attr of jnts in n/a list, rest stays at n/a 144 | jnt_data_row[jnt_header.index(attr)] = cmds.getAttr("{}.{}".format(jnt, attr)) 145 | 146 | if rotation: 147 | jnt_mtx = pm.dt.TransformationMatrix(cmds.xform(jnt, m=1, q=1, os=1)) 148 | jnt_rot_mtx3 = [x for mtx in jnt_mtx.asRotateMatrix()[:-1] for x in mtx[:-1]] 149 | 150 | start_index = jnt_header.index("rotMtx_00") # get index of first rotMtx entry in jnt_header and start replacing rotMtx values from there 151 | for mtx_index, rot_mtx in enumerate(jnt_rot_mtx3): 152 | jnt_data_row[start_index + mtx_index] = rot_mtx 153 | 154 | with open(jnt_out, "a") as f: 155 | writer = csv.writer(f) 156 | writer.writerow(jnt_data_row) 157 | 158 | #update progress 159 | cmds.progressWindow(edit=True, progress=((float(i+1))/train_poses)*100, 160 | status=f'Generating {i+1}/{train_poses}...') 161 | 162 | 163 | # reset transforms to zero 164 | for ctrl in ctrl_list: 165 | mUtils.restore_defaults(ctrl) 166 | 167 | cmds.progressWindow(endProgress=True) 168 | 169 | -------------------------------------------------------------------------------- /system/gpr_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import torch 13 | import gpytorch 14 | import numpy as np 15 | import pandas as pd 16 | import pathlib 17 | import os 18 | 19 | import utils.pytorch as torchUtils 20 | 21 | class BatchIndependentMultitaskGPModel(gpytorch.models.ExactGP): 22 | def __init__(self, train_x, train_y, likelihood, force_cpu, train_x_dimension, train_y_dimension, 23 | x_min,x_max, x_mean, y_min, y_max, y_mean): 24 | super(BatchIndependentMultitaskGPModel, self).__init__(train_x, train_y, likelihood) 25 | self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([train_y_dimension])) 26 | self.covar_module = gpytorch.kernels.ScaleKernel( 27 | gpytorch.kernels.RBFKernel(batch_shape=torch.Size([train_y_dimension])), 28 | batch_shape=torch.Size([train_y_dimension]) 29 | ) 30 | 31 | self.register_buffer('force_cpu', torch.tensor([force_cpu])) 32 | 33 | self.register_buffer('train_x', train_x) 34 | self.register_buffer('x_min', x_min) 35 | self.register_buffer('x_max', x_max) 36 | self.register_buffer('x_mean', x_mean) 37 | self.register_buffer('x_dim', torch.tensor([train_x_dimension])) 38 | 39 | self.register_buffer('train_y', train_y) 40 | self.register_buffer('y_min', y_min) 41 | self.register_buffer('y_max', y_max) 42 | self.register_buffer('y_mean', y_mean) 43 | self.register_buffer('y_dim', torch.tensor([train_y_dimension])) 44 | 45 | 46 | def forward(self, x): 47 | mean_x = self.mean_module(x) 48 | covar_x = self.covar_module(x) 49 | return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn( 50 | gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 51 | ) 52 | 53 | 54 | def build_data_tensor(dataset_path, min_val=None, max_val=None, mean_val=None): 55 | ### BUILD DATA TENSOR ### 56 | dataset = pd.read_csv(dataset_path, na_values='?', comment='\t', sep=',', skipinitialspace=True, header=[0]) 57 | 58 | # number of objs = length of unique items of first column 59 | numObjs = len(np.unique(dataset.iloc[:, 0])) 60 | 61 | # get dimensions of each obj and get sum of it 62 | dimension = sum(dataset.filter(items=["dimension"]).values[:numObjs])[0] 63 | highest_dim = max(dataset.filter(items=["dimension"]).values[:numObjs])[0] 64 | 65 | # create list with entries of all attribute columns 66 | raw_tensor = np.array(dataset.iloc[:, 3:].values).reshape(-1, numObjs*highest_dim) 67 | attr_list = dataset.columns.values.tolist() 68 | rotMtx_index_list = [attr_list.index(attr) for attr in attr_list if "rotMtx_" in attr] 69 | 70 | # extract values before and after rotMtx 71 | before_rotMtx_tensor = torch.from_numpy(dataset.iloc[:, 3:rotMtx_index_list[0]].values) 72 | after_rotMtx_tensor = torch.from_numpy(dataset.iloc[:, rotMtx_index_list[-1]+1:].values) 73 | 74 | # extract rot matrix and convert to quaternion 75 | rotMtx_tensor = torch.from_numpy(dataset.iloc[:, rotMtx_index_list[0]:rotMtx_index_list[-1]+1].values).reshape(-1, 3, 3) 76 | quat_tensor = torchUtils.batch_rotation_matrix_to_quaternion(rotMtx_tensor) 77 | #quat_tensor = torchUtils.matrix_to_6d(rotMtx_tensor).reshape(-1, 6) 78 | 79 | # concatenate tensors back; reduce dimension since quat is only 4 entries compared to 9 of rot mtx 80 | concat_tensor = torch.cat((before_rotMtx_tensor, quat_tensor, after_rotMtx_tensor), dim=1) 81 | quat_dim = dimension - (5 * numObjs) 82 | #quat_dim = dimension - (3 * numObjs) 83 | 84 | # remove n/a entries from data 85 | cleaned_tensor = torch.from_numpy(np.array([entry for row in concat_tensor.tolist() for entry in row if str(entry) != "nan"]).reshape(-1, quat_dim)) 86 | 87 | # normalize tensor 88 | if min_val is None or max_val is None or mean_val is None: 89 | min_val, max_val, mean_val = torchUtils.calculate_min_max_mean(cleaned_tensor) 90 | norm_tensor = torchUtils.normalize_tensor(cleaned_tensor, min_val, max_val, mean_val) 91 | 92 | return norm_tensor.float(), quat_dim, min_val.float(), max_val.float(), mean_val.float(), concat_tensor.float() 93 | 94 | 95 | 96 | def train_model(rig_path, jnt_path, model_path, lr, epochs, force_cpu): 97 | ## enable GPU/CUDA if available 98 | if torch.cuda.is_available() and not force_cpu: 99 | dev = "cuda:0" 100 | else: 101 | dev = "cpu" 102 | device = torch.device(dev) 103 | 104 | # build tensors 105 | train_x, train_x_dimension, x_min, x_max, x_mean, x_concat = build_data_tensor(jnt_path) 106 | train_x = train_x.to(device) 107 | 108 | train_y, train_y_dimension, y_min, y_max, y_mean, y_concat = build_data_tensor(rig_path) 109 | train_y = train_y.to(device) 110 | 111 | 112 | likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=train_y_dimension).to(device) 113 | model = BatchIndependentMultitaskGPModel(train_x, train_y, likelihood, force_cpu, train_x_dimension, 114 | train_y_dimension, x_min,x_max, x_mean, y_min, y_max, y_mean).to(device) 115 | 116 | 117 | # Find optimal model hyperparameters 118 | model.train() 119 | likelihood.train() 120 | 121 | # Use the adam optimizer 122 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) # lr = learning rate 123 | 124 | # "Loss" for GPs - the marginal log likelihood 125 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) 126 | 127 | for i in range(epochs): 128 | optimizer.zero_grad() 129 | output = model(train_x) 130 | loss = -mll(output, train_y) 131 | loss.backward() 132 | print('Iter %d/%d - Loss: %.3f' % (i + 1, epochs, loss.item())) 133 | optimizer.step() 134 | print(f"PROGRESS {100.0 * (i + 1) / epochs}") 135 | 136 | 137 | # Set into eval mode 138 | model.eval() 139 | likelihood.eval() 140 | 141 | torch.save(model.state_dict(), model_path) # save trained model parameters 142 | -------------------------------------------------------------------------------- /system/gpr_predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import torch 13 | import gpytorch 14 | import numpy as np 15 | import pandas as pd 16 | import csv 17 | import pathlib 18 | import os 19 | 20 | import system.gpr_model as gpr 21 | import utils.pytorch as torchUtils 22 | 23 | 24 | def predict_data(anim_path, model_path, rig_path): 25 | # load trained model 26 | state_dict = torch.load(model_path) 27 | force_cpu = state_dict["force_cpu"] 28 | 29 | # enable GPU/CUDA if available 30 | if torch.cuda.is_available() and not force_cpu: 31 | dev = "cuda:0" 32 | else: 33 | dev = "cpu" 34 | device = torch.device(dev) 35 | 36 | train_x = state_dict["train_x"] 37 | train_x = train_x.to(device) 38 | train_y = state_dict["train_y"] 39 | train_y = train_y.to(device) 40 | train_x_dimension = state_dict["x_dim"] 41 | train_y_dimension = state_dict["y_dim"] 42 | 43 | x_min = state_dict["x_min"].to(device) 44 | x_max = state_dict["x_max"].to(device) 45 | x_mean = state_dict["x_mean"].to(device) 46 | 47 | y_min = state_dict["y_min"].to(device) 48 | y_max = state_dict["y_max"].to(device) 49 | y_mean = state_dict["y_mean"].to(device) 50 | 51 | likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=train_y_dimension).to(device) 52 | 53 | model = gpr.BatchIndependentMultitaskGPModel(train_x, train_y, likelihood, force_cpu, train_x_dimension, 54 | train_y_dimension, x_min,x_max, x_mean, y_min, y_max, y_mean).float().to(device) 55 | model.load_state_dict(state_dict) 56 | 57 | model.eval() 58 | likelihood.eval() 59 | print("PROGRESS 20") 60 | 61 | # load anim data 62 | print("PROGRESS 40") 63 | 64 | anim_dataset = pd.read_csv(anim_path, na_values='?', comment='\t', sep=',', skipinitialspace=True, header=[0]) 65 | print("PROGRESS 40") 66 | 67 | anim_numObjs = len(np.unique(anim_dataset.iloc[:, 1])) 68 | print("PROGRESS 40") 69 | 70 | anim_frame_len = int(len(anim_dataset.iloc[:, 1])/anim_numObjs) 71 | print("PROGRESS 40") 72 | 73 | anim_x, anim_quat_dim, min_val, max_val, mean_val, anim_concat = gpr.build_data_tensor(anim_path, min_val=x_min.to("cpu"), max_val=x_max.to("cpu"), mean_val=x_mean.to("cpu")) 74 | anim_x = anim_x.to(device) 75 | 76 | # get predict values from trained model 77 | print("PROGRESS 40") 78 | predict_mean = [] 79 | with torch.no_grad(), gpytorch.settings.fast_pred_var(): 80 | predict_y = likelihood(model(anim_x)) 81 | predict_mean = predict_y.mean.tolist() 82 | print("PROGRESS 70") 83 | 84 | 85 | # get rig dataset used in training for building predict data 86 | train_rig_df = pd.read_csv(rig_path, na_values='?', comment='\t', sep=',', skipinitialspace=True, header=[0]) 87 | rig_x, rig_quat_dim, rig_min, rig_max, rig_mean, rig_concat = gpr.build_data_tensor(rig_path) 88 | 89 | rig_numObjs = len(np.unique(train_rig_df.iloc[:, 1])) 90 | rig_highest_dim = max(train_rig_df.filter(items=["dimension"]).values[:rig_numObjs])[0] 91 | 92 | # denormalise tensor 93 | predict_tensor = torch.from_numpy(np.array(predict_mean)).to(device) 94 | denorm_predict = torchUtils.denormalize_tensor(predict_tensor, y_min, y_max, y_mean) 95 | 96 | # convert predict tensor to same structure of dataframe (with nan values) 97 | attr_list = train_rig_df.columns.values[3:].tolist() 98 | rotMtx_start = [attr_list.index(attr) for attr in attr_list if "rotMtx_" in attr][0] 99 | 100 | rig_nan_tensor = torch.tensor(rig_concat.reshape(-1, (rig_highest_dim-5)*rig_numObjs)[0]).repeat(anim_frame_len, 1) 101 | for entry_index, entry in enumerate(denorm_predict): 102 | replace_index = 0 103 | for nan_index, nan in enumerate(rig_nan_tensor[entry_index]): 104 | if not np.isnan(nan): 105 | rig_nan_tensor[entry_index][nan_index] = entry[replace_index] 106 | replace_index += 1 107 | rig_nan_tensor = rig_nan_tensor.reshape(-1, rig_highest_dim-5) 108 | 109 | # extract values before and after quat 110 | before_quat_tensor = torch.tensor(rig_nan_tensor[:, :rotMtx_start]).to(device) 111 | after_quat_tensor = torch.tensor(rig_nan_tensor[:, rotMtx_start+4:]).to(device) 112 | #after_quat_tensor = torch.tensor(denorm_predict[:, rotMtx_start+6:]).to(device) 113 | 114 | # extract rot matrix and convert to quaternion 115 | quat_tensor = torch.tensor(rig_nan_tensor[:, rotMtx_start:rotMtx_start+4]).reshape(-1, 4) 116 | #quat_tensor = torch.tensor(denorm_predict[:, rotMtx_start:rotMtx_start+6]) 117 | rotMtx_tensor = torchUtils.batch_quaternion_to_rotation_matrix(quat_tensor).reshape(-1, 9).to(device) 118 | #rotMtx_tensor = torchUtils._6d_to_matrix(quat_tensor).reshape(-1, 9).to(device) 119 | rotMtx_predict = torch.cat((before_quat_tensor, rotMtx_tensor, after_quat_tensor), dim=1) 120 | 121 | predict_rowBegin = train_rig_df.iloc[:rig_numObjs, :3].values 122 | predict_data = [predict_rowBegin[i%rig_numObjs].tolist() + data.tolist() for i, data in enumerate(rotMtx_predict)] 123 | 124 | print("PROGRESS 90") 125 | # save anim data for Maya 126 | predict_path = pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "predict_data/irm_predict_data.csv") 127 | predict_header = train_rig_df.columns.values # use header of rig train data for predict header 128 | 129 | with open(predict_path, "w") as f: 130 | writer = csv.writer(f) 131 | writer.writerow(predict_header) 132 | writer.writerows(predict_data) 133 | 134 | print("PROGRESS 100") -------------------------------------------------------------------------------- /system/irm_ui.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | from PySide2 import QtCore, QtWidgets, QtGui 13 | from functools import partial 14 | import json 15 | import pathlib 16 | import os 17 | 18 | import utils.ui as uiUtils 19 | 20 | 21 | 22 | class DataGenWidget(QtWidgets.QWidget): 23 | def __init__(self, parent=None): 24 | super(DataGenWidget, self).__init__(parent) 25 | 26 | self.create_widgets() 27 | self.create_layouts() 28 | self.create_connections() 29 | 30 | self.rig_tree.checkIfEmpty() 31 | self.jnt_tree.checkIfEmpty() 32 | 33 | def create_widgets(self): 34 | # rig widgets 35 | self.rig_param_label = QtWidgets.QLabel("Control Rig Parameters (0)") 36 | self.rig_param_label.setStyleSheet("color: #00ff6e;") 37 | self.rig_clear_btn = QtWidgets.QPushButton("Clear All") 38 | self.rig_add_btn = QtWidgets.QPushButton("Add") 39 | 40 | # rig tree view widget 41 | rig_msg = 'Add rig control parameters using the "Add" button.\nMake sure you have at least one object selected.\nAdjust the parameter range and delete unwanted parameters by right-clicking and choosing "Delete".' 42 | self.rig_tree = uiUtils.PlaceholderTreeWidget(self, rig_msg) 43 | self.rig_tree.setHeaderLabels(['Control Name', 'Min', 'Max']) 44 | self.rig_tree.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) 45 | self.rig_tree.setItemDelegate(uiUtils.EditableItemDelegate(self.rig_tree)) 46 | self.rig_tree.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) 47 | 48 | rig_header = self.rig_tree.header() 49 | rig_header.setStretchLastSection(False) 50 | rig_header.setSectionResizeMode(0, QtWidgets.QHeaderView.Stretch) 51 | rig_header.setSectionResizeMode(1, QtWidgets.QHeaderView.Interactive) 52 | rig_header.setSectionResizeMode(2, QtWidgets.QHeaderView.Interactive) 53 | rig_header.resizeSection(1, 100) 54 | rig_header.resizeSection(2, 100) 55 | 56 | 57 | # joint widgets 58 | self.jnt_param_label = QtWidgets.QLabel("Joint Parameters (0)") 59 | self.jnt_param_label.setStyleSheet("color: #00ff6e;") 60 | self.jnt_clear_btn = QtWidgets.QPushButton("Clear All") 61 | self.jnt_add_btn = QtWidgets.QPushButton("Add") 62 | 63 | # joint tree view widget 64 | jnt_msg = 'Add joints using the "Add" button.\nMake sure you have at least one joint selected.\nOnly connected parameters are added. Delete unwanted ones by right-clicking and choosing "Delete".' 65 | self.jnt_tree = uiUtils.PlaceholderTreeWidget(self, jnt_msg) 66 | self.jnt_tree.setHeaderLabels(['Joint Name']) 67 | self.jnt_tree.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) 68 | self.jnt_tree.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) 69 | 70 | jnt_header = self.jnt_tree.header() 71 | jnt_header.setStretchLastSection(False) 72 | jnt_header.setSectionResizeMode(0, QtWidgets.QHeaderView.Stretch) 73 | 74 | 75 | # train data settings 76 | self.setting_widget = QtWidgets.QWidget() 77 | self.paramProperties_label = QtWidgets.QLabel("Parameter Properties:") 78 | 79 | self.minRange_label = QtWidgets.QLabel("Minimum:") 80 | self.minRange_line = QtWidgets.QLineEdit("-50.000") 81 | self.minRange_line.setValidator(QtGui.QRegExpValidator(QtCore.QRegExp("^-?\d+(\.\d{0,3})?$"))) 82 | self.minRange_line.setEnabled(False) 83 | 84 | self.maxRange_label = QtWidgets.QLabel("Maximum:") 85 | self.maxRange_line = QtWidgets.QLineEdit("50.000") 86 | self.maxRange_line.setValidator(QtGui.QRegExpValidator(QtCore.QRegExp("^-?\d+(\.\d{0,3})?$"))) 87 | self.maxRange_line.setEnabled(False) 88 | 89 | self.dataSpacer01_label = QtWidgets.QLabel() 90 | self.dataSpacer01_label.setFixedHeight(10) 91 | 92 | self.numPoses_label = QtWidgets.QLabel("Number of Poses:") 93 | self.numPoses_line = QtWidgets.QLineEdit("1000") 94 | self.numPoses_line.setValidator(QtGui.QIntValidator(1, 9999999)) 95 | 96 | self.dataSpacer02_label = QtWidgets.QLabel() 97 | self.dataSpacer02_label.setFixedHeight(50) 98 | 99 | 100 | # train model settings 101 | self.lr_label = QtWidgets.QLabel("Learning Rate:") 102 | self.lr_line = QtWidgets.QLineEdit("0.01") 103 | self.lr_line.setValidator(QtGui.QRegExpValidator(QtCore.QRegExp("^\d+(\.\d{0,8})?$"))) 104 | 105 | self.epoch_label = QtWidgets.QLabel("Epochs:") 106 | self.epoch_line = QtWidgets.QLineEdit("100") 107 | self.epoch_line.setValidator(QtGui.QIntValidator(1, 9999999)) 108 | 109 | self.forceCPU_label = QtWidgets.QLabel("Force CPU:") 110 | self.forceCPU_checkBox = QtWidgets.QCheckBox() 111 | 112 | self.trainSpacer01_label = QtWidgets.QLabel() 113 | self.trainSpacer01_label.setFixedHeight(10) 114 | 115 | self.pyApp_label = QtWidgets.QLabel("Python Exe:") 116 | self.pyApp_line = QtWidgets.QLineEdit() 117 | py_path = pathlib.Path(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent 118 | self.pyApp_line.setText(pathlib.PurePath(py_path, pathlib.Path("venv/Scripts/python.exe")).as_posix()) 119 | self.pyApp_line.setReadOnly(True) 120 | self.pyApp_btn = QtWidgets.QPushButton("<") 121 | self.pyApp_btn.setFixedSize(20,20) 122 | 123 | self.trainSpacer02_label = QtWidgets.QLabel() 124 | self.trainSpacer02_label.setFixedHeight(50) 125 | 126 | 127 | # output settings 128 | self.outRig_label = QtWidgets.QLabel("Control Rig File:") 129 | self.outRig_line = QtWidgets.QLineEdit() 130 | self.outRig_line.setText(pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "training_data/rig/irm_rig_data.csv").as_posix()) 131 | self.outRig_line.setReadOnly(True) 132 | self.outRig_btn = QtWidgets.QPushButton("<") 133 | self.outRig_btn.setFixedSize(20,20) 134 | 135 | self.outJnt_label = QtWidgets.QLabel("Joint File:") 136 | self.outJnt_line = QtWidgets.QLineEdit() 137 | self.outJnt_line.setText(pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "training_data/jnt/irm_jnt_data.csv").as_posix()) 138 | self.outJnt_line.setReadOnly(True) 139 | self.outJnt_btn = QtWidgets.QPushButton("<") 140 | self.outJnt_btn.setFixedSize(20,20) 141 | 142 | self.outSpacer01_label = QtWidgets.QLabel() 143 | self.outSpacer01_label.setFixedHeight(10) 144 | 145 | self.outModel_label = QtWidgets.QLabel("Model File:") 146 | self.outModel_line = QtWidgets.QLineEdit() 147 | self.outModel_line.setText(pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "trained_model/trained_model.pt").as_posix()) 148 | self.outModel_line.setReadOnly(True) 149 | self.outModel_btn = QtWidgets.QPushButton("<") 150 | self.outModel_btn.setFixedSize(20,20) 151 | 152 | self.outSpacer02_label = QtWidgets.QLabel() 153 | self.outSpacer02_label.setFixedHeight(50) 154 | 155 | 156 | # buttons 157 | self.generate_btn = QtWidgets.QPushButton("Generate Train Data") 158 | self.train_btn = QtWidgets.QPushButton("Train Model") 159 | 160 | 161 | def create_layouts(self): 162 | # main UI layout 163 | main_layout = QtWidgets.QHBoxLayout() 164 | main_layout.setContentsMargins(3,3,3,3) 165 | 166 | # left UI side - rig and joint tree views 167 | tree_widget = QtWidgets.QWidget() 168 | tree_layout = QtWidgets.QVBoxLayout() 169 | tree_widget.setLayout(tree_layout) 170 | 171 | # rig layout 172 | rig_group = QtWidgets.QGroupBox() 173 | tree_layout.addWidget(rig_group) 174 | rig_attr_layout = QtWidgets.QVBoxLayout() 175 | rig_group.setLayout(rig_attr_layout) 176 | 177 | rig_btn_layout = QtWidgets.QHBoxLayout() 178 | rig_btn_layout.addWidget(self.rig_param_label) 179 | rig_btn_layout.addWidget(self.rig_clear_btn) 180 | rig_btn_layout.addWidget(self.rig_add_btn) 181 | rig_attr_layout.addLayout(rig_btn_layout) 182 | 183 | rig_attr_layout.addWidget(self.rig_tree) 184 | 185 | # jnt layout 186 | jnt_group = QtWidgets.QGroupBox() 187 | tree_layout.addWidget(jnt_group) 188 | jnt_attr_layout = QtWidgets.QVBoxLayout() 189 | jnt_group.setLayout(jnt_attr_layout) 190 | 191 | jnt_btn_layout = QtWidgets.QHBoxLayout() 192 | jnt_btn_layout.addWidget(self.jnt_param_label) 193 | jnt_btn_layout.addWidget(self.jnt_clear_btn) 194 | jnt_btn_layout.addWidget(self.jnt_add_btn) 195 | jnt_attr_layout.addLayout(jnt_btn_layout) 196 | 197 | jnt_attr_layout.addWidget(self.jnt_tree) 198 | 199 | 200 | # right UI side - settings for data gen and model training 201 | settings_widget = QtWidgets.QWidget() 202 | settings_layout = QtWidgets.QVBoxLayout() 203 | settings_widget.setLayout(settings_layout) 204 | 205 | # data gen settings 206 | data_widget = QtWidgets.QGroupBox("Generator Settings") 207 | data_widget.setObjectName("data_box") 208 | data_widget.setStyleSheet("QGroupBox#data_box { color: #00ff6e; }") 209 | data_layout = QtWidgets.QGridLayout() 210 | data_widget.setLayout(data_layout) 211 | 212 | dataAlign_widget = QtWidgets.QWidget() 213 | dataAlign_widget.setFixedWidth(110) 214 | dataAlign_widget.setFixedHeight(1) 215 | data_layout.addWidget(dataAlign_widget, 0, 0) 216 | 217 | data_layout.addWidget(self.paramProperties_label, 1, 0) 218 | 219 | data_layout.addWidget(self.minRange_label, 2, 0) 220 | data_layout.addWidget(self.minRange_line, 2, 1) 221 | 222 | data_layout.addWidget(self.maxRange_label, 3, 0) 223 | data_layout.addWidget(self.maxRange_line, 3, 1) 224 | 225 | data_layout.addWidget(self.dataSpacer01_label, 4, 0) 226 | 227 | data_layout.addWidget(self.numPoses_label, 5, 0) 228 | data_layout.addWidget(self.numPoses_line, 5, 1) 229 | 230 | data_layout.addWidget(self.dataSpacer02_label, 6, 0) 231 | 232 | 233 | # train settings 234 | train_widget = QtWidgets.QGroupBox("Train Settings") 235 | train_widget.setStyleSheet("QGroupBox { color: #00ff6e; }") 236 | train_layout = QtWidgets.QGridLayout() 237 | train_widget.setLayout(train_layout) 238 | 239 | trainAlign_widget = QtWidgets.QWidget() 240 | trainAlign_widget.setFixedWidth(110) 241 | trainAlign_widget.setFixedHeight(1) 242 | train_layout.addWidget(trainAlign_widget, 0, 0) 243 | 244 | train_layout.addWidget(self.lr_label, 1, 0) 245 | train_layout.addWidget(self.lr_line, 1, 1) 246 | 247 | train_layout.addWidget(self.epoch_label, 2, 0) 248 | train_layout.addWidget(self.epoch_line, 2, 1) 249 | 250 | train_layout.addWidget(self.forceCPU_label, 3, 0) 251 | train_layout.addWidget(self.forceCPU_checkBox, 3, 1) 252 | 253 | train_layout.addWidget(self.trainSpacer01_label, 4, 0) 254 | 255 | train_layout.addWidget(self.pyApp_label, 5, 0) 256 | train_layout.addWidget(self.pyApp_line, 5, 1) 257 | train_layout.addWidget(self.pyApp_btn, 5, 2) 258 | 259 | train_layout.addWidget(self.trainSpacer02_label, 6, 0) 260 | 261 | 262 | # output settings 263 | output_widget = QtWidgets.QGroupBox("Output Settings") 264 | output_widget.setStyleSheet("QGroupBox { color: #00ff6e; }") 265 | output_layout = QtWidgets.QGridLayout() 266 | output_widget.setLayout(output_layout) 267 | 268 | outAlign_widget = QtWidgets.QWidget() 269 | outAlign_widget.setFixedWidth(110) 270 | outAlign_widget.setFixedHeight(1) 271 | output_layout.addWidget(outAlign_widget, 0, 0) 272 | 273 | output_layout.addWidget(self.outRig_label, 1, 0) 274 | output_layout.addWidget(self.outRig_line, 1, 1) 275 | output_layout.addWidget(self.outRig_btn, 1, 2) 276 | 277 | output_layout.addWidget(self.outJnt_label, 2, 0) 278 | output_layout.addWidget(self.outJnt_line, 2, 1) 279 | output_layout.addWidget(self.outJnt_btn, 2, 2) 280 | 281 | output_layout.addWidget(self.outSpacer01_label, 3, 0) 282 | 283 | output_layout.addWidget(self.outModel_label, 4, 0) 284 | output_layout.addWidget(self.outModel_line, 4, 1) 285 | output_layout.addWidget(self.outModel_btn, 4, 2) 286 | 287 | output_layout.addWidget(self.outSpacer02_label, 5, 0) 288 | 289 | 290 | # data/train settings spliiter 291 | settings_splitter = uiUtils.UnmovableSplitter(QtCore.Qt.Vertical) 292 | settings_splitter.addWidget(data_widget) 293 | settings_splitter.addWidget(train_widget) 294 | settings_splitter.addWidget(output_widget) 295 | settings_splitter.setSizes([5000, 5000, 5000]) 296 | settings_layout.addWidget(settings_splitter) 297 | 298 | # button layout 299 | outBtn_layout = QtWidgets.QHBoxLayout() 300 | outBtn_layout.addWidget(self.generate_btn) 301 | outBtn_layout.addWidget(self.train_btn) 302 | settings_layout.addLayout(outBtn_layout) 303 | 304 | 305 | # main layout splitter 306 | main_splitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal) 307 | main_splitter.addWidget(tree_widget) 308 | main_splitter.addWidget(settings_widget) 309 | main_splitter.setSizes([7000, 3000]) 310 | main_layout.addWidget(main_splitter) 311 | self.setLayout(main_layout) 312 | 313 | 314 | 315 | def create_connections(self): 316 | # rig UI 317 | self.rig_clear_btn.clicked.connect(partial(uiUtils.clear_tree, self.rig_tree, self.rig_param_label)) 318 | self.rig_add_btn.clicked.connect(partial(self.add_tree_item, self.rig_tree, self.rig_param_label, jnt_mode=False)) 319 | 320 | self.rig_tree.customContextMenuRequested.connect(self.show_rig_context_menu) 321 | self.rig_tree.itemChanged.connect(self.rig_tree.checkIfEmpty) 322 | self.rig_tree.itemSelectionChanged.connect(self.update_param_range) 323 | 324 | # joint UI 325 | self.jnt_clear_btn.clicked.connect(partial(uiUtils.clear_tree, self.jnt_tree, self.jnt_param_label)) 326 | self.jnt_add_btn.clicked.connect(partial(self.add_tree_item, self.jnt_tree, self.jnt_param_label, jnt_mode=True)) 327 | 328 | self.jnt_tree.customContextMenuRequested.connect(self.show_jnt_context_menu) 329 | self.jnt_tree.itemChanged.connect(self.jnt_tree.checkIfEmpty) 330 | 331 | # data/train settings 332 | self.minRange_line.editingFinished.connect(self.update_min_range) 333 | self.maxRange_line.editingFinished.connect(self.update_max_range) 334 | 335 | self.pyApp_btn.clicked.connect(self.set_pyApp_outPath) 336 | 337 | self.outRig_btn.clicked.connect(self.set_rigData_outPath) 338 | self.outJnt_btn.clicked.connect(self.set_jntData_outPath) 339 | self.outModel_btn.clicked.connect(self.set_model_outPath) 340 | 341 | self.generate_btn.clicked.connect(self.generate_train_data) 342 | self.train_btn.clicked.connect(self.train_model) 343 | 344 | def set_pyApp_outPath(self): 345 | uiUtils.openFileDialog(self, self.pyApp_line, "Set Python Interpreter", "exe") 346 | 347 | def set_rigData_outPath(self): 348 | uiUtils.saveFileDialog(self, self.outRig_line, "Save Control Rig Train Data", "csv") 349 | 350 | def set_jntData_outPath(self): 351 | uiUtils.saveFileDialog(self, self.outJnt_line, "Save Joint Train Data", "csv") 352 | 353 | def set_model_outPath(self): 354 | uiUtils.saveFileDialog(self, self.outModel_line, "Save Trained Model", "pt") 355 | 356 | def add_tree_item(self, treeWidget, label, jnt_mode): 357 | uiUtils.add_selection(treeWidget, jnt_mode) 358 | uiUtils.update_param_label(treeWidget, label) 359 | 360 | def show_rig_context_menu(self, pos): 361 | uiUtils.show_context_menu(self, pos, self.rig_tree) 362 | self.rig_tree.checkIfEmpty() 363 | uiUtils.update_param_label(self.rig_tree, self.rig_param_label) 364 | 365 | def show_jnt_context_menu(self, pos): 366 | uiUtils.show_context_menu(self, pos, self.jnt_tree) 367 | self.jnt_tree.checkIfEmpty() 368 | uiUtils.update_param_label(self.jnt_tree, self.jnt_param_label) 369 | 370 | 371 | def update_param_range(self): 372 | selected_items = self.rig_tree.selectedItems() 373 | 374 | if selected_items: 375 | item = selected_items[0] 376 | 377 | if item.childCount() > 0: 378 | return 379 | else: 380 | self.minRange_line.setEnabled(True) 381 | self.maxRange_line.setEnabled(True) 382 | self.minRange_line.setText(item.text(1)) 383 | self.maxRange_line.setText(item.text(2)) 384 | else: 385 | self.minRange_line.setEnabled(False) 386 | self.maxRange_line.setEnabled(False) 387 | 388 | 389 | def ensure_three_digits(self, lineWidget): 390 | value = float(lineWidget.text()) 391 | formatted_value = "{:.3f}".format(value) 392 | lineWidget.setText(formatted_value) 393 | 394 | def ensure_notZero(self, lineWidget): 395 | value = int(lineWidget.text()) 396 | print(value) 397 | if value == 0: 398 | lineWidget.setText("1") 399 | 400 | def update_min_range(self): 401 | self.update_selected_items(self.minRange_line, 1) 402 | self.ensure_three_digits(self.minRange_line) 403 | 404 | def update_max_range(self): 405 | self.update_selected_items(self.maxRange_line, 2) 406 | self.ensure_three_digits(self.maxRange_line) 407 | 408 | 409 | def update_selected_items(self, lineWidget, column): 410 | selected_items = self.rig_tree.selectedItems() 411 | 412 | for item in selected_items: 413 | if item.childCount() > 0: 414 | continue 415 | try: 416 | item.setText(column, "{:.3f}".format(float(lineWidget.text()))) 417 | except ValueError: 418 | pass 419 | 420 | 421 | def generate_train_data(self): 422 | rig_input_data = uiUtils.get_treeItems_as_dict(treeWidget=self.rig_tree) 423 | jnt_input_data = uiUtils.get_treeItems_as_dict(treeWidget=self.jnt_tree) 424 | 425 | rig_path = self.outRig_line.text() 426 | jnt_path = self.outJnt_line.text() 427 | train_poses = int(self.numPoses_line.text()) 428 | 429 | if not uiUtils.check_dir_path(path=rig_path) or not uiUtils.check_dir_path(path=jnt_path): 430 | return 431 | 432 | if not uiUtils.check_train_data(rig_input_data, "control rig") or not uiUtils.check_train_data(jnt_input_data, "joint"): 433 | return 434 | 435 | import system.generate_train_data as gen_train_data 436 | gen_train_data.generate_data(rig_input_data=rig_input_data, jnt_input_data=jnt_input_data, 437 | rig_out=rig_path, jnt_out=jnt_path, train_poses=train_poses) 438 | 439 | 440 | def train_model(self): 441 | rig_path = self.outRig_line.text() 442 | jnt_path = self.outJnt_line.text() 443 | model_path = self.outModel_line.text() 444 | py_app = self.pyApp_line.text() 445 | 446 | if not uiUtils.check_file_path(path=rig_path) or not uiUtils.check_file_path(path=jnt_path): 447 | return 448 | 449 | learning_rate = float(self.lr_line.text()) 450 | epochs = int(self.epoch_line.text()) 451 | force_cpu = self.forceCPU_checkBox.isChecked() 452 | 453 | import system.train_model as train_model 454 | train_model.train_model(py_app=py_app, rig_path=rig_path, jnt_path=jnt_path, model_path=model_path, 455 | lr=learning_rate, epochs=epochs, force_cpu=force_cpu) 456 | 457 | 458 | 459 | class PredictWidget(QtWidgets.QWidget): 460 | def __init__(self, parent=None): 461 | super(PredictWidget, self).__init__(parent) 462 | 463 | self.create_widgets() 464 | self.create_layouts() 465 | self.create_connections() 466 | 467 | self.anim_tree.checkIfEmpty() 468 | 469 | 470 | def create_widgets(self): 471 | # anim widgets 472 | self.anim_param_label = QtWidgets.QLabel("Animation Parameters (0)") 473 | self.anim_param_label.setStyleSheet("color: #00ff6e;") 474 | self.anim_clear_btn = QtWidgets.QPushButton("Clear All") 475 | self.anim_add_btn = QtWidgets.QPushButton("Add") 476 | 477 | # anim tree view widget 478 | anim_msg = 'Add animated parameters using the "Add" button.\nMake sure you have at least one joint selected.' 479 | self.anim_tree = uiUtils.PlaceholderTreeWidget(self, anim_msg) 480 | self.anim_tree.setHeaderLabels(['Joint Name']) 481 | self.anim_tree.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) 482 | self.anim_tree.setItemDelegate(uiUtils.EditableItemDelegate(self.anim_tree)) 483 | self.anim_tree.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) 484 | 485 | anim_header = self.anim_tree.header() 486 | anim_header.setStretchLastSection(False) 487 | anim_header.setSectionResizeMode(0, QtWidgets.QHeaderView.Stretch) 488 | 489 | 490 | 491 | # predict settings 492 | self.setting_widget = QtWidgets.QWidget() 493 | 494 | self.outRig_label = QtWidgets.QLabel("Control Rig Data:") 495 | self.outRig_line = QtWidgets.QLineEdit() 496 | self.outRig_line.setText(pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "training_data/rig/irm_rig_data.csv").as_posix()) 497 | self.outRig_line.setReadOnly(True) 498 | self.outRig_btn = QtWidgets.QPushButton("<") 499 | self.outRig_btn.setFixedSize(20,20) 500 | 501 | self.outJnt_label = QtWidgets.QLabel("Joint Data:") 502 | self.outJnt_line = QtWidgets.QLineEdit() 503 | self.outJnt_line.setText(pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "training_data/jnt/irm_jnt_data.csv").as_posix()) 504 | self.outJnt_line.setReadOnly(True) 505 | self.outJnt_btn = QtWidgets.QPushButton("<") 506 | self.outJnt_btn.setFixedSize(20,20) 507 | 508 | self.outSpacer01_label = QtWidgets.QLabel() 509 | self.outSpacer01_label.setFixedHeight(10) 510 | 511 | self.outModel_label = QtWidgets.QLabel("Trained Model:") 512 | self.outModel_line = QtWidgets.QLineEdit() 513 | self.outModel_line.setText(pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "trained_model/trained_model.pt").as_posix()) 514 | self.outModel_line.setReadOnly(True) 515 | self.outModel_btn = QtWidgets.QPushButton("<") 516 | self.outModel_btn.setFixedSize(20,20) 517 | 518 | self.pyApp_label = QtWidgets.QLabel("Python Exe:") 519 | self.pyApp_line = QtWidgets.QLineEdit() 520 | py_path = pathlib.Path(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent 521 | self.pyApp_line.setText(pathlib.PurePath(py_path, pathlib.Path("venv/Scripts/python.exe")).as_posix()) 522 | self.pyApp_line.setReadOnly(True) 523 | self.pyApp_btn = QtWidgets.QPushButton("<") 524 | self.pyApp_btn.setFixedSize(20,20) 525 | 526 | self.outSpacer02_label = QtWidgets.QLabel() 527 | self.outSpacer02_label.setFixedHeight(500) 528 | 529 | # buttons 530 | self.predict_btn = QtWidgets.QPushButton("Map Prediction") 531 | 532 | 533 | def create_layouts(self): 534 | # main UI layout 535 | main_layout = QtWidgets.QHBoxLayout() 536 | main_layout.setContentsMargins(3,3,3,3) 537 | 538 | 539 | # left UI side - anim parameters 540 | tree_widget = QtWidgets.QWidget() 541 | tree_layout = QtWidgets.QVBoxLayout() 542 | tree_widget.setLayout(tree_layout) 543 | 544 | 545 | # anim layout 546 | anim_group = QtWidgets.QGroupBox() 547 | tree_layout.addWidget(anim_group) 548 | anim_attr_layout = QtWidgets.QVBoxLayout() 549 | anim_group.setLayout(anim_attr_layout) 550 | 551 | anim_btn_layout = QtWidgets.QHBoxLayout() 552 | anim_btn_layout.addWidget(self.anim_param_label) 553 | anim_btn_layout.addWidget(self.anim_clear_btn) 554 | anim_btn_layout.addWidget(self.anim_add_btn) 555 | anim_attr_layout.addLayout(anim_btn_layout) 556 | 557 | anim_attr_layout.addWidget(self.anim_tree) 558 | 559 | 560 | # right UI side - settings for prediction 561 | settings_widget = QtWidgets.QWidget() 562 | settings_layout = QtWidgets.QVBoxLayout() 563 | settings_widget.setLayout(settings_layout) 564 | 565 | predict_widget = QtWidgets.QGroupBox("Predict Settings") 566 | predict_widget.setStyleSheet("QGroupBox { color: #00ff6e; }") 567 | predict_layout = QtWidgets.QGridLayout() 568 | predict_widget.setLayout(predict_layout) 569 | 570 | predictAlign_widget = QtWidgets.QWidget() 571 | predictAlign_widget.setFixedWidth(110) 572 | predictAlign_widget.setFixedHeight(1) 573 | predict_layout.addWidget(predictAlign_widget, 0, 0) 574 | 575 | predictAlign_widget = QtWidgets.QWidget() 576 | predictAlign_widget.setFixedWidth(110) 577 | predictAlign_widget.setFixedHeight(1) 578 | predict_layout.addWidget(predictAlign_widget, 0, 0) 579 | 580 | predict_layout.addWidget(self.outRig_label, 1, 0) 581 | predict_layout.addWidget(self.outRig_line, 1, 1) 582 | predict_layout.addWidget(self.outRig_btn, 1, 2) 583 | 584 | predict_layout.addWidget(self.outJnt_label, 2, 0) 585 | predict_layout.addWidget(self.outJnt_line, 2, 1) 586 | predict_layout.addWidget(self.outJnt_btn, 2, 2) 587 | 588 | predict_layout.addWidget(self.outSpacer01_label, 3, 0) 589 | 590 | predict_layout.addWidget(self.outModel_label, 4, 0) 591 | predict_layout.addWidget(self.outModel_line, 4, 1) 592 | predict_layout.addWidget(self.outModel_btn, 4, 2) 593 | 594 | predict_layout.addWidget(self.pyApp_label, 5, 0) 595 | predict_layout.addWidget(self.pyApp_line, 5, 1) 596 | predict_layout.addWidget(self.pyApp_btn, 5, 2) 597 | 598 | predict_layout.addWidget(self.outSpacer02_label, 6, 0) 599 | 600 | settings_layout.addWidget(predict_widget) 601 | 602 | # button layout 603 | outBtn_layout = QtWidgets.QHBoxLayout() 604 | outBtn_layout.addWidget(self.predict_btn) 605 | settings_layout.addLayout(outBtn_layout) 606 | 607 | 608 | # main layout splitter 609 | main_splitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal) 610 | main_splitter.addWidget(tree_widget) 611 | main_splitter.addWidget(settings_widget) 612 | main_splitter.setSizes([7000, 3000]) 613 | main_layout.addWidget(main_splitter) 614 | self.setLayout(main_layout) 615 | 616 | 617 | def create_connections(self): 618 | # anim UI 619 | self.anim_clear_btn.clicked.connect(partial(uiUtils.clear_tree, self.anim_tree, self.anim_param_label)) 620 | self.anim_add_btn.clicked.connect(partial(self.add_tree_item, self.anim_tree, self.anim_param_label, jnt_mode=True)) 621 | 622 | self.anim_tree.customContextMenuRequested.connect(self.show_anim_context_menu) 623 | self.anim_tree.itemChanged.connect(self.anim_tree.checkIfEmpty) 624 | 625 | # predict settings 626 | self.outRig_btn.clicked.connect(self.set_rigData_outPath) 627 | self.outJnt_btn.clicked.connect(self.set_jntData_outPath) 628 | self.outModel_btn.clicked.connect(self.set_model_outPath) 629 | self.pyApp_btn.clicked.connect(self.set_pyApp_outPath) 630 | 631 | self.predict_btn.clicked.connect(self.map_prediction) 632 | 633 | 634 | def set_rigData_outPath(self): 635 | uiUtils.openFileDialog(self, self.outRig_line, "Load Control Rig Train Data", "csv") 636 | 637 | def set_jntData_outPath(self): 638 | uiUtils.openFileDialog(self, self.outJnt_line, "Load Joint Train Data", "csv") 639 | 640 | def set_model_outPath(self): 641 | uiUtils.openFileDialog(self, self.outModel_line, "Load Trained Model", "pt") 642 | 643 | def set_pyApp_outPath(self): 644 | uiUtils.openFileDialog(self, self.pyApp_line, "Set Python Interpreter", "exe") 645 | 646 | def add_tree_item(self, treeWidget, label, jnt_mode): 647 | uiUtils.add_selection(treeWidget, jnt_mode) 648 | uiUtils.update_param_label(treeWidget, label) 649 | 650 | def show_anim_context_menu(self, pos): 651 | uiUtils.show_context_menu(self, pos, self.anim_tree) 652 | self.anim_tree.checkIfEmpty() 653 | uiUtils.update_param_label(self.anim_tree, self.anim_param_label) 654 | 655 | 656 | def map_prediction(self): 657 | anim_input_data = uiUtils.get_treeItems_as_dict(treeWidget=self.anim_tree) 658 | jnt_path = self.outJnt_line.text() 659 | rig_path = self.outRig_line.text() 660 | model_path = self.outModel_line.text() 661 | py_app = self.pyApp_line.text() 662 | 663 | if not uiUtils.check_file_path(path=jnt_path) or not uiUtils.check_file_path(path=model_path) or not uiUtils.check_file_path(path=rig_path) or not uiUtils.check_file_path(path=py_app): 664 | return 665 | 666 | if not uiUtils.check_train_data(anim_input_data, "animation"): 667 | return 668 | 669 | import system.apply_prediction as apply_prediction 670 | apply_prediction.map_data(anim_input_data, jnt_path, rig_path, model_path, py_app) 671 | 672 | 673 | 674 | class IRM_UI(QtWidgets.QDialog): 675 | 676 | dlg_instance = None 677 | 678 | @classmethod 679 | def show_dialog(cls): 680 | if not cls.dlg_instance: 681 | cls.dlg_instance = IRM_UI() 682 | 683 | if cls.dlg_instance.isHidden(): 684 | cls.dlg_instance.show() 685 | else: 686 | cls.dlg_instance.raise_() 687 | cls.dlg_instance.activateWindow() 688 | 689 | 690 | def __init__(self, parent=uiUtils.maya_main_window()): 691 | super(IRM_UI, self).__init__(parent) 692 | 693 | self.setWindowTitle("Inverse Rig Mapping Tool") 694 | self.resize(1000, 700) 695 | 696 | self.setWindowFlags(self.windowFlags() ^ QtCore.Qt.WindowContextHelpButtonHint) 697 | self.setWindowFlags(self.windowFlags() | QtCore.Qt.WindowMinimizeButtonHint) 698 | 699 | self.new_window = None 700 | 701 | self.create_widgets() 702 | self.create_layouts() 703 | self.create_connections() 704 | 705 | 706 | def create_widgets(self): 707 | # menu bar 708 | self.menuBar = QtWidgets.QMenuBar() 709 | self.file_menu = QtWidgets.QMenu("File", self) 710 | self.menu_load = QtWidgets.QAction("Load Config", self) 711 | self.menu_save = QtWidgets.QAction("Save Config", self) 712 | self.menu_recent = QtWidgets.QMenu("Recent Configs", self) 713 | 714 | self.file_menu.addAction(self.menu_load) 715 | self.file_menu.addAction(self.menu_save) 716 | self.file_menu.addSeparator() 717 | #self.file_menu.addMenu(self.menu_recent) 718 | self.menuBar.addMenu(self.file_menu) 719 | 720 | # tab widgets 721 | self.dataGen_wdg = DataGenWidget() 722 | self.predict_wdg = PredictWidget() 723 | 724 | self.tab_widget = QtWidgets.QTabWidget(self) 725 | self.tab_widget.addTab(self.dataGen_wdg, "Training Setup") 726 | self.tab_widget.addTab(self.predict_wdg, "Predict Animation") 727 | 728 | 729 | def create_layouts(self): 730 | main_layout = QtWidgets.QVBoxLayout(self) 731 | main_layout.setContentsMargins(3,3,3,3) 732 | main_layout.addWidget(self.menuBar) 733 | main_layout.addWidget(self.tab_widget) 734 | 735 | 736 | def create_connections(self): 737 | self.menu_load.triggered.connect(self.load_config) 738 | self.menu_save.triggered.connect(self.save_config) 739 | 740 | 741 | def load_config(self): 742 | file_name, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open IRM Config", QtCore.QStandardPaths.writableLocation(QtCore.QStandardPaths.DocumentsLocation), "JSON Files (*.json)") 743 | if file_name: 744 | pass 745 | 746 | 747 | def save_config(self): 748 | file_name, _ = QtWidgets.QFileDialog.getSaveFileName(self, "Save IRM Config", QtCore.QStandardPaths.writableLocation(QtCore.QStandardPaths.DocumentsLocation), "JSON Files (*.json)") 749 | if file_name: 750 | pass 751 | 752 | 753 | 754 | 755 | try: 756 | irm_dialog.close() 757 | irm_dialog.deleteLater() 758 | except: 759 | pass 760 | 761 | irm_dialog = IRM_UI() 762 | irm_dialog.show() 763 | -------------------------------------------------------------------------------- /system/prep_anim_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import maya.cmds as cmds 13 | import pymel.core as pm 14 | import os 15 | import csv 16 | import maya.api.OpenMaya as om 17 | import pathlib 18 | 19 | 20 | 21 | def prep_data(anim_input_data, jnt_path): 22 | jnt_list = [jnt for jnt in anim_input_data.keys()] 23 | jnt_list.sort() 24 | 25 | with open(jnt_path, "r") as f: 26 | reader = csv.reader(f) 27 | anim_header = next(reader, None) # get header 28 | jnt_data = [row for row in reader if len(row) != 0] 29 | 30 | jnt_names = list(set([entry[1] for entry in jnt_data])) # get unique jnt names 31 | jnt_dimension_list = [entry[2] for entry in jnt_data[:len(jnt_names)]] 32 | 33 | 34 | anim_data = [] 35 | current_frame = cmds.currentTime(q=1) 36 | frames = [] 37 | for jnt in jnt_list: 38 | frames.extend(list(set(cmds.keyframe(jnt, q=1)))) 39 | frames = list(set(frames)) 40 | 41 | for frame_index, frame in enumerate(frames): 42 | for i, jnt in enumerate(jnt_list): 43 | cmds.currentTime(frame) 44 | frame_data = [i, jnt, jnt_dimension_list[i]] 45 | frame_data.extend(["n/a" for i in range(len(anim_header)-3)]) 46 | 47 | train_values = jnt_data[i][3:] 48 | attr_list = [anim_header[train_values.index(value) + 3] for value in train_values if value != "n/a"] 49 | 50 | for attr in attr_list: 51 | if not "rotMtx_" in attr: 52 | # replace only used attr of ctrls in n/a list, rest stays at n/a 53 | frame_data[anim_header.index(attr)] = cmds.getAttr("{}.{}".format(jnt, attr)) 54 | 55 | rotation = [attr for attr in attr_list if "rotMtx_" in attr] 56 | if rotation: 57 | jnt_mtx = pm.dt.TransformationMatrix(cmds.xform(jnt, m=1, q=1, os=1)) 58 | jnt_rot_mtx3 = [x for mtx in jnt_mtx.asRotateMatrix()[:-1] for x in mtx[:-1]] 59 | 60 | start_index = anim_header.index("rotMtx_00") # get index of first rotMtx entry in anim_header and start replacing rotMtx values from there 61 | for mtx_index, rot_mtx in enumerate(jnt_rot_mtx3): 62 | frame_data[start_index + mtx_index] = rot_mtx 63 | 64 | anim_data.append(frame_data) 65 | cmds.progressWindow(edit=True, progress=(frame_index/len(frames))*100, status=('Getting Animation Data...')) 66 | 67 | 68 | cmds.currentTime(current_frame) 69 | 70 | # save anim data as json 71 | anim_path = pathlib.PurePath(pathlib.PurePath(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent, "anim_data/irm_anim_data.csv") 72 | with open(anim_path, "w") as f: 73 | writer = csv.writer(f) 74 | writer.writerow(anim_header) 75 | writer.writerows(anim_data) 76 | 77 | return anim_path.as_posix(), frames 78 | 79 | -------------------------------------------------------------------------------- /system/train_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import subprocess 13 | import os 14 | import pathlib 15 | import maya.api.OpenMaya as om 16 | import maya.cmds as cmds 17 | 18 | 19 | def train_model(py_app, rig_path, jnt_path, model_path, lr=0.01, epochs=100, force_cpu=False): 20 | py_path = pathlib.Path(os.path.normpath(os.path.dirname(os.path.realpath(__file__)))).parent 21 | py_cmd = f"import sys; sys.path.append('{py_path}'); import system.gpr_model as gpr; gpr.train_model('{rig_path}', '{jnt_path}', '{model_path}', {lr}, {epochs}, {force_cpu})" 22 | command = [py_app, "-u", "-c", py_cmd] 23 | 24 | # Initialize the progress window 25 | cmds.progressWindow(title='Training Model', progress=0, status='Initialize Training...', isInterruptable=True) 26 | 27 | # start subprocess; prevent output window 28 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, 29 | text=True, creationflags=subprocess.CREATE_NO_WINDOW) 30 | 31 | # update progress bar 32 | for line in iter(process.stdout.readline, ""): 33 | if line.startswith("Iter "): 34 | current_epoch = str(line.split("/")[0].partition("Iter ")[2]) 35 | current_loss = str(line.split("Loss: ")[1]) 36 | 37 | if line.startswith("PROGRESS "): 38 | progress = float(line.strip().split(" ")[1]) 39 | if cmds.progressWindow(query=True, isCancelled=True): 40 | process.kill() 41 | break 42 | cmds.progressWindow(edit=True, progress=progress, status=(f'Epochs: {current_epoch}/{epochs} - Loss: {current_loss}')) 43 | 44 | process.wait() 45 | 46 | cmds.progressWindow(endProgress=True) 47 | 48 | if process.returncode: 49 | error_message = process.stderr.read() 50 | raise ValueError(error_message) 51 | else: 52 | om.MGlobal.displayInfo(f"Model trained successfully in {epochs} epochs and end loss of {current_loss}") 53 | -------------------------------------------------------------------------------- /trained_model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /training_data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /training_data/jnt/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /training_data/rig/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ -------------------------------------------------------------------------------- /utils/data_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | def build_header(base_header=["No.", "rigName", "dimension"], attr_list=[]): 13 | 14 | for attr in attr_list: 15 | if attr in ["translateX", "translateY", "translateZ"]: 16 | base_header.append(attr) 17 | 18 | if "rot_mtx" in attr_list: 19 | base_header.extend(["rotMtx_00", "rotMtx_01", "rotMtx_02", 20 | "rotMtx_10", "rotMtx_11", "rotMtx_12", 21 | "rotMtx_20", "rotMtx_21", "rotMtx_22"]) 22 | 23 | for attr in attr_list: 24 | if attr in ["scaleX", "scaleY", "scaleZ"]: 25 | base_header.append(attr) 26 | 27 | base_header.extend([attr for attr in attr_list if not attr in ["translateX", "translateY", "translateZ", "rot_mtx", "scaleX", "scaleY", "scaleZ"]]) 28 | 29 | return base_header 30 | 31 | 32 | def get_attr_dimension(attr_list, rotation=True): 33 | # set dimension to length of attr_list; if rotate in list, remove rotate and add matrix3 (9 dimension) 34 | if rotation: 35 | return len([attr for attr in attr_list if not attr in ["rotateX", "rotateY", "rotateZ"]]) + 9 36 | else: 37 | return len(attr_list) 38 | 39 | 40 | def check_for_rotation(attr_list): 41 | if ("rotateX" or "rotateY" or "rotateZ") in attr_list: 42 | return True 43 | else: 44 | return False 45 | -------------------------------------------------------------------------------- /utils/maya.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import maya.cmds as cmds 13 | import pymel.core as pm 14 | import maya.api.OpenMaya as om 15 | import math 16 | import os 17 | import random 18 | import csv 19 | import pathlib 20 | 21 | 22 | def get_all_attributes(obj, unlocked=True): 23 | ''' 24 | return all attributes of obj if attribute is unlocked, keyable, scalar and not of type enum or bool 25 | ''' 26 | return [attr for attr in cmds.listAttr(obj, unlocked=unlocked, keyable=True, scalar=True) if not cmds.attributeQuery(attr, node=obj, at=1) in ["enum", "bool"]] 27 | 28 | def filter_attrs_from_dict(dict_entry): 29 | return [attr[0] for attr in dict_entry] 30 | 31 | 32 | def restore_defaults(ctrl): 33 | ''' 34 | set all attributes on obj to default values 35 | ''' 36 | for attr in get_all_attributes(ctrl): 37 | cmds.setAttr("{}.{}".format(ctrl, attr), cmds.attributeQuery(attr, node=ctrl, ld=1)[0]) 38 | 39 | 40 | def check_source_connection(obj, attr): 41 | ''' 42 | check given attribute for incoming connections 43 | ''' 44 | connection = cmds.listConnections("{}.{}".format(obj, attr), s=1, d=0, p=1) 45 | 46 | # check if attr is part of a compound one and check the compound one as well for any connections 47 | if not connection: 48 | attr_parent = cmds.attributeQuery(attr, node=obj, lp=1)[0] 49 | if attr_parent: 50 | connection = cmds.listConnections("{}.{}".format(obj, attr_parent), s=1, d=0, p=1) 51 | 52 | return connection 53 | 54 | 55 | def check_transformLimit(ctrl, axis="rotateX", default_min=-180, default_max=180): 56 | limit_dict = {"translateX":[1,0,0, 0,0,0, 0,0,0], "translateY":[0,1,0, 0,0,0, 0,0,0], "translateZ":[0,0,1, 0,0,0, 0,0,0], 57 | "rotateX":[0,0,0, 1,0,0, 0,0,0], "rotateY":[0,0,0, 0,1,0, 0,0,0], "rotateZ":[0,0,0, 0,0,1, 0,0,0], 58 | "scaleX":[0,0,0, 0,0,0, 1,0,0], "scaleY":[0,0,0, 0,0,0, 0,1,0], "scaleZ":[0,0,0, 0,0,0, 0,0,1]} 59 | 60 | if cmds.transformLimits(ctrl, q=1, etx=limit_dict[axis][0], ety=limit_dict[axis][1], etz=limit_dict[axis][2], 61 | erx=limit_dict[axis][3], ery=limit_dict[axis][4], erz=limit_dict[axis][5], 62 | esx=limit_dict[axis][6], esy=limit_dict[axis][7], esz=limit_dict[axis][8])[0]: 63 | limit_min = cmds.transformLimits(ctrl, q=1, tx=limit_dict[axis][0], ty=limit_dict[axis][1], tz=limit_dict[axis][2], 64 | rx=limit_dict[axis][3], ry=limit_dict[axis][4], rz=limit_dict[axis][5], 65 | sx=limit_dict[axis][6], sy=limit_dict[axis][7], sz=limit_dict[axis][8])[0] 66 | else: 67 | limit_min = default_min 68 | 69 | if cmds.transformLimits(ctrl, q=1, etx=limit_dict[axis][0], ety=limit_dict[axis][1], etz=limit_dict[axis][2], 70 | erx=limit_dict[axis][3], ery=limit_dict[axis][4], erz=limit_dict[axis][5], 71 | esx=limit_dict[axis][6], esy=limit_dict[axis][7], esz=limit_dict[axis][8])[1]: 72 | limit_max = cmds.transformLimits(ctrl, q=1, tx=limit_dict[axis][0], ty=limit_dict[axis][1], tz=limit_dict[axis][2], 73 | rx=limit_dict[axis][3], ry=limit_dict[axis][4], rz=limit_dict[axis][5], 74 | sx=limit_dict[axis][6], sy=limit_dict[axis][7], sz=limit_dict[axis][8])[1] 75 | else: 76 | limit_max = default_max 77 | 78 | return limit_min, limit_max 79 | 80 | 81 | def query_visibility(obj): # check obj parents for vis flag 82 | if cmds.getAttr("{}.v".format(obj)): 83 | while cmds.listRelatives(obj, p=1): 84 | parent = cmds.listRelatives(obj, p=1)[0] 85 | if cmds.getAttr("{}.v".format(parent)): 86 | obj = parent 87 | else: 88 | return False 89 | return True 90 | else: 91 | return False 92 | 93 | 94 | -------------------------------------------------------------------------------- /utils/pytorch.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | import torch 13 | import numpy as np 14 | 15 | def normalize_tensor(tensor, min_val, max_val, mean_val): 16 | eps = 1e-7 17 | return 2 * (tensor - mean_val) / (max_val - min_val + eps) 18 | 19 | 20 | def denormalize_tensor(tensor, min_val, max_val, mean_val): 21 | return tensor * (max_val - min_val) / 2 + mean_val 22 | 23 | 24 | def calculate_min_max_mean(tensor): 25 | min_val, _ = torch.min(tensor, dim=0, keepdim=True) 26 | max_val, _ = torch.max(tensor, dim=0, keepdim=True) 27 | mean_val = torch.mean(tensor, dim=0, keepdim=True) 28 | 29 | return min_val, max_val, mean_val 30 | 31 | 32 | def rotation_matrix_to_quaternion(rot_matrix): 33 | trace = torch.trace(rot_matrix) 34 | 35 | if trace > 0: 36 | s = torch.sqrt(trace + 1.0) * 2 37 | quat_w = 0.25 * s 38 | quat_x = (rot_matrix[2,1] - rot_matrix[1,2]) / s 39 | quat_y = (rot_matrix[0,2] - rot_matrix[2,0]) / s 40 | quat_z = (rot_matrix[1,0] - rot_matrix[0,1]) / s 41 | else: 42 | diagonal_elements = torch.tensor([rot_matrix[0,0], rot_matrix[1,1], rot_matrix[2,2]]) 43 | max_index = torch.argmax(diagonal_elements).item() 44 | if max_index == 0: 45 | s = torch.sqrt(1.0 + rot_matrix[0,0] - rot_matrix[1,1] - rot_matrix[2,2]) * 2 46 | quat_w = (rot_matrix[2,1] - rot_matrix[1,2]) / s 47 | quat_x = 0.25 * s 48 | quat_y = (rot_matrix[0,1] + rot_matrix[1,0]) / s 49 | quat_z = (rot_matrix[0,2] + rot_matrix[2,0]) / s 50 | elif max_index == 1: 51 | s = torch.sqrt(1.0 + rot_matrix[1,1] - rot_matrix[0,0] - rot_matrix[2,2]) * 2 52 | quat_w = (rot_matrix[0,2] - rot_matrix[2,0]) / s 53 | quat_x = (rot_matrix[0,1] + rot_matrix[1,0]) / s 54 | quat_y = 0.25 * s 55 | quat_z = (rot_matrix[1,2] + rot_matrix[2,1]) / s 56 | else: 57 | s = torch.sqrt(1.0 + rot_matrix[2,2] - rot_matrix[0,0] - rot_matrix[1,1]) * 2 58 | quat_w = (rot_matrix[1,0] - rot_matrix[0,1]) / s 59 | quat_x = (rot_matrix[0,2] + rot_matrix[2,0]) / s 60 | quat_y = (rot_matrix[1,2] + rot_matrix[2,1]) / s 61 | quat_z = 0.25 * s 62 | 63 | quaternion = torch.tensor([quat_w, quat_x, quat_y, quat_z]) 64 | 65 | return quaternion 66 | 67 | 68 | def batch_rotation_matrix_to_quaternion(rot_matrices): 69 | quaternions = [] 70 | for i in range(rot_matrices.shape[0]): 71 | q = rotation_matrix_to_quaternion(rot_matrices[i]) 72 | quaternions.append(q) 73 | return torch.stack(quaternions) 74 | 75 | 76 | def quaternion_to_rotation_matrix(quat): 77 | # Normalize the quaternion to unit length 78 | quat = quat / torch.sqrt(torch.sum(quat**2)) 79 | 80 | w, x, y, z = quat[0], quat[1], quat[2], quat[3] 81 | 82 | # Compute the rotation matrix elements 83 | r00 = 1 - 2*(y**2 + z**2) 84 | r01 = 2*(x*y - z*w) 85 | r02 = 2*(x*z + y*w) 86 | 87 | r10 = 2*(x*y + z*w) 88 | r11 = 1 - 2*(x**2 + z**2) 89 | r12 = 2*(y*z - x*w) 90 | 91 | r20 = 2*(x*z - y*w) 92 | r21 = 2*(y*z + x*w) 93 | r22 = 1 - 2*(x**2 + y**2) 94 | 95 | rotation_matrix = torch.tensor([[r00, r01, r02], 96 | [r10, r11, r12], 97 | [r20, r21, r22]]) 98 | 99 | return rotation_matrix 100 | 101 | 102 | def batch_quaternion_to_rotation_matrix(quaternions): 103 | rotation_matrices = [] 104 | for i in range(quaternions.shape[0]): 105 | R = quaternion_to_rotation_matrix(quaternions[i]) 106 | rotation_matrices.append(R) 107 | return torch.stack(rotation_matrices) 108 | 109 | 110 | def matrix_to_6d(rot_mat): 111 | # Use the first two columns of the rotation matrix to get the 6D representation 112 | return rot_mat[:, :2].reshape(-1) 113 | 114 | 115 | def _6d_to_matrix(rot_6d): 116 | # Reshape the 6D representation back to a 3x2 matrix 117 | mat = rot_6d.view(-1, 3, 2) 118 | 119 | # Calculate the third column of the rotation matrix as the cross product of the first two columns 120 | third_col = torch.cross(mat[:, :, 0], mat[:, :, 1]).unsqueeze(2) 121 | 122 | # Construct the full rotation matrix 123 | return torch.cat((mat, third_col), dim=2) 124 | -------------------------------------------------------------------------------- /utils/ui.py: -------------------------------------------------------------------------------- 1 | """ 2 | ----------------------------------------------------------------------------- 3 | This file has been developed within the scope of the 4 | Technical Director course at Filmakademie Baden-Wuerttemberg. 5 | http://technicaldirector.de 6 | 7 | Written by Lukas Kapp 8 | Copyright (c) 2023 Animationsinstitut of Filmakademie Baden-Wuerttemberg 9 | ----------------------------------------------------------------------------- 10 | """ 11 | 12 | from PySide2 import QtCore, QtWidgets, QtGui 13 | from shiboken2 import wrapInstance 14 | import maya.OpenMayaUI as omui 15 | import maya.OpenMaya as om 16 | import maya.cmds as cmds 17 | from functools import partial 18 | import os 19 | from imp import reload 20 | 21 | import utils.maya as mUtils 22 | reload(mUtils) 23 | 24 | 25 | def maya_main_window(): 26 | """ 27 | Return the Maya main window widget as a Python object 28 | """ 29 | main_window_ptr = omui.MQtUtil.mainWindow() 30 | return wrapInstance(int(main_window_ptr), QtWidgets.QWidget) 31 | 32 | 33 | class EditableItemDelegate(QtWidgets.QItemDelegate): 34 | def setModelData(self, editor, model, index): 35 | text = editor.text().strip() # Remove leading/trailing spaces 36 | if not text or text == index.data(): # Do not update if text is empty or same as before 37 | return 38 | super().setModelData(editor, model, index) 39 | 40 | 41 | class PlaceholderTreeWidget(QtWidgets.QTreeWidget): 42 | def __init__(self, parent=None, label_msg="No items added"): 43 | super(PlaceholderTreeWidget, self).__init__(parent) 44 | self.emptyLabel = QtWidgets.QLabel(label_msg, self) 45 | self.emptyLabel.setAlignment(QtCore.Qt.AlignCenter) 46 | self.emptyLabel.hide() 47 | 48 | def resizeEvent(self, event): 49 | super(PlaceholderTreeWidget, self).resizeEvent(event) 50 | header_height = self.header().height() 51 | self.emptyLabel.setGeometry(0, header_height, self.width(), self.height() - header_height) 52 | 53 | def checkIfEmpty(self): 54 | if self.topLevelItemCount() == 0: 55 | self.emptyLabel.show() 56 | else: 57 | self.emptyLabel.hide() 58 | 59 | 60 | class UnmovableSplitterHandle(QtWidgets.QSplitterHandle): 61 | def __init__(self, orientation, parent): 62 | super(UnmovableSplitterHandle, self).__init__(orientation, parent) 63 | 64 | def mouseMoveEvent(self, event): 65 | pass 66 | 67 | class UnmovableSplitter(QtWidgets.QSplitter): 68 | def __init__(self, orientation, parent=None): 69 | super(UnmovableSplitter, self).__init__(orientation, parent) 70 | 71 | def createHandle(self): 72 | return UnmovableSplitterHandle(self.orientation(), self) 73 | 74 | 75 | 76 | ### FUNCTIONS ### 77 | 78 | 79 | def add_selection(treeWidget, jnt_mode): 80 | if jnt_mode: 81 | sel = cmds.ls(sl=1, typ="joint") 82 | else: 83 | raw_sel = cmds.ls(sl=1, transforms=True) 84 | sel = [] 85 | for transform in raw_sel: 86 | child_nodes = cmds.listRelatives(transform, children=True, fullPath=True) or [] 87 | has_desired_types = any(cmds.nodeType(child) in ['nurbsCurve', 'mesh', 'nurbsSurface'] for child in child_nodes) 88 | if has_desired_types: 89 | sel.append(transform) 90 | for obj in sel: 91 | add_tree_item(treeWidget=treeWidget, name=obj, jnt_mode=jnt_mode) 92 | 93 | 94 | def add_tree_item(treeWidget, name, jnt_mode=False): 95 | root = treeWidget.invisibleRootItem() 96 | for i in range(root.childCount()): 97 | if root.child(i).text(0) == name: 98 | om.MGlobal.displayWarning(f"Item '{name}' already exists. Skipping...") 99 | return 100 | 101 | if jnt_mode: 102 | attr_list = [attr for attr in mUtils.get_all_attributes(name, unlocked=False) if mUtils.check_source_connection(name, attr)] 103 | else: 104 | attr_list = mUtils.get_all_attributes(name) 105 | 106 | if attr_list: 107 | parent = QtWidgets.QTreeWidgetItem(treeWidget) 108 | parent.setFlags(parent.flags() & ~QtCore.Qt.ItemIsEditable) 109 | parent.setText(0, name) 110 | font = parent.font(0) 111 | font.setBold(True) 112 | parent.setFont(0, font) 113 | 114 | 115 | for attr in attr_list: 116 | child = QtWidgets.QTreeWidgetItem(parent) 117 | parent.addChild(child) 118 | #child.setFlags(child.flags() | QtCore.Qt.ItemIsEditable) 119 | child.setText(0, attr) 120 | if treeWidget.columnCount() > 1: 121 | child.setText(1, "-50.000") 122 | child.setText(2, "50.000") 123 | 124 | treeWidget.expandItem(parent) 125 | 126 | 127 | def saveFileDialog(widget, lineEdit, dialog_header, file_types): 128 | start_dir = lineEdit.text() 129 | options = QtWidgets.QFileDialog.Options() 130 | fileName, _ = QtWidgets.QFileDialog.getSaveFileName(widget, dialog_header, start_dir, f"{file_types.upper()} Files (*.{file_types})", options=options) 131 | if fileName: 132 | if not fileName.endswith(f".{file_types}"): 133 | fileName += f".{file_types}" 134 | lineEdit.setText(fileName) 135 | 136 | 137 | def openFileDialog(widget, lineEdit, dialog_header, file_types): 138 | start_dir = lineEdit.text() 139 | options = QtWidgets.QFileDialog.Options() 140 | fileName, _ = QtWidgets.QFileDialog.getOpenFileName(widget, dialog_header, start_dir, f"{file_types.upper()} Files (*.{file_types})", options=options) 141 | if fileName: 142 | if not fileName.endswith(f".{file_types}"): 143 | fileName += f".{file_types}" 144 | lineEdit.setText(fileName) 145 | 146 | 147 | def update_param_label(treeWidget, label): 148 | count = 0 149 | for i in range(0, treeWidget.topLevelItemCount()): 150 | parent = treeWidget.topLevelItem(i) 151 | count += parent.childCount() 152 | 153 | label.setText("{}({})".format(label.text().rpartition("(")[0], count)) 154 | 155 | 156 | def clear_tree(treeWidget, label): 157 | treeWidget.clear() 158 | treeWidget.checkIfEmpty() 159 | update_param_label(treeWidget, label) 160 | 161 | 162 | def show_context_menu(widget, pos, treeWidget): 163 | menu = QtWidgets.QMenu(widget) 164 | delete_action = menu.addAction("Delete") 165 | delete_action.triggered.connect(partial(delete_items, treeWidget)) 166 | menu.exec_(treeWidget.viewport().mapToGlobal(pos)) 167 | 168 | 169 | def delete_items(treeWidget): 170 | selected_items = treeWidget.selectedItems() 171 | for item in selected_items: 172 | (item.parent() or treeWidget.invisibleRootItem()).removeChild(item) 173 | 174 | 175 | def get_treeItems_as_dict(treeWidget): 176 | item_dict = {} 177 | root = treeWidget.invisibleRootItem() 178 | for i in range(root.childCount()): 179 | parent_item = root.child(i) 180 | parent_name = parent_item.text(0) 181 | item_dict[parent_name] = [] 182 | get_treeChildren_as_list(parent_item, item_dict[parent_name]) 183 | 184 | return item_dict 185 | 186 | 187 | def get_treeChildren_as_list(parent, children_list): 188 | for i in range(parent.childCount()): 189 | child_item = parent.child(i) 190 | child_name = child_item.text(0) 191 | 192 | num_columns = child_item.columnCount() 193 | if num_columns == 3: 194 | child_min_range = child_item.text(1) 195 | child_max_range = child_item.text(2) 196 | children_list.append([child_name, child_min_range, child_max_range]) 197 | elif num_columns == 1: 198 | children_list.append([child_name]) 199 | 200 | get_treeChildren_as_list(child_item, children_list) 201 | 202 | 203 | def is_valid_dir(path): 204 | dir_path = os.path.dirname(path) 205 | return os.path.isdir(dir_path) 206 | 207 | def is_valid_file(path): 208 | return os.path.isfile(path) 209 | 210 | 211 | def check_dir_path(path): 212 | if not is_valid_dir(path): 213 | msg = QtWidgets.QMessageBox() 214 | msg.setIcon(QtWidgets.QMessageBox.Critical) 215 | msg.setText("Invalid file path: " + path) 216 | msg.setWindowTitle("File Error") 217 | msg.exec_() 218 | return False 219 | return True 220 | 221 | 222 | def check_train_data(data, data_type): 223 | parameters = [param for param in data.values() if param] 224 | if not parameters: 225 | msg = QtWidgets.QMessageBox() 226 | msg.setIcon(QtWidgets.QMessageBox.Critical) 227 | msg.setText(f"No {data_type} parameters found!") 228 | msg.setWindowTitle("Data Error") 229 | msg.exec_() 230 | return False 231 | return True 232 | 233 | 234 | def check_file_path(path): 235 | if not is_valid_file(path): 236 | msg = QtWidgets.QMessageBox() 237 | msg.setIcon(QtWidgets.QMessageBox.Critical) 238 | msg.setText("Invalid file path: " + path) 239 | msg.setWindowTitle("File Error") 240 | msg.exec_() 241 | return False 242 | return True 243 | --------------------------------------------------------------------------------