├── data └── .gitkeep ├── output └── .gitkeep ├── saved_weights └── .gitkeep ├── viz ├── .gitignore ├── src │ ├── helpers │ │ └── fetchJson.ts │ ├── style.css │ ├── index.ts │ └── renderBVH.ts ├── tsconfig.json ├── package.json └── webpack.config.js ├── generator_backend ├── requirements.txt ├── __init__.py └── model.py ├── generator_frontend ├── .gitignore ├── src │ ├── helpers │ │ └── fetchJson.ts │ ├── generate.ts │ ├── style.css │ ├── controls.ts │ ├── index.ts │ ├── renderBVH.ts │ └── motionEditor.ts ├── tsconfig.json ├── package.json └── webpack.config.js ├── static └── model.jpg ├── .gitignore ├── requirements.txt ├── util ├── plot.py ├── math.py ├── interpolation │ ├── fixed_points.py │ ├── interpolation_factory.py │ ├── linear_interpolation.py │ └── spherical_interpolation.py ├── read_config.py ├── smoothing │ └── moving_average_smoothing.py ├── load_data.py ├── lafan1.py ├── conversion.py ├── extract.py └── quaternions.py ├── constants.py ├── config └── default.yml ├── test_spherical_interpolation.py ├── model ├── encoding │ ├── linear_encoding.py │ ├── input_encoder.py │ ├── output_decoder.py │ └── positional_encoding.py ├── loss │ ├── fk_loss.py │ ├── npss_loss.py │ └── l2_loss.py └── transformer.py ├── train_stats.py ├── LICENSE ├── README.md ├── visualize.py ├── train.py └── evaluate.py /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /saved_weights/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viz/.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | node_modules -------------------------------------------------------------------------------- /generator_backend/requirements.txt: -------------------------------------------------------------------------------- 1 | flask == 2.1.* -------------------------------------------------------------------------------- /generator_frontend/.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | node_modules -------------------------------------------------------------------------------- /static/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pavi114/motion-completion-using-transformers/HEAD/static/model.jpg -------------------------------------------------------------------------------- /viz/src/helpers/fetchJson.ts: -------------------------------------------------------------------------------- 1 | export default async function (url: string): Promise { 2 | const response = await fetch(url); 3 | 4 | const json = await response.json(); 5 | 6 | return json; 7 | } -------------------------------------------------------------------------------- /generator_frontend/src/helpers/fetchJson.ts: -------------------------------------------------------------------------------- 1 | export default async function (url: string): Promise { 2 | const response = await fetch(url); 3 | 4 | const json = await response.json(); 5 | 6 | return json; 7 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data/* 2 | !/data/.gitkeep 3 | 4 | /saved_weights/* 5 | !/saved_weights/.gitkeep 6 | 7 | /output/* 8 | !/output/.gitkeep 9 | 10 | /config/* 11 | !/config/default.yml 12 | 13 | __pycache__ 14 | venv 15 | 16 | .vscode -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib == 3.5.* 2 | numpy == 1.22.* 3 | pillow == 9.0.* 4 | pyquaternion == 0.9.* 5 | pyyaml == 6.0.* 6 | torch == 1.11.* 7 | torchaudio == 0.11.* 8 | torchvision == 0.12.* 9 | tqdm == 4.62.* 10 | typing-extensions == 4.0.* -------------------------------------------------------------------------------- /util/plot.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import matplotlib.pyplot as plt 3 | 4 | def plot_loss(loss_history): 5 | plt.plot(loss_history, '-r', label='loss') 6 | plt.show() 7 | 8 | 9 | if __name__ == '__main__': 10 | loss_history = [1.2, 1, 0.4, 0.3, 0.1, 0.001, 0.0005] 11 | plot_loss(loss_history) -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | LAFAN1_DIRECTORY='data/lafan1' 4 | NUM_JOINTS = 22 5 | PARENTS = [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 11, 18, 19, 20] 6 | 7 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | MODEL_SAVE_DIRECTORY='saved_weights' 10 | OUTPUT_DIRECTORY='output' -------------------------------------------------------------------------------- /viz/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "./dist/", 4 | "noImplicitAny": true, 5 | "module": "es6", 6 | "target": "es5", 7 | "jsx": "react", 8 | "allowJs": true, 9 | "moduleResolution": "node", 10 | "resolveJsonModule": true, 11 | "allowSyntheticDefaultImports": true 12 | } 13 | } -------------------------------------------------------------------------------- /generator_frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "outDir": "./dist/", 4 | "noImplicitAny": true, 5 | "module": "es6", 6 | "target": "es5", 7 | "jsx": "react", 8 | "allowJs": true, 9 | "moduleResolution": "node", 10 | "resolveJsonModule": true, 11 | "allowSyntheticDefaultImports": true 12 | } 13 | } -------------------------------------------------------------------------------- /config/default.yml: -------------------------------------------------------------------------------- 1 | embedding_size: 2 | p: 3 3 | q: 3 4 | v: 3 5 | 6 | dataset: 7 | batch_size: 8 8 | files_to_read: 1 9 | keyframe_gap: 30 10 | num_workers: 4 11 | window_size: 64 12 | 13 | model: 14 | num_heads: 1 15 | num_encoder_layers: 1 16 | num_decoder_layers: 1 17 | dropout_p: 0.2 18 | 19 | hyperparameters: 20 | epochs: 1 21 | learning_rate: 1.0e-3 22 | interpolation: linear -------------------------------------------------------------------------------- /util/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | def round_tensor(x: Tensor, decimals: int = 0) -> Tensor: 5 | """Rounds the given tensor to `decimals` decimal places 6 | 7 | Args: 8 | x (Tensor): Tensor to round. 9 | decimals (int): Precision (Default = 0). 10 | 11 | Returns: 12 | Tensor: Rounded tensor 13 | """ 14 | return torch.round(x * (10 ** decimals)) / 10 ** decimals -------------------------------------------------------------------------------- /util/interpolation/fixed_points.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from torch import LongTensor 3 | 4 | from constants import DEVICE 5 | 6 | def get_fixed_points(window_size, keyframe_gap): 7 | fixed_points = list(range(0, window_size, keyframe_gap)) 8 | 9 | if (window_size - 1) % keyframe_gap != 0: 10 | fixed_points.append(window_size - 1) 11 | 12 | fixed_points = LongTensor(fixed_points).to(DEVICE) 13 | 14 | return fixed_points -------------------------------------------------------------------------------- /util/read_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import yaml 3 | 4 | def read_config(config_name='default') -> object: 5 | config = Path(f'./config/{config_name}.yml') 6 | 7 | if not config.exists(): 8 | raise Exception("ConfigNotExistsException") 9 | 10 | with config.open() as stream: 11 | try: 12 | return yaml.safe_load(stream) 13 | except yaml.YAMLError as exc: 14 | print(exc) 15 | -------------------------------------------------------------------------------- /viz/src/style.css: -------------------------------------------------------------------------------- 1 | body, html { 2 | margin: 0; 3 | padding: 0; 4 | width: 100vw; 5 | height: 100vh; 6 | overflow: hidden; 7 | } 8 | 9 | .container { 10 | width: 100%; 11 | height: 100%; 12 | 13 | display: flex; 14 | align-items: center; 15 | justify-content: space-evenly; 16 | 17 | background-color: red; 18 | } 19 | 20 | .canvas { 21 | width: 30%; 22 | height: 100%; 23 | border: 1px solid black; 24 | } -------------------------------------------------------------------------------- /util/interpolation/interpolation_factory.py: -------------------------------------------------------------------------------- 1 | from re import I 2 | 3 | from util.interpolation.spherical_interpolation import spherical_interpolation 4 | from .linear_interpolation import linear_interpolation 5 | 6 | interpolations = { 7 | 'linear': linear_interpolation, 8 | 'spherical': spherical_interpolation 9 | } 10 | 11 | def get_p_interpolation(interpolation: str): 12 | return linear_interpolation 13 | 14 | def get_q_interpolation(interpolation: str): 15 | return interpolations[interpolation] -------------------------------------------------------------------------------- /generator_frontend/src/generate.ts: -------------------------------------------------------------------------------- 1 | export default async function generateMotionSequence(gpos: number[][][]) { 2 | const response = await fetch('http://127.0.0.1:5000/generate', { 3 | method: 'POST', 4 | headers: { 5 | 'Accept': 'application/json', 6 | 'Content-Type': 'application/json' 7 | }, 8 | body: JSON.stringify({ 9 | gpos 10 | }), 11 | mode: 'cors' 12 | }) 13 | 14 | const json = await response.json() 15 | 16 | return json 17 | } -------------------------------------------------------------------------------- /test_spherical_interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from util.interpolation.spherical_interpolation import spherical_interpolation 4 | 5 | q = torch.Tensor([ 6 | [ 7 | [0, 0.408, 0.408, 0.816], 8 | [1, 0, 0, 0], 9 | ], 10 | [ 11 | [0, 0, 0, 0], 12 | [0, 0, 0, 0], 13 | ], 14 | [ 15 | [0, 0, 0, 0], 16 | [0, 0, 0, 0], 17 | ], 18 | [ 19 | [0, 1, 0, 0], 20 | [1, 0, 0, 0], 21 | ] 22 | ]) 23 | fixed_points = torch.LongTensor([0, 3]) 24 | 25 | print(q.shape) 26 | 27 | out = spherical_interpolation(q, -3, fixed_points) 28 | 29 | print(q, q.shape) 30 | 31 | print(out, out.shape) -------------------------------------------------------------------------------- /model/encoding/linear_encoding.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | 3 | class LinearEncoding(nn.Module): 4 | """nn.Module that performs linear encoding. 5 | Specify [input_size, hidden_size, output_size] for a 2 layer NN 6 | """ 7 | 8 | def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None: 9 | super(LinearEncoding, self).__init__() 10 | 11 | self.l1 = nn.Linear(in_features=input_size, out_features=hidden_size) 12 | self.l2 = nn.Linear(in_features=hidden_size, out_features=output_size) 13 | 14 | self.relu = nn.ReLU() 15 | 16 | def forward(self, x: Tensor) -> Tensor: 17 | x = self.l1(x) 18 | x = self.relu(x) 19 | x = self.l2(x) 20 | 21 | return x 22 | -------------------------------------------------------------------------------- /train_stats.py: -------------------------------------------------------------------------------- 1 | from genericpath import exists 2 | import pickle 3 | from constants import LAFAN1_DIRECTORY, OUTPUT_DIRECTORY 4 | 5 | from util.extract import get_train_stats 6 | 7 | def save_stats(): 8 | x_mean, x_std, _, _ = get_train_stats(LAFAN1_DIRECTORY, ['subject1', 'subject2', 'subject3', 'subject4']) 9 | 10 | with open(f'{OUTPUT_DIRECTORY}/stats.pkl', 'wb') as f: 11 | pickle.dump( 12 | {'x_mean': x_mean, 'x_std': x_std}, 13 | f, 14 | protocol=pickle.HIGHEST_PROTOCOL 15 | ) 16 | 17 | def load_stats(): 18 | p = f'{OUTPUT_DIRECTORY}/stats.pkl' 19 | 20 | if not exists(p): 21 | save_stats() 22 | 23 | with open(f'{OUTPUT_DIRECTORY}/stats.pkl', 'rb') as f: 24 | stats = pickle.load(f) 25 | 26 | return stats['x_mean'], stats['x_std'] 27 | 28 | if __name__ == '__main__': 29 | save_stats() -------------------------------------------------------------------------------- /viz/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "viz", 3 | "version": "1.0.0", 4 | "description": "This repository contains the code accompaniying the thesis project \"Transformer based Motion In-betweening\".", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1", 8 | "dev": "webpack serve", 9 | "build": "webpack", 10 | "start": "http-server ./dist" 11 | }, 12 | "author": "", 13 | "license": "ISC", 14 | "devDependencies": { 15 | "@types/three": "^0.137.0", 16 | "copy-webpack-plugin": "^10.2.4", 17 | "css-loader": "^6.6.0", 18 | "html-webpack-plugin": "^5.5.0", 19 | "http-server": "^14.1.0", 20 | "style-loader": "^3.3.1", 21 | "ts-loader": "^9.2.6", 22 | "typescript": "^4.5.5", 23 | "webpack": "^5.69.1", 24 | "webpack-cli": "^4.9.2", 25 | "webpack-dev-server": "^4.7.4" 26 | }, 27 | "dependencies": { 28 | "three": "^0.137.5" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /generator_frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "viz", 3 | "version": "1.0.0", 4 | "description": "This repository contains the code accompaniying the thesis project \"Transformer based Motion In-betweening\".", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1", 8 | "dev": "webpack serve", 9 | "build": "webpack", 10 | "start": "http-server ./dist" 11 | }, 12 | "author": "", 13 | "license": "ISC", 14 | "devDependencies": { 15 | "@types/three": "^0.137.0", 16 | "copy-webpack-plugin": "^10.2.4", 17 | "css-loader": "^6.6.0", 18 | "html-webpack-plugin": "^5.5.0", 19 | "http-server": "^14.1.0", 20 | "style-loader": "^3.3.1", 21 | "ts-loader": "^9.2.6", 22 | "typescript": "^4.5.5", 23 | "webpack": "^5.69.1", 24 | "webpack-cli": "^4.9.2", 25 | "webpack-dev-server": "^4.7.4" 26 | }, 27 | "dependencies": { 28 | "three": "^0.137.5" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /generator_frontend/src/style.css: -------------------------------------------------------------------------------- 1 | body, html { 2 | margin: 0; 3 | padding: 0; 4 | width: 100vw; 5 | height: 100vh; 6 | overflow: hidden; 7 | } 8 | 9 | .container { 10 | width: 100%; 11 | height: 100%; 12 | 13 | display: flex; 14 | align-items: center; 15 | justify-content: space-evenly; 16 | 17 | background-color: red; 18 | } 19 | 20 | .canvas { 21 | width: 30%; 22 | height: 100%; 23 | border: 1px solid black; 24 | } 25 | 26 | .controls { 27 | position: fixed; 28 | 29 | top: 50px; 30 | right: 50px; 31 | 32 | background-color: bisque; 33 | 34 | z-index: 1000; 35 | 36 | display: flex; 37 | flex-direction: column; 38 | align-items: center; 39 | } 40 | 41 | .controls.hidden { 42 | display: none; 43 | } 44 | 45 | .generate { 46 | position: fixed; 47 | 48 | top: 50px; 49 | left: 50px; 50 | 51 | background-color: green; 52 | color: white; 53 | 54 | padding: 8px; 55 | border-radius: 4px; 56 | 57 | z-index: 1000; 58 | } 59 | 60 | -------------------------------------------------------------------------------- /util/smoothing/moving_average_smoothing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | def moving_average_smoothing(x: Tensor, dim: int = -1, window_size: int = 1) -> Tensor: 5 | """Applies moving average smoothing to the given tensor. 6 | 7 | Args: 8 | x (Tensor): Tensor to smoothen. 9 | dim (int): Dimension to average. 10 | window_size (int): Window of moving average. 11 | 12 | Returns: 13 | Tensor: Smoothened tensor 14 | """ 15 | index_tensor = torch.arange(x.shape[dim]).to(x.device) 16 | 17 | x_index = [torch.index_select(x, dim, index_tensor[i]) for i in range(x.shape[dim])] 18 | 19 | averaged_tensors = [] 20 | 21 | for i in range(len(x_index)): 22 | n = 0 23 | s = 0 24 | for j in range(-window_size, window_size + 1): 25 | if i + j >= 0 and i + j < len(x_index): 26 | n += 1 27 | s += x_index[i + j] 28 | 29 | averaged_tensors.append(s / n) 30 | 31 | return torch.cat(averaged_tensors, dim=dim) -------------------------------------------------------------------------------- /viz/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | const HtmlWebpackPlugin = require('html-webpack-plugin'); 3 | const CopyWebpackPlugin = require('copy-webpack-plugin'); 4 | 5 | module.exports = { 6 | mode: 'development', 7 | entry: './src/index.ts', 8 | module: { 9 | rules: [ 10 | { 11 | test: /\.tsx?$/, 12 | use: 'ts-loader', 13 | exclude: /node_modules/, 14 | }, 15 | { 16 | test: /\.css$/i, 17 | use: ["style-loader", "css-loader"], 18 | }, 19 | ], 20 | }, 21 | plugins: [ 22 | new HtmlWebpackPlugin({ 23 | title: 'Visualization', 24 | }), 25 | new CopyWebpackPlugin({ 26 | patterns: [ 27 | { from: "./public", to: "./static" } 28 | ], 29 | }), 30 | ], 31 | resolve: { 32 | extensions: ['.tsx', '.ts', '.js'], 33 | }, 34 | output: { 35 | filename: 'bundle.js', 36 | path: path.resolve(__dirname, 'dist'), 37 | }, 38 | devServer: { 39 | static: { 40 | directory: path.join(__dirname, 'public'), 41 | } 42 | }, 43 | }; -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Pavithra S, Aananth V and Madhav Aggarwal 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /generator_frontend/webpack.config.js: -------------------------------------------------------------------------------- 1 | const path = require('path'); 2 | const HtmlWebpackPlugin = require('html-webpack-plugin'); 3 | const CopyWebpackPlugin = require('copy-webpack-plugin'); 4 | 5 | module.exports = { 6 | mode: 'development', 7 | entry: './src/index.ts', 8 | module: { 9 | rules: [ 10 | { 11 | test: /\.tsx?$/, 12 | use: 'ts-loader', 13 | exclude: /node_modules/, 14 | }, 15 | { 16 | test: /\.css$/i, 17 | use: ["style-loader", "css-loader"], 18 | }, 19 | ], 20 | }, 21 | plugins: [ 22 | new HtmlWebpackPlugin({ 23 | title: 'Visualization', 24 | }), 25 | new CopyWebpackPlugin({ 26 | patterns: [ 27 | { from: "./public", to: "./static" } 28 | ], 29 | }), 30 | ], 31 | resolve: { 32 | extensions: ['.tsx', '.ts', '.js'], 33 | }, 34 | output: { 35 | filename: 'bundle.js', 36 | path: path.resolve(__dirname, 'dist'), 37 | }, 38 | devServer: { 39 | static: { 40 | directory: path.join(__dirname, 'public'), 41 | } 42 | }, 43 | }; -------------------------------------------------------------------------------- /model/loss/fk_loss.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import Module 3 | from torch.nn.functional import l1_loss 4 | from constants import PARENTS 5 | 6 | from util.quaternions import quat_fk_tensor 7 | 8 | class FKLoss(Module): 9 | """nn.Module that calculates forward kinematics loss 10 | """ 11 | def __init__(self) -> None: 12 | super(FKLoss, self).__init__() 13 | 14 | def forward(self, local_p: Tensor, local_q: Tensor, local_p_cap: Tensor, local_q_cap: Tensor) -> Tensor: 15 | """ 16 | Args: 17 | local_p (Tensor): Local positions [..., J, 3] 18 | local_q (Tensor): Local quaternions [..., J, 4] 19 | local_p_cap (Tensor): Predicted Local positions [..., J, 3] 20 | local_q_cap (Tensor): Predicted Local quaternions [..., J, 4] 21 | 22 | Returns: 23 | Tensor: FK Loss. 24 | """ 25 | 26 | # Get globals 27 | q, x = quat_fk_tensor(local_q, local_p, PARENTS) 28 | 29 | q_cap, x_cap = quat_fk_tensor(local_q_cap, local_p_cap, PARENTS) 30 | 31 | # Calculate Loss 32 | return l1_loss(x, x_cap) -------------------------------------------------------------------------------- /viz/src/index.ts: -------------------------------------------------------------------------------- 1 | import { Clock } from 'three'; 2 | import RenderBVH from './renderBVH'; 3 | 4 | import './style.css'; 5 | import fetchJson from './helpers/fetchJson'; 6 | 7 | (async () => { 8 | const animation_id = prompt("Enter animation id: ", "1"); 9 | // const animation_id = 1; 10 | 11 | const motionSequences = [ 12 | await fetchJson(`./static/animations/${animation_id}/ground_truth.json`) as number[][][], 13 | await fetchJson(`./static/animations/${animation_id}/input.json`) as number[][][], 14 | await fetchJson(`./static/animations/${animation_id}/output.json`) as number[][][], 15 | await fetchJson(`./static/animations/${animation_id}/output_smoothened.json`) as number[][][] 16 | ]; 17 | 18 | const container = document.createElement('div'); 19 | container.classList.add('container'); 20 | 21 | motionSequences.forEach((motionSequence, index) => { 22 | const canvas = document.createElement('canvas'); 23 | 24 | canvas.classList.add('canvas'); 25 | 26 | container.appendChild(canvas); 27 | 28 | const clock = new Clock(); 29 | 30 | new RenderBVH(canvas, motionSequence, clock, index); 31 | }); 32 | 33 | document.body.appendChild(container); 34 | })(); -------------------------------------------------------------------------------- /generator_backend/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from flask import Flask, request 5 | from flask_cors import CORS 6 | 7 | from constants import DEVICE 8 | from generator_backend.model import Model 9 | 10 | def create_app(test_config=None): 11 | # create and configure the app 12 | app = Flask(__name__, instance_relative_config=True) 13 | 14 | CORS(app) 15 | 16 | app.config.from_mapping( 17 | SECRET_KEY='dev', 18 | DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'), 19 | ) 20 | 21 | if test_config is None: 22 | # load the instance config, if it exists, when not testing 23 | app.config.from_pyfile('config.py', silent=True) 24 | else: 25 | # load the test config if passed in 26 | app.config.from_mapping(test_config) 27 | 28 | # ensure the instance folder exists 29 | try: 30 | os.makedirs(app.instance_path) 31 | except OSError: 32 | pass 33 | 34 | model = Model('2e_2d_2h_30k_linear') 35 | 36 | print(model) 37 | 38 | # a simple page that says hello 39 | @app.route('/hello') 40 | def hello(): 41 | return f'Hello, World! {DEVICE}' 42 | 43 | @app.post('/generate') 44 | def generate(): 45 | print("Generate called") 46 | 47 | body = json.loads(request.data) 48 | 49 | z_x, in_x = model.generate(body['gpos']) 50 | 51 | res = { 52 | 'z_x': z_x, 53 | 'in_x': in_x 54 | } 55 | 56 | return json.dumps(res) 57 | 58 | return app -------------------------------------------------------------------------------- /model/encoding/input_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from constants import NUM_JOINTS 4 | 5 | from model.encoding.linear_encoding import LinearEncoding 6 | 7 | class InputEncoder(nn.Module): 8 | """Encodes the input sequence. 9 | """ 10 | 11 | def __init__(self, embedding_size) -> None: 12 | super(InputEncoder, self).__init__() 13 | 14 | self.embedding_size = embedding_size 15 | 16 | self.q_encoder = LinearEncoding(input_size=4, hidden_size=16, output_size=self.embedding_size['q']) 17 | self.p_encoder = LinearEncoding(input_size=3, hidden_size=8, output_size=self.embedding_size['p']) 18 | self.v_encoder = LinearEncoding(input_size=3, hidden_size=8, output_size=self.embedding_size['v']) 19 | 20 | def forward(self, local_q: Tensor, root_p: Tensor, root_v: Tensor) -> Tensor: 21 | """ 22 | Args: 23 | local_q (Tensor): Local quaternions. [batch_size, seq_len, J, 4] 24 | root_p (Tensor): Global Root Position. [batch_size, seq_len, 3] 25 | root_v (Tensor): Global Root Velocity. [batch_size, seq_len, 3] 26 | 27 | Returns: 28 | Tensor: Encoded Input 29 | """ 30 | local_q = self.q_encoder(local_q) 31 | root_p = self.p_encoder(root_p) 32 | root_v = self.v_encoder(root_v) 33 | 34 | # Reshape Q 35 | local_q = local_q.reshape((local_q.shape[0], local_q.shape[1], NUM_JOINTS * self.embedding_size['q'])) 36 | 37 | # Concateneate tensors 38 | x = torch.cat([root_p, root_v, local_q], dim=-1) 39 | 40 | return x -------------------------------------------------------------------------------- /model/encoding/output_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch import nn, Tensor 3 | from constants import NUM_JOINTS 4 | from model.encoding.linear_encoding import LinearEncoding 5 | 6 | class OutputDecoder(nn.Module): 7 | """Encodes the input sequence. 8 | """ 9 | 10 | def __init__(self, embedding_size) -> None: 11 | super(OutputDecoder, self).__init__() 12 | 13 | self.embedding_size = embedding_size 14 | 15 | self.q_decoder = LinearEncoding(input_size=self.embedding_size['q'], hidden_size=16, output_size=4) 16 | self.p_decoder = LinearEncoding(input_size=self.embedding_size['p'], hidden_size=8, output_size=3) 17 | self.v_decoder = LinearEncoding(input_size=self.embedding_size['v'], hidden_size=8, output_size=3) 18 | 19 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 20 | """[summary] 21 | 22 | Args: 23 | x (Tensor): Output Tensor. 24 | [batch_size, seq_len, J * Q_EMBEDDING_DIM + P_EMBEDDING_DIM + V_EMBEDDING_DIM] 25 | 26 | Returns: 27 | Tuple[Tensor, Tensor, Tensor]: [local_q, root_p, root_v] 28 | """ 29 | # Extract three components 30 | root_p = x[:, :, :self.embedding_size['p']] 31 | root_v = x[:, :, self.embedding_size['p']:self.embedding_size['p'] + self.embedding_size['v']] 32 | local_q = x[:, :, self.embedding_size['p'] + self.embedding_size['v']:] 33 | 34 | # Reshape Q 35 | local_q = local_q.reshape((local_q.shape[0], local_q.shape[1], NUM_JOINTS, self.embedding_size['q'])) 36 | 37 | # Decode Tensors 38 | local_q = self.q_decoder(local_q) 39 | root_p = self.p_decoder(root_p) 40 | root_v = self.v_decoder(root_v) 41 | 42 | return local_q, root_p, root_v -------------------------------------------------------------------------------- /model/encoding/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Module 6 | 7 | class PositionalEncoding(Module): 8 | """nn.Module that performs positional encoding 9 | """ 10 | 11 | def __init__(self, d_model: int, max_len: int = 5000, device: torch.device = torch.device('cpu')) -> None: 12 | """ 13 | Args: 14 | d_model (int): Embedding dimension. Must be even. 15 | max_len (int, optional): Maximum sequence length. Defaults to 5000. 16 | """ 17 | super(PositionalEncoding, self).__init__() 18 | 19 | max_len = max(max_len, 64) 20 | 21 | # Tensor[max_len, 1] 22 | position = torch.arange(max_len).unsqueeze(1) 23 | 24 | # Tensor[1, d_model / 2] 25 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) 26 | 27 | # Tensor[max_len, d_model] 28 | pe = torch.zeros(max_len, d_model) 29 | 30 | # Set all even terms to Tensor[max_len, d_model / 2] 31 | pe[:, 0::2] = torch.sin(position * div_term) 32 | 33 | # Set all odd terms to Tensor[max_len, d_model / 2] 34 | pe[:, 1::2] = torch.cos(position * div_term) 35 | 36 | pe = pe.to(device) 37 | 38 | self.register_buffer('pe', pe) 39 | 40 | def forward(self, x: Tensor) -> Tensor: 41 | """ 42 | Args: 43 | x: Tensor, shape [batch_size, seq_len, embedding_dim] 44 | """ 45 | return x + self.pe[:x.size(1)] 46 | 47 | if __name__ == '__main__': 48 | d_model = 256 49 | seq_len = 128 50 | 51 | model = PositionalEncoding(d_model, seq_len) 52 | x = torch.zeros(1, seq_len, d_model) 53 | 54 | y = model(x).squeeze() 55 | 56 | plt.imshow(y) 57 | plt.savefig('random.png') -------------------------------------------------------------------------------- /util/interpolation/linear_interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import LongTensor, Tensor 3 | 4 | def linear_interpolation(x: Tensor, dim: int, fixed_points: LongTensor) -> Tensor: 5 | """Perform linear interpolation fixed_points on a tensor 6 | 7 | This function accepts a tensor and a list of fixed indices. 8 | The fixed indices are preserved as-is and the positions in 9 | between are filled with linear interpolation values. 10 | 11 | TODO: Optimize further, maybe use cuda, parallelize? 12 | 13 | Args: 14 | x (Tensor): Input tensor to interpolate. [..., N, ...]. 15 | dim (int): Dimension to index. 16 | fixed_points (LongTensor): List of fixed indices. [i: 0 <= i < N] 17 | First and Last Indices MUST BE 0 and N - 1 18 | 19 | Returns: 20 | Tensor: Linear Interpolated Tensor 21 | """ 22 | fixed_values = x.index_select(dim, fixed_points) 23 | 24 | xi = [] 25 | 26 | index_tensor = torch.arange(len(fixed_points)).unsqueeze(dim=1).to(fixed_points.device) 27 | 28 | for i in range(len(fixed_points) - 1): 29 | n = fixed_points[i + 1] - fixed_points[i] 30 | 31 | delta = (fixed_values.index_select(dim, index_tensor[i + 1]) - fixed_values.index_select(dim, index_tensor[i])) / n 32 | 33 | d_range = [] 34 | 35 | # TODO: Optimize 36 | for j in range(n): 37 | d_range.append((fixed_values.index_select(dim, index_tensor[i]) + delta * j)) 38 | 39 | xi.append(torch.cat(d_range, dim=dim)) 40 | 41 | xi.append(fixed_values.index_select(dim, index_tensor[fixed_values.shape[dim] - 1])) 42 | 43 | return torch.cat(xi, dim=dim) 44 | 45 | if __name__ == '__main__': 46 | x = Tensor([[[[1, 2]], [[0, 0]], [[3, 6]], [[0, 0]], [[5, 10]]]]) 47 | fixed_points = LongTensor([0, 2, 4]) 48 | 49 | out = linear_interpolation(x, 1, fixed_points) 50 | 51 | print(out) 52 | 53 | print(x.shape, out.shape) -------------------------------------------------------------------------------- /generator_frontend/src/controls.ts: -------------------------------------------------------------------------------- 1 | import { Vector3 } from "three"; 2 | 3 | export default class Controls { 4 | position: Vector3; 5 | container: HTMLDivElement; 6 | sliders: HTMLInputElement[]; 7 | 8 | constructor() { 9 | this.position = new Vector3(0, 0, 0); 10 | 11 | this.initHTML(); 12 | this.initControls(); 13 | } 14 | 15 | initHTML() { 16 | this.container = document.createElement('div'); 17 | this.container.className = 'controls'; 18 | 19 | const axes = ['x', 'y', 'z']; 20 | 21 | this.sliders = []; 22 | 23 | for (let i = 0; i < 3; i++) { 24 | const slider = document.createElement('input'); 25 | slider.type = 'range'; 26 | slider.min = '-100'; 27 | slider.max = '100'; 28 | slider.id = `slider-${axes[i]}` 29 | this.sliders.push(slider); 30 | 31 | const label = document.createElement('label'); 32 | label.innerText = axes[i]; 33 | label.setAttribute('for', `slider-${axes[i]}`); 34 | 35 | this.container.appendChild(label); 36 | this.container.appendChild(slider); 37 | } 38 | 39 | document.body.appendChild(this.container); 40 | } 41 | 42 | initControls() { 43 | this.container.addEventListener('click', (event: MouseEvent) => { 44 | event.stopPropagation(); 45 | }) 46 | 47 | this.sliders.forEach((slider, index) => { 48 | slider.addEventListener('input', (event: InputEvent) => { 49 | // @ts-ignore 50 | this.position.setComponent(index, parseFloat(event.target.value)); 51 | }) 52 | }) 53 | } 54 | 55 | show() { 56 | this.container.classList.remove('hidden'); 57 | } 58 | 59 | hide() { 60 | this.container.classList.add('hidden'); 61 | } 62 | 63 | set(position: Vector3) { 64 | this.position = position; 65 | 66 | this.sliders[0].value = position.x.toString() 67 | this.sliders[1].value = position.y.toString() 68 | this.sliders[2].value = position.z.toString() 69 | } 70 | } -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import Module 5 | from constants import DEVICE, NUM_JOINTS 6 | 7 | from .encoding.positional_encoding import PositionalEncoding 8 | 9 | 10 | class Transformer(Module): 11 | """nn.Module for transformer""" 12 | 13 | # Constructor 14 | def __init__(self, config): 15 | super().__init__() 16 | 17 | self.config = config 18 | self.dim_model = NUM_JOINTS * config['embedding_size']['q'] + config['embedding_size']['q'] + config['embedding_size']['v'] 19 | 20 | self.positional_encoder = PositionalEncoding( 21 | d_model=self.dim_model, 22 | max_len=self.config['dataset']['window_size'], 23 | device=DEVICE) 24 | 25 | self.transformer = nn.Transformer( 26 | d_model=self.dim_model, 27 | nhead=self.config['model']['num_heads'], 28 | num_encoder_layers=self.config['model']['num_encoder_layers'], 29 | num_decoder_layers=self.config['model']['num_decoder_layers'], 30 | dropout=self.config['model']['dropout_p'], 31 | batch_first=True) 32 | 33 | self.register_buffer( 34 | 'mask', 35 | self.get_target_mask(self.config['dataset']['window_size']).to(DEVICE)) 36 | 37 | def forward(self, src, target): 38 | 39 | src = self.positional_encoder(src) 40 | target = self.positional_encoder(target) 41 | 42 | # Transformer blocks - Out size = (sequence length, batch_size, num_tokens) 43 | return self.transformer( 44 | src, 45 | target, 46 | # src_mask=self.mask, 47 | tgt_mask=self.mask) 48 | 49 | def get_target_mask(self, size) -> torch.Tensor: 50 | 51 | mask = torch.full((1, size), float("-inf")) 52 | # mask[0, 0] = 0 53 | # mask[0, 1] = 0 54 | # mask[0, 2] = 0 55 | # mask[0, -3] = 0 56 | # mask[0, -2] = 0 57 | # mask[0, -1] = 0 58 | 59 | # Unmask every KEYFRAME_GAP frame 60 | mask[0, ::self.config['dataset']['keyframe_gap']] = 0 61 | 62 | mask = mask.repeat(size, 1) 63 | 64 | return mask 65 | -------------------------------------------------------------------------------- /generator_frontend/src/index.ts: -------------------------------------------------------------------------------- 1 | import { Clock } from 'three'; 2 | import MotionEditor from './motionEditor'; 3 | 4 | import './style.css'; 5 | import fetchJson from './helpers/fetchJson'; 6 | import Controls from './controls'; 7 | import generateMotionSequence from './generate'; 8 | import RenderBVH from './renderBVH'; 9 | 10 | (async () => { 11 | // const animation_id = prompt("Enter animation id: ", "1"); 12 | const animation_id = 1; 13 | 14 | const motionSequence = await fetchJson(`./static/animations/${animation_id}/ground_truth.json`) as number[][][]; 15 | 16 | const container = document.createElement('div'); 17 | container.classList.add('container'); 18 | 19 | const controls = new Controls(); 20 | 21 | const canvas = document.createElement('canvas'); 22 | 23 | canvas.classList.add('canvas'); 24 | 25 | container.appendChild(canvas); 26 | 27 | const clock = new Clock(); 28 | 29 | const motionEditor = new MotionEditor(canvas, motionSequence, controls); 30 | 31 | document.body.appendChild(container); 32 | 33 | const generateButton = document.createElement('button'); 34 | generateButton.className = 'generate'; 35 | generateButton.innerText = 'Generate'; 36 | 37 | document.body.appendChild(generateButton); 38 | 39 | generateButton.addEventListener('click', async () => { 40 | const res = await generateMotionSequence(motionEditor.track); 41 | 42 | const motionSequences = [ 43 | res['z_x'], 44 | res['in_x'] 45 | ]; 46 | 47 | document.body.removeChild(container); 48 | document.body.removeChild(generateButton); 49 | 50 | motionEditor.controls.hide(); 51 | 52 | const ncontainer = document.createElement('div'); 53 | ncontainer.classList.add('container'); 54 | 55 | motionSequences.forEach((motionSequence, index) => { 56 | const canvas = document.createElement('canvas'); 57 | 58 | canvas.classList.add('canvas'); 59 | 60 | ncontainer.appendChild(canvas); 61 | 62 | const clock = new Clock(); 63 | 64 | new RenderBVH(canvas, motionSequence, clock, index); 65 | }); 66 | 67 | document.body.appendChild(ncontainer); 68 | }); 69 | })(); -------------------------------------------------------------------------------- /model/loss/npss_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Module 6 | 7 | from constants import PARENTS 8 | from util.quaternions import quat_fk_tensor 9 | 10 | class NPSSLoss(Module): 11 | """nn.Module that calculates NPSS Loss 12 | 13 | Based on: https://arxiv.org/abs/1809.03036 14 | """ 15 | 16 | def __init__(self) -> None: 17 | super(NPSSLoss, self).__init__() 18 | 19 | def forward(self, local_p: Tensor, local_q: Tensor, local_p_cap: Tensor, local_q_cap: Tensor) -> Tensor: 20 | """ 21 | Args: 22 | local_p (Tensor): Local positions [..., J, 3] 23 | local_q (Tensor): Local quaternions [..., J, 4] 24 | local_p_cap (Tensor): Predicted Local positions [..., J, 3] 25 | local_q_cap (Tensor): Predicted Local quaternions [..., J, 4] 26 | 27 | Returns: 28 | Tensor: NPSS Loss. 29 | """ 30 | 31 | # Get global quaternions 32 | q, _ = quat_fk_tensor(local_q, local_p, PARENTS) 33 | q_cap, _ = quat_fk_tensor(local_q_cap, local_p_cap, PARENTS) 34 | 35 | x = q 36 | x_cap = q_cap 37 | 38 | # Reshape to have all features in one dimension 39 | x = x.reshape((x.shape[0], x.shape[1], -1)) 40 | x_cap = x_cap.reshape((x_cap.shape[0], x_cap.shape[1], -1)) 41 | 42 | # compute fourier coefficients 43 | x_ftt_coeff = torch.real(torch.fft.fft(x, axis=1)) 44 | x_cap_fft_coeff = torch.real(torch.fft.fft(x_cap, axis=1)) 45 | 46 | #Sq the coeff 47 | x_ftt_coeff_sq = torch.square(x_ftt_coeff) 48 | x_cap_ftt_coeff_sq = torch.square(x_cap_fft_coeff) 49 | 50 | # sum the tensor 51 | x_tot = torch.sum(x_ftt_coeff_sq, axis=1, keepdim=True) 52 | x_cap_tot = torch.sum(x_cap_ftt_coeff_sq, axis=1, keepdim=True) 53 | 54 | # normalize 55 | x_norm = x_ftt_coeff_sq / x_tot 56 | x_cap_norm = x_cap_ftt_coeff_sq / x_cap_tot 57 | 58 | # Compute emd 59 | emd = torch.norm(x_norm - x_cap_norm, dim=1, p=1) 60 | 61 | # Find total norm 62 | x_norm_tot = torch.sum(x_norm, axis=1) 63 | 64 | # Weighted Avg (NPSS) 65 | npss_loss = torch.sum(emd * x_norm_tot) / torch.sum(x_norm_tot) 66 | 67 | return npss_loss -------------------------------------------------------------------------------- /util/interpolation/spherical_interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import LongTensor, Tensor 3 | 4 | from util.quaternions import quat_exp, quat_inv_tensor, quat_mul_tensor 5 | 6 | def quat_slerp(q1: Tensor, q2: Tensor, n: int, dim: int) -> Tensor: 7 | """Perform slerp on two quaternions to get n interpolated values. 8 | 9 | Slerp(q_1, q_2, t) = q_1 . (q_1^(-1) . q_2)^t 10 | 11 | Args: 12 | q1 (Tensor): Quaternion 1 [..., 4] 13 | q2 (Tensor): Quaternion 2 [..., 4] 14 | n (int): Number of inbetween values 15 | dim (int): Dimension to index. 16 | 17 | Returns: 18 | Tensor: Spherical Interpolated Tensor [..., n, 4] 19 | """ 20 | q1_inv = quat_inv_tensor(q1) 21 | 22 | # print("q1_inv", q1_inv) 23 | 24 | q = quat_mul_tensor(q1_inv, q2) 25 | 26 | # print("q", q) 27 | 28 | quats = [] 29 | 30 | for i in range(n): 31 | t = (i + 1) / (n + 1) 32 | # print("quat_exp(q, t)", quat_exp(q, t)) 33 | quats.append(quat_mul_tensor(q1, quat_exp(q, t))) 34 | 35 | return torch.cat(quats, dim) 36 | 37 | def spherical_interpolation(x: Tensor, dim: int, fixed_points: LongTensor) -> Tensor: 38 | """Perform spherical interpolation fixed_points on a tensor 39 | 40 | This function accepts a tensor and a list of fixed indices. 41 | The fixed indices are preserved as-is and the positions in 42 | between are filled with pherical interpolation values. 43 | 44 | TODO: Optimize further, maybe use cuda, parallelize? 45 | 46 | Args: 47 | x (Tensor): Input tensor to interpolate. [..., N, ...]. 48 | dim (int): Dimension to index. 49 | fixed_points (LongTensor): List of fixed indices. [i: 0 <= i < N] 50 | First and Last Indices MUST BE 0 and N - 1 51 | 52 | Returns: 53 | Tensor: Spherical Interpolated Tensor 54 | """ 55 | fixed_values = x.index_select(dim, fixed_points) 56 | 57 | index_tensor = torch.arange(len(fixed_points)).unsqueeze(dim=1).to(fixed_points.device) 58 | 59 | xi = [] 60 | 61 | for i in range(len(fixed_points) - 1): 62 | n = fixed_points[i + 1] - fixed_points[i] - 1 63 | 64 | d_range = quat_slerp(fixed_values.index_select(dim, index_tensor[i]), fixed_values.index_select(dim, index_tensor[i + 1]), n, dim) 65 | 66 | xi.append(fixed_values.index_select(dim, index_tensor[i])) 67 | xi.append(d_range) 68 | 69 | xi.append(fixed_values.index_select(dim, index_tensor[fixed_values.shape[dim] - 1])) 70 | 71 | return torch.cat(xi, dim=dim) -------------------------------------------------------------------------------- /model/loss/l2_loss.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch 3 | from torch.nn import Module 4 | from torch.nn.functional import mse_loss 5 | from constants import DEVICE, PARENTS 6 | 7 | from util.quaternions import quat_fk_tensor 8 | from train_stats import load_stats 9 | 10 | class L2PLoss(Module): 11 | """nn.Module that calculates L2P Loss 12 | """ 13 | 14 | def __init__(self) -> None: 15 | super(L2PLoss, self).__init__() 16 | x_mean_np, x_std_np = load_stats() 17 | 18 | self.x_mean = Tensor(x_mean_np).to(DEVICE) 19 | self.x_std = Tensor(x_std_np).to(DEVICE) 20 | 21 | def forward(self, local_p: Tensor, local_q: Tensor, local_p_cap: Tensor, local_q_cap: Tensor) -> Tensor: 22 | """ 23 | Args: 24 | local_p (Tensor): Local positions [..., J, 3] 25 | local_q (Tensor): Local quaternions [..., J, 4] 26 | local_p_cap (Tensor): Predicted Local positions [..., J, 3] 27 | local_q_cap (Tensor): Predicted Local quaternions [..., J, 4] 28 | 29 | Returns: 30 | Tensor: L2P Loss. 31 | """ 32 | 33 | # Get globals 34 | _, x = quat_fk_tensor(local_q, local_p, PARENTS) 35 | 36 | _, x_cap = quat_fk_tensor(local_q_cap, local_p_cap, PARENTS) 37 | 38 | # Normalize 39 | x = (x - self.x_mean) / self.x_std 40 | x_cap = (x_cap - self.x_mean) / self.x_std 41 | 42 | return torch.mean(torch.sqrt(torch.sum((x - x_cap)**2, axis=(2, 3)))) 43 | 44 | # Calculate Loss 45 | return mse_loss(x, x_cap) 46 | 47 | class L2QLoss(Module): 48 | """nn.Module that calculates L2Q Loss 49 | """ 50 | 51 | def __init__(self) -> None: 52 | super(L2QLoss, self).__init__() 53 | 54 | def forward(self, local_p: Tensor, local_q: Tensor, local_p_cap: Tensor, local_q_cap: Tensor) -> Tensor: 55 | """ 56 | Args: 57 | local_p (Tensor): Local positions [..., J, 3] 58 | local_q (Tensor): Local quaternions [..., J, 4] 59 | local_p_cap (Tensor): Predicted Local positions [..., J, 3] 60 | local_q_cap (Tensor): Predicted Local quaternions [..., J, 4] 61 | 62 | Returns: 63 | Tensor: L2P Loss. 64 | """ 65 | 66 | # Get globals 67 | q, _ = quat_fk_tensor(local_q, local_p, PARENTS) 68 | 69 | q_cap, _ = quat_fk_tensor(local_q_cap, local_p_cap, PARENTS) 70 | 71 | # Normalize 72 | q = q / torch.norm(q, dim=-1, keepdim=True) 73 | q_cap = q_cap / torch.norm(q_cap, dim=-1, keepdim=True) 74 | 75 | return torch.mean(torch.sqrt(torch.sum((q - q_cap)**2, axis=(2, 3)))) 76 | 77 | # Calculate Loss 78 | return mse_loss(q, q_cap) -------------------------------------------------------------------------------- /util/load_data.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch.utils.data import DataLoader 3 | from constants import LAFAN1_DIRECTORY 4 | 5 | from util.lafan1 import LaFan1 6 | 7 | 8 | def load_train_dataset(dataset_config) -> DataLoader: 9 | """Function to load dataset from the given directory, perform pre-processing. 10 | 11 | Args: 12 | dataset_directory (str): Location of the dataset 13 | 14 | Returns: 15 | DataLoader: train_dataloader 16 | """ 17 | lafan_train_dataset = LaFan1(dataset_directory=LAFAN1_DIRECTORY, 18 | train=True, 19 | seq_len=dataset_config['window_size'], 20 | files_to_read=dataset_config['files_to_read']) 21 | lafan_train_loader = DataLoader( 22 | lafan_train_dataset, 23 | batch_size=dataset_config['batch_size'], 24 | shuffle=True, 25 | num_workers=dataset_config['num_workers'], 26 | drop_last=True, 27 | ) 28 | 29 | return lafan_train_loader 30 | 31 | 32 | def load_viz_dataset(dataset_config) -> DataLoader: 33 | """Function to load dataset from the given directory, perform pre-processing. 34 | 35 | Args: 36 | dataset_directory (str): Location of the dataset 37 | 38 | Returns: 39 | DataLoader: viz_dataloader 40 | """ 41 | lafan_viz_dataset = LaFan1(dataset_directory=LAFAN1_DIRECTORY, 42 | train=False, 43 | seq_len=dataset_config['window_size'], 44 | files_to_read=1) 45 | lafan_viz_loader = DataLoader( 46 | lafan_viz_dataset, 47 | batch_size=dataset_config['batch_size'], 48 | shuffle=False, 49 | num_workers=dataset_config['num_workers'], 50 | drop_last=True, 51 | ) 52 | 53 | return lafan_viz_loader 54 | 55 | 56 | def load_test_dataset(dataset_config) -> DataLoader: 57 | """Function to load dataset from the given directory, perform pre-processing. 58 | 59 | Args: 60 | dataset_directory (str): Location of the dataset 61 | 62 | Returns: 63 | DataLoader: test_dataloader 64 | """ 65 | lafan_test_dataset = LaFan1(dataset_directory=LAFAN1_DIRECTORY, 66 | train=False, 67 | seq_len=dataset_config['window_size'], 68 | files_to_read=-1) 69 | 70 | lafan_test_loader = DataLoader( 71 | lafan_test_dataset, 72 | batch_size=dataset_config['batch_size'], 73 | shuffle=False, 74 | num_workers=dataset_config['num_workers'], 75 | drop_last=True, 76 | ) 77 | 78 | return lafan_test_loader 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer based Motion In-betweening 2 | 3 | In-Betweening is the process of drawing transition frames between temporally-sparse keyframes to create a smooth animation sequence. This work presents a novel transformer based in betweening technique that serves as a tool for 3D animators. 4 | 5 | 6 | ## Visualizer 7 | 8 | https://user-images.githubusercontent.com/44777563/182904169-af1d0b4d-a023-4ddd-b33e-b5a4ec4aa717.mp4 9 | 10 | 11 | 12 | 13 | 14 | ## Motion Generator 15 | 16 | https://user-images.githubusercontent.com/44777563/182904188-ce43c556-7472-47d2-bc04-c5cb868674b6.mp4 17 | 18 | 19 | 20 | ## Architecture 21 | ![](./static/model.jpg) 22 | 23 | ## Downloading Data 24 | 25 | 26 | ### LAFAN1 Dataset 27 | 28 | - Download the dataset from [Ubisoft's Github Repository](https://github.com/ubisoft/ubisoft-laforge-animation-dataset/blob/master/lafan1/lafan1.zip) and extract it to `/data/lafan1/` 29 | 30 | ## Installation 31 | 32 | 1. Install Pre-Requisites 33 | 34 | - Python 3.9 35 | - PyTorch 1.10 36 | 37 | 2. Clone the repository 38 | ```git clone https://github.com/Pavi114/motion-completion-using-transformers``` 39 | 40 | 3. Copy config/default.yml to config/`model_name`.yml and edit as needed. 41 | 42 | 4. Install Python Dependencies 43 | 44 | - Create a virtualenv: `python3 -m virtualenv -p python3.9 venv` 45 | 46 | - Install Dependencies: `pip install -r requirements.txt` 47 | 48 | ## Execution 49 | 50 | First activate the venv: `source venv/bin/activate` 51 | 52 | ### Training 53 | 54 | ``` 55 | train.py [-h] [--model_name MODEL_NAME] [--save_weights | --no-save_weights] [--load_weights | --no-load_weights] 56 | 57 | optional arguments: 58 | -h, --help show the help message and exit 59 | --model_name MODEL_NAME 60 | Name of the model. Used for loading and saving weights. 61 | --save_weights, --no-save_weights 62 | Save model weights. (default: False) 63 | --load_weights, --no-load_weights 64 | Load model weights. (default: False) 65 | ``` 66 | 67 | ### Visualization 68 | 69 | ``` 70 | visualize.py [-h] [--model_name MODEL_NAME] 71 | 72 | optional arguments: 73 | -h, --help show the help message and exit 74 | --model_name MODEL_NAME 75 | Name of the model. Used for loading and saving weights. 76 | ``` 77 | 78 | ### Evaluation 79 | 80 | ``` 81 | evaluate.py [-h] [--model_name MODEL_NAME] 82 | 83 | optional arguments: 84 | -h, --help show the help message and exit 85 | --model_name MODEL_NAME 86 | Name of the model. Used for loading and saving weights. 87 | ``` 88 | 89 | ### Running Visualizer 90 | 91 | 0. Navigate to `./viz` directory 92 | 93 | ``` 94 | cd ./viz 95 | ``` 96 | 97 | 1. Install NPM Modules 98 | 99 | ``` 100 | npm install 101 | ``` 102 | 103 | 2. Build visualizer 104 | 105 | ``` 106 | npm run build 107 | ``` 108 | 109 | 3. Copy output file to `./dist` 110 | 111 | ``` 112 | cp output/[MODEL_NAME] viz/dist/static/animations/[MODEL_NAME] 113 | ``` 114 | 115 | 4. Run viz 116 | 117 | ``` 118 | npm start 119 | ``` 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /generator_backend/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from constants import DEVICE, MODEL_SAVE_DIRECTORY, PARENTS 4 | from model.encoding.input_encoder import InputEncoder 5 | from model.encoding.output_decoder import OutputDecoder 6 | from model.transformer import Transformer 7 | from util.interpolation.fixed_points import get_fixed_points 8 | from util.interpolation.interpolation_factory import get_p_interpolation, get_q_interpolation 9 | from util.quaternions import quat_fk, quat_ik_tensor 10 | from util.read_config import read_config 11 | 12 | class Model: 13 | def __init__(self, model_name) -> None: 14 | # Load config 15 | self.config = read_config(model_name) 16 | 17 | self.transformer = Transformer(self.config).to(DEVICE) 18 | 19 | self.input_encoder = InputEncoder(self.config['embedding_size']).to(DEVICE) 20 | 21 | self.output_decoder = OutputDecoder(self.config['embedding_size']).to(DEVICE) 22 | 23 | self.fixed_points = get_fixed_points(self.config['dataset']['window_size'], self.config['dataset']['keyframe_gap']) 24 | 25 | self.p_interpolation_function = get_p_interpolation(self.config['hyperparameters']['interpolation']) 26 | self.q_interpolation_function = get_q_interpolation(self.config['hyperparameters']['interpolation']) 27 | 28 | checkpoint = torch.load(f'{MODEL_SAVE_DIRECTORY}/model_{model_name}.pt') 29 | 30 | self.transformer.load_state_dict(checkpoint['transformer_state_dict']) 31 | self.input_encoder.load_state_dict(checkpoint['encoder_state_dict']) 32 | self.output_decoder.load_state_dict(checkpoint['decoder_state_dict']) 33 | 34 | def generate(self, gpos): 35 | gpos = torch.tensor(gpos).to(DEVICE) 36 | 37 | z_gpos = torch.zeros((128, 22, 3)).to(DEVICE) 38 | 39 | for i in range(127): 40 | z_gpos[i] = gpos[30 * (i // 30)] 41 | 42 | z_gpos[-1] = gpos[-1] 43 | 44 | in_gpos = self.p_interpolation_function(gpos, 0, self.fixed_points) 45 | 46 | return z_gpos.cpu().detach().numpy().tolist(), in_gpos.cpu().detach().numpy().tolist() 47 | 48 | def _generate(self, gpos): 49 | grot = torch.tensor(grot).to(DEVICE) 50 | gpos = torch.tensor(gpos).to(DEVICE) 51 | 52 | lrot, lpos = quat_ik_tensor(grot, gpos, PARENTS) 53 | 54 | lrotSeq = torch.zeros((128, 22, 4)).to(DEVICE) 55 | lposSeq = torch.zeros((128, 22, 3)).to(DEVICE) 56 | rootPosSeq = torch.zeros((128, 3)).to(DEVICE) 57 | rootVSeq = torch.zeros((128, 3)).to(DEVICE) 58 | 59 | for i in range(5): 60 | lrotSeq[30*i] = lrot[i] 61 | lposSeq[30*i] = lpos[i] 62 | rootPosSeq[30*i] = lpos[i][0] 63 | 64 | lrotSeq[-1] = lrot[-1] 65 | rootPosSeq[-1] = lpos[-1][0] 66 | 67 | in_local_q = self.q_interpolation_function(lrotSeq, -3, self.fixed_points) 68 | in_local_p = self.p_interpolation_function(lposSeq, -3, self.fixed_points) 69 | in_root_p = self.p_interpolation_function(rootPosSeq, -2, self.fixed_points) 70 | in_root_v = self.p_interpolation_function(rootVSeq, -2, self.fixed_points) 71 | 72 | seq = self.input_encoder(in_local_q, in_root_p, in_root_v) 73 | 74 | out = self.transformer(seq, seq) 75 | 76 | out_q, out_p, out_v = self.output_decoder(out) 77 | 78 | out_q = out_q / torch.norm(out_q, dim=-1, keepdim=True) 79 | 80 | out_local_p = in_local_p 81 | out_local_p[:, :, 0, :] = out_p 82 | 83 | _, x = quat_fk(lposSeq.detach().cpu().numpy(), 84 | lposSeq.detach().cpu().numpy(), PARENTS) 85 | _, in_x = quat_fk(in_local_q.detach().cpu().numpy(), 86 | in_local_p.detach().cpu().numpy(), PARENTS) 87 | _, out_x = quat_fk(out_q.detach().cpu().numpy(), 88 | out_local_p.detach().cpu().numpy(), PARENTS) 89 | 90 | return x, in_x, out_x 91 | -------------------------------------------------------------------------------- /util/lafan1.py: -------------------------------------------------------------------------------- 1 | from constants import DEVICE 2 | from util import quaternions, extract 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | import sys 8 | import os 9 | sys.path.insert(0, os.path.dirname(__file__)) 10 | sys.path.append("..") 11 | 12 | 13 | class LaFan1(Dataset): 14 | 15 | def __init__(self, dataset_directory, train=False, seq_len=50, offset=10, files_to_read=-1): 16 | """ 17 | Args: 18 | dataset_directory (string): Path to the bvh files. 19 | seq_len (int): The max len of the sequence for interpolation. 20 | """ 21 | if train: 22 | self.actors = ['subject1', 'subject2', 'subject3', 'subject4'] 23 | else: 24 | self.actors = ['subject5'] 25 | self.train = train 26 | self.seq_len = seq_len 27 | self.offset = offset 28 | self.files_to_read = files_to_read 29 | self.data = self.load_data(dataset_directory) 30 | self.cur_seq_length = 5 31 | 32 | def load_data(self, dataset_directory): 33 | 34 | print('Building the data set...') 35 | X, Q, parents = extract.get_lafan1_set( 36 | dataset_directory, self.actors, window=self.seq_len, offset=self.offset, files_to_read=self.files_to_read) 37 | 38 | Q = Q.cpu() 39 | X = X.cpu() 40 | 41 | # Global representation: 42 | q_glbl, x_glbl = quaternions.quat_fk_tensor(Q, X, parents) 43 | 44 | # Global positions stats: 45 | # self.x_mean = torch.mean(x_glbl.reshape( 46 | # [x_glbl.shape[0], x_glbl.shape[1], -1]).permute([0, 2, 1]), dim=(0, 2), keepdim=True) 47 | # self.x_std = torch.std(x_glbl.reshape( 48 | # [x_glbl.shape[0], x_glbl.shape[1], -1]).permute([0, 2, 1]), dim=(0, 2), keepdim=True) 49 | 50 | input_ = {} 51 | # The following features are inputs: 52 | # 1. local quaternion vector (J * 4d) 53 | input_['local_q'] = Q 54 | 55 | # 2. global root velocity vector (3d) 56 | input_['root_v'] = x_glbl[:, 1:, 0, :] - x_glbl[:, :-1, 0, :] 57 | 58 | # Add zero velocity vector for last frame 59 | input_['root_v'] = torch.cat( 60 | (input_['root_v'], torch.zeros((input_['root_v'].shape[0], 1, 3))), dim=-2) 61 | 62 | # 3. contact information vector (4d) 63 | # input_['contact'] = torch.cat([contacts_l, contacts_r], dim=-1) 64 | 65 | # 4. global root position offset (?d) 66 | input_['root_p_offset'] = x_glbl[:, -1, 0, :] 67 | 68 | # 5. local quaternion offset (?d) 69 | input_['local_q_offset'] = Q[:, -1, :, :] 70 | 71 | # 6. target 72 | input_['target'] = Q[:, -1, :, :] 73 | 74 | # 7. root pos 75 | input_['root_p'] = x_glbl[:, :, 0, :] 76 | 77 | # 8. X 78 | input_['X'] = x_glbl[:, :, :, :] 79 | 80 | # 9. local_p 81 | input_['local_p'] = X 82 | 83 | # print('Nb of sequences : {}\n'.format(X.shape[0])) 84 | # print(input_['X'].shape, input_['local_q'].shape) 85 | # print(input_['X'][0][0]) 86 | 87 | return input_ 88 | 89 | def __len__(self): 90 | return len(self.data['local_q']) 91 | 92 | def __getitem__(self, idx): 93 | sample = {} 94 | sample['local_q'] = self.data['local_q'][idx] 95 | sample['root_v'] = self.data['root_v'][idx] 96 | # sample['contact'] = self.data['contact'][idx] 97 | sample['root_p_offset'] = self.data['root_p_offset'][idx] 98 | sample['local_q_offset'] = self.data['local_q_offset'][idx] 99 | sample['target'] = self.data['target'][idx] 100 | sample['root_p'] = self.data['root_p'][idx] 101 | sample['X'] = self.data['X'][idx] 102 | sample['local_p'] = self.data['local_p'][idx] 103 | return sample 104 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from constants import DEVICE, MODEL_SAVE_DIRECTORY, OUTPUT_DIRECTORY, PARENTS 5 | 6 | import torch 7 | from torch.nn import functional as F 8 | from model.encoding.input_encoder import InputEncoder 9 | from model.encoding.output_decoder import OutputDecoder 10 | from util.interpolation.fixed_points import get_fixed_points 11 | from util.interpolation.interpolation_factory import get_p_interpolation, get_q_interpolation 12 | from util.load_data import load_viz_dataset 13 | from util.math import round_tensor 14 | from util.quaternions import quat_fk 15 | from model.transformer import Transformer 16 | from util.read_config import read_config 17 | from util.smoothing.moving_average_smoothing import moving_average_smoothing 18 | 19 | 20 | def visualize(model_name='default'): 21 | # Load config 22 | config = read_config(model_name) 23 | 24 | # Load and Preprocess Data 25 | test_dataloader = load_viz_dataset(config['dataset']) 26 | 27 | # Training Loop 28 | transformer = Transformer(config).to(DEVICE) 29 | 30 | input_encoder = InputEncoder(config['embedding_size']).to(DEVICE) 31 | 32 | output_decoder = OutputDecoder(config['embedding_size']).to(DEVICE) 33 | 34 | fixed_points = get_fixed_points(config['dataset']['window_size'], config['dataset']['keyframe_gap']) 35 | 36 | p_interpolation_function = get_p_interpolation(config['hyperparameters']['interpolation']) 37 | q_interpolation_function = get_q_interpolation(config['hyperparameters']['interpolation']) 38 | 39 | checkpoint = torch.load(f'{MODEL_SAVE_DIRECTORY}/model_{model_name}.pt') 40 | 41 | transformer.load_state_dict(checkpoint['transformer_state_dict']) 42 | input_encoder.load_state_dict(checkpoint['encoder_state_dict']) 43 | output_decoder.load_state_dict(checkpoint['decoder_state_dict']) 44 | 45 | transformer.eval() 46 | input_encoder.eval() 47 | output_decoder.eval() 48 | 49 | # Visualize 50 | viz_batch = next(iter(test_dataloader)) 51 | 52 | local_q = round_tensor(viz_batch["local_q"].to(DEVICE), decimals=4) 53 | local_p = round_tensor(viz_batch["local_p"].to(DEVICE), decimals=4) 54 | root_p = round_tensor(viz_batch["X"][:, :, 0, :].to(DEVICE), decimals=4) 55 | root_v = round_tensor(viz_batch["root_v"].to(DEVICE), decimals=4) 56 | 57 | in_local_q = q_interpolation_function(local_q, 1, fixed_points) 58 | in_local_p = p_interpolation_function(local_p, 1, fixed_points) 59 | in_root_p = p_interpolation_function(root_p, 1, fixed_points) 60 | in_root_v = p_interpolation_function(root_v, 1, fixed_points) 61 | 62 | seq = input_encoder(in_local_q, in_root_p, in_root_v) 63 | 64 | out = transformer(seq, seq) 65 | 66 | ma_out = moving_average_smoothing(out, dim=1) 67 | 68 | out_q, out_p, out_v = output_decoder(out) 69 | 70 | ma_out_q, ma_out_p, ma_out_v = output_decoder(ma_out) 71 | 72 | out_local_p = local_p 73 | out_local_p[:, :, 0, :] = out_p 74 | 75 | ma_out_local_p = local_p 76 | ma_out_local_p[:, :, 0, :] = ma_out_p 77 | 78 | _, x = quat_fk(local_q.detach().cpu().numpy(), 79 | local_p.detach().cpu().numpy(), PARENTS) 80 | _, in_x = quat_fk(in_local_q.detach().cpu().numpy(), 81 | in_local_p.detach().cpu().numpy(), PARENTS) 82 | _, out_x = quat_fk(out_q.detach().cpu().numpy(), 83 | out_local_p.detach().cpu().numpy(), PARENTS) 84 | _, ma_out_x = quat_fk(ma_out_q.detach().cpu().numpy(), 85 | ma_out_local_p.detach().cpu().numpy(), PARENTS) 86 | 87 | for i in range(config['dataset']['batch_size']): 88 | output_dir = f'{OUTPUT_DIRECTORY}/viz/{model_name}/{i}' 89 | 90 | Path(output_dir).mkdir(parents=True, exist_ok=True) 91 | 92 | with open(f'{output_dir}/ground_truth.json', 'w') as f: 93 | f.truncate(0) 94 | f.write(json.dumps(x[i, :, :, :].tolist())) 95 | 96 | with open(f'{output_dir}/input.json', 'w') as f: 97 | f.truncate(0) 98 | f.write(json.dumps(in_x[i, :, :, :].tolist())) 99 | 100 | with open(f'{output_dir}/output.json', 'w') as f: 101 | f.truncate(0) 102 | f.write(json.dumps(out_x[i, :, :, :].tolist())) 103 | 104 | with open(f'{output_dir}/output_smoothened.json', 'w') as f: 105 | f.truncate(0) 106 | f.write(json.dumps(ma_out_x[i, :, :, :].tolist())) 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | 112 | parser.add_argument( 113 | '--model_name', 114 | help='Name of the model. Used for loading and saving weights.', 115 | type=str, 116 | default='default') 117 | 118 | args = parser.parse_args() 119 | 120 | visualize(args.model_name) 121 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import chain 3 | from constants import DEVICE, MODEL_SAVE_DIRECTORY 4 | 5 | import torch 6 | from torch.nn import L1Loss 7 | from torch.optim import Adam 8 | from tqdm import tqdm 9 | from model.encoding.input_encoder import InputEncoder 10 | from model.encoding.output_decoder import OutputDecoder 11 | from model.loss.fk_loss import FKLoss 12 | from util.interpolation.fixed_points import get_fixed_points 13 | from util.interpolation.interpolation_factory import get_p_interpolation, get_q_interpolation 14 | from util.load_data import load_train_dataset 15 | from model.transformer import Transformer 16 | from util.math import round_tensor 17 | from util.read_config import read_config 18 | from util.plot import plot_loss 19 | 20 | 21 | def train(model_name='default', save_weights=False, load_weights=False): 22 | # Load config 23 | config = read_config(model_name) 24 | 25 | # Load and Preprocess Data 26 | train_dataloader = load_train_dataset(config['dataset']) 27 | 28 | # Training Loop 29 | transformer = Transformer(config).to(DEVICE) 30 | 31 | input_encoder = InputEncoder(config['embedding_size']).to(DEVICE) 32 | 33 | output_decoder = OutputDecoder(config['embedding_size']).to(DEVICE) 34 | 35 | optimizer_g = Adam(lr=config['hyperparameters']['learning_rate'], 36 | params=chain(transformer.parameters(), 37 | input_encoder.parameters(), 38 | output_decoder.parameters())) 39 | 40 | criterion = L1Loss().to(DEVICE) 41 | 42 | fk_criterion = FKLoss().to(DEVICE) 43 | 44 | best_loss = torch.Tensor([float("+inf")]).to(DEVICE) 45 | 46 | loss_history = [] 47 | 48 | p_interpolation_function = get_p_interpolation(config['hyperparameters']['interpolation']) 49 | q_interpolation_function = get_q_interpolation(config['hyperparameters']['interpolation']) 50 | 51 | fixed_points = get_fixed_points(config['dataset']['window_size'], config['dataset']['keyframe_gap']) 52 | 53 | if load_weights: 54 | checkpoint = torch.load( 55 | f'{MODEL_SAVE_DIRECTORY}/model_{model_name}.pt') 56 | 57 | transformer.load_state_dict(checkpoint['transformer_state_dict']) 58 | input_encoder.load_state_dict(checkpoint['encoder_state_dict']) 59 | output_decoder.load_state_dict(checkpoint['decoder_state_dict']) 60 | optimizer_g.load_state_dict(checkpoint['optimizer_state_dict']) 61 | best_loss = checkpoint['loss'] 62 | 63 | for epoch in range(config['hyperparameters']['epochs']): 64 | transformer.train() 65 | train_loss = 0 66 | tqdm_dataloader = tqdm(train_dataloader) 67 | for index, batch in enumerate(tqdm_dataloader): 68 | local_q = round_tensor(batch["local_q"].to(DEVICE), decimals=4) 69 | local_p = round_tensor(batch["local_p"].to(DEVICE), decimals=4) 70 | root_p = round_tensor(batch["X"][:, :, 0, :].to(DEVICE), decimals=4) 71 | root_v = round_tensor(batch["root_v"].to(DEVICE), decimals=4) 72 | 73 | in_local_q = q_interpolation_function(local_q, 1, fixed_points) 74 | in_root_p = p_interpolation_function(root_p, 1, fixed_points) 75 | in_root_v = p_interpolation_function(root_v, 1, fixed_points) 76 | 77 | # print(torch.any(torch.isnan(in_local_q))) 78 | 79 | # print(local_q[0][0]) 80 | # print(local_q[0][3]) 81 | # print(in_local_q[0][1]) 82 | 83 | # seq = input_encoder(local_q, root_p, root_v) 84 | seq = input_encoder(in_local_q, in_root_p, in_root_v) 85 | 86 | out = transformer(seq, seq) 87 | 88 | # print(out) 89 | 90 | # break 91 | 92 | out_q, out_p, out_v = output_decoder(out) 93 | 94 | out_local_p = local_p 95 | out_local_p[:, :, 0, :] = out_p 96 | 97 | optimizer_g.zero_grad() 98 | 99 | q_loss = criterion(local_q, out_q) 100 | # p_loss = criterion(root_p, out_p) 101 | # v_loss = criterion(root_v, out_v) 102 | fk_loss = fk_criterion(local_p, local_q, out_local_p, out_q) 103 | 104 | loss = 10 * q_loss + fk_loss 105 | 106 | loss.backward() 107 | 108 | optimizer_g.step() 109 | tqdm_dataloader.set_description( 110 | f"{model_name} | batch: {index + 1} | loss: {loss:.4f} q_loss: {q_loss:.4f} fk_loss: {fk_loss:.4f}" 111 | ) 112 | train_loss += loss 113 | 114 | loss_history.append(train_loss.detach().cpu().numpy()) 115 | print(f"epoch: {epoch + 1}, train loss: {train_loss/index}") 116 | 117 | if save_weights and train_loss < best_loss: 118 | # Save weights 119 | torch.save( 120 | { 121 | 'transformer_state_dict': transformer.state_dict(), 122 | 'encoder_state_dict': input_encoder.state_dict(), 123 | 'decoder_state_dict': output_decoder.state_dict(), 124 | 'optimizer_state_dict': optimizer_g.state_dict(), 125 | 'loss': best_loss 126 | }, f'{MODEL_SAVE_DIRECTORY}/model_{model_name}.pt') 127 | 128 | best_loss = train_loss 129 | 130 | plot_loss(loss_history) 131 | 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser() 135 | 136 | parser.add_argument( 137 | '--model_name', 138 | help='Name of the model. Used for loading and saving weights.', 139 | type=str, 140 | default='default') 141 | 142 | parser.add_argument('--save_weights', 143 | help='Save model weights.', 144 | action=argparse.BooleanOptionalAction, 145 | default=False) 146 | 147 | parser.add_argument('--load_weights', 148 | help='Load model weights.', 149 | action=argparse.BooleanOptionalAction, 150 | default=False) 151 | 152 | args = parser.parse_args() 153 | 154 | train(args.model_name, args.save_weights, args.load_weights) 155 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from cgi import test 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.nn import L1Loss 7 | from tqdm import tqdm 8 | 9 | from constants import DEVICE, MODEL_SAVE_DIRECTORY, OUTPUT_DIRECTORY, PARENTS 10 | from model.encoding.input_encoder import InputEncoder 11 | from model.encoding.output_decoder import OutputDecoder 12 | from model.loss.fk_loss import FKLoss 13 | from model.loss.l2_loss import L2PLoss, L2QLoss 14 | from model.loss.npss_loss import NPSSLoss 15 | from util.interpolation.fixed_points import get_fixed_points 16 | from util.interpolation.interpolation_factory import get_p_interpolation, get_q_interpolation 17 | from util.load_data import load_test_dataset 18 | from model.transformer import Transformer 19 | from util.math import round_tensor 20 | from util.read_config import read_config 21 | from util.smoothing.moving_average_smoothing import moving_average_smoothing 22 | 23 | def evaluate(model_name='default'): 24 | # Load config 25 | config = read_config(model_name) 26 | 27 | # Load and Preprocess Data 28 | test_dataloader = load_test_dataset(config['dataset']) 29 | 30 | # Training Loop 31 | transformer = Transformer(config).to(DEVICE) 32 | 33 | input_encoder = InputEncoder(config['embedding_size']).to(DEVICE) 34 | 35 | output_decoder = OutputDecoder(config['embedding_size']).to(DEVICE) 36 | 37 | fixed_points = get_fixed_points(config['dataset']['window_size'], config['dataset']['keyframe_gap']) 38 | 39 | p_interpolation_function = get_p_interpolation(config['hyperparameters']['interpolation']) 40 | q_interpolation_function = get_q_interpolation(config['hyperparameters']['interpolation']) 41 | 42 | checkpoint = torch.load(f'{MODEL_SAVE_DIRECTORY}/model_{model_name}.pt', map_location=DEVICE) 43 | 44 | transformer.load_state_dict(checkpoint['transformer_state_dict']) 45 | input_encoder.load_state_dict(checkpoint['encoder_state_dict']) 46 | output_decoder.load_state_dict(checkpoint['decoder_state_dict']) 47 | 48 | # Initialize losses 49 | l1_criterion = L1Loss() 50 | fk_criterion = FKLoss() 51 | l2p_criterion = L2PLoss() 52 | l2q_criterion = L2QLoss() 53 | npss_criterion = NPSSLoss() 54 | 55 | transformer.eval() 56 | input_encoder.eval() 57 | output_decoder.eval() 58 | 59 | global_q_loss = 0 60 | global_fk_loss = 0 61 | global_l2p_loss = 0 62 | global_l2q_loss = 0 63 | global_npss_loss = 0 64 | 65 | global_interpolation_q_loss = 0 66 | global_interpolation_fk_loss = 0 67 | global_interpolation_l2p_loss = 0 68 | global_interpolation_l2q_loss = 0 69 | global_interpolation_npss_loss = 0 70 | 71 | # Visualize 72 | tqdm_dataloader = tqdm(test_dataloader) 73 | for index, batch in enumerate(tqdm_dataloader): 74 | local_q = round_tensor(batch["local_q"].to(DEVICE), decimals=4) 75 | local_p = round_tensor(batch["local_p"].to(DEVICE), decimals=4) 76 | root_p = round_tensor(batch["X"][:, :, 0, :].to(DEVICE), decimals=4) 77 | root_v = round_tensor(batch["root_v"].to(DEVICE), decimals=4) 78 | 79 | in_local_q = q_interpolation_function(local_q, 1, fixed_points) 80 | in_local_q = in_local_q / torch.norm(in_local_q, dim=-1, keepdim=True) 81 | in_local_p = p_interpolation_function(local_p, 1, fixed_points) 82 | in_root_p = p_interpolation_function(root_p, 1, fixed_points) 83 | in_root_v = p_interpolation_function(root_v, 1, fixed_points) 84 | 85 | seq = input_encoder(in_local_q, in_root_p, in_root_v) 86 | 87 | out = transformer(seq, seq) 88 | 89 | out_q, out_p, out_v = output_decoder(out) 90 | 91 | out_q = out_q / torch.norm(out_q, dim=-1, keepdim=True) 92 | 93 | out_local_p = local_p 94 | out_local_p[:, :, 0, :] = out_p 95 | 96 | # Evaluate 97 | q_loss = l1_criterion(local_q, out_q).item() 98 | fk_loss = fk_criterion(local_p, local_q, out_local_p, out_q).item() 99 | l2p_loss = l2p_criterion(local_p, local_q, out_local_p, out_q).item() 100 | l2q_loss = l2q_criterion(local_p, local_q, out_local_p, out_q).item() 101 | npss_loss = npss_criterion(local_p, local_q, out_local_p, out_q).item() 102 | 103 | in_q_loss = l1_criterion(local_q, in_local_q).item() 104 | in_fk_loss = fk_criterion(local_p, local_q, in_local_p, in_local_q).item() 105 | in_l2p_loss = l2p_criterion(local_p, local_q, in_local_p, in_local_q).item() 106 | in_l2q_loss = l2q_criterion(local_p, local_q, in_local_p, in_local_q).item() 107 | in_npss_loss = npss_criterion(local_p, local_q, in_local_p, in_local_q).item() 108 | 109 | tqdm_dataloader.set_description( 110 | f"batch: {index + 1} | q: {q_loss:.4f} fk: {fk_loss:.4f} l2p: {l2p_loss:.4f} l2q: {l2q_loss:.4f} npss: {npss_loss:.4f}" 111 | ) 112 | 113 | global_q_loss += q_loss 114 | global_fk_loss += fk_loss 115 | global_l2p_loss += l2p_loss 116 | global_l2q_loss += l2q_loss 117 | global_npss_loss += npss_loss 118 | 119 | global_interpolation_q_loss += in_q_loss 120 | global_interpolation_fk_loss += in_fk_loss 121 | global_interpolation_l2p_loss += in_l2p_loss 122 | global_interpolation_l2q_loss += in_l2q_loss 123 | global_interpolation_npss_loss += in_npss_loss 124 | 125 | # Store results 126 | path = f'{OUTPUT_DIRECTORY}/metrics' 127 | 128 | Path(path).mkdir(parents=True, exist_ok=True) 129 | 130 | s = f'Q: {global_q_loss / (index + 1)}\n' + \ 131 | f'FK: {global_fk_loss / (index + 1)}\n' + \ 132 | f'L2P: {global_l2p_loss / (index + 1)}\n' + \ 133 | f'L2Q: {global_l2q_loss / (index + 1)}\n' + \ 134 | f'NPSS: {global_npss_loss / (index + 1)}' 135 | 136 | in_s = f'IN_Q: {global_interpolation_q_loss / (index + 1)}\n' + \ 137 | f'IN_FK: {global_interpolation_fk_loss / (index + 1)}\n' + \ 138 | f'IN_L2P: {global_interpolation_l2p_loss / (index + 1)}\n' + \ 139 | f'IN_L2Q: {global_interpolation_l2q_loss / (index + 1)}\n' + \ 140 | f'IN_NPSS: {global_interpolation_npss_loss / (index + 1)}\n' 141 | 142 | with open(f'{path}/{model_name}.txt', 'w') as f: 143 | f.truncate(0) 144 | f.write(s) 145 | 146 | print(model_name, "\n", s, "\n", in_s) 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser() 150 | 151 | parser.add_argument( 152 | '--model_name', 153 | help='Name of the model. Used for loading and saving weights.', 154 | type=str, 155 | default='default') 156 | 157 | args = parser.parse_args() 158 | 159 | evaluate(args.model_name) 160 | -------------------------------------------------------------------------------- /util/conversion.py: -------------------------------------------------------------------------------- 1 | # This file will contain conversions like 2 | # - euler_angles to quaternions 3 | # - quaternions to euler_angles 4 | # - absolute positions to angles 5 | # - bvh to angles 6 | # etc. Add when necessary. 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from constants import DEVICE 12 | from . import quaternions 13 | 14 | 15 | def euler_to_quat(e, order='zyx'): 16 | """ 17 | 18 | Converts from an euler representation to a quaternion representation 19 | 20 | :param e: euler tensor 21 | :param order: order of euler rotations 22 | :return: quaternion tensor 23 | """ 24 | axis = { 25 | 'x': np.asarray([1, 0, 0], dtype=np.float32), 26 | 'y': np.asarray([0, 1, 0], dtype=np.float32), 27 | 'z': np.asarray([0, 0, 1], dtype=np.float32)} 28 | 29 | q0 = angle_axis_to_quat(e[..., 0], axis[order[0]]) 30 | q1 = angle_axis_to_quat(e[..., 1], axis[order[1]]) 31 | q2 = angle_axis_to_quat(e[..., 2], axis[order[2]]) 32 | 33 | return quaternions.quat_mul(q0, quaternions.quat_mul(q1, q2)) 34 | 35 | 36 | def angle_axis_to_quat(angle, axis): 37 | """ 38 | Converts from and angle-axis representation to a quaternion representation 39 | 40 | :param angle: angles tensor 41 | :param axis: axis tensor 42 | :return: quaternion tensor 43 | """ 44 | c = np.cos(angle / 2.0)[..., np.newaxis] 45 | s = np.sin(angle / 2.0)[..., np.newaxis] 46 | q = np.concatenate([c, s * axis], axis=-1) 47 | return q 48 | 49 | # Orient the data according to the las past keframe 50 | def rotate_at_frame(X, Q, parents, n_past=10): 51 | """ 52 | Re-orients the animation data according to the last frame of past context. 53 | 54 | :param X: tensor of local positions of shape (Batchsize, Timesteps, Joints, 3) 55 | :param Q: tensor of local quaternions (Batchsize, Timesteps, Joints, 4) 56 | :param parents: list of parents' indices 57 | :param n_past: number of frames in the past context 58 | :return: The rotated positions X and quaternions Q 59 | """ 60 | # Get global quats and global poses (FK) 61 | global_q, global_x = quaternions.quat_fk(Q, X, parents) 62 | 63 | key_glob_Q = global_q[:, n_past - 1: n_past, 0:1, :] # (B, 1, 1, 4) 64 | forward = np.array([1, 0, 1])[np.newaxis, np.newaxis, np.newaxis, :] \ 65 | * quaternions.quat_mul_vec(key_glob_Q, np.array([0, 1, 0])[np.newaxis, np.newaxis, np.newaxis, :]) 66 | forward = normalize(forward) 67 | yrot = normalize(quaternions.quat_between(np.array([1, 0, 0]), forward)) 68 | new_glob_Q = quaternions.quat_mul(quaternions.quat_inv(yrot), global_q) 69 | new_glob_X = quaternions.quat_mul_vec(quaternions.quat_inv(yrot), global_x) 70 | 71 | # back to local quat-pos 72 | Q, X = quaternions.quat_ik(new_glob_Q, new_glob_X, parents) 73 | 74 | return X, Q 75 | 76 | # Orient the data according to the las past keframe 77 | def rotate_at_frame_tensor(X, Q, parents, n_past=10): 78 | """ 79 | Re-orients the animation data according to the last frame of past context. 80 | 81 | :param X: tensor of local positions of shape (Batchsize, Timesteps, Joints, 3) 82 | :param Q: tensor of local quaternions (Batchsize, Timesteps, Joints, 4) 83 | :param parents: list of parents' indices 84 | :param n_past: number of frames in the past context 85 | :return: The rotated positions X and quaternions Q 86 | """ 87 | # Get global quats and global poses (FK) 88 | global_q, global_x = quaternions.quat_fk_tensor(Q, X, parents) 89 | 90 | key_glob_Q = global_q[:, n_past - 1: n_past, 0:1, :] # (B, 1, 1, 4) 91 | 92 | forward = torch.Tensor([1, 0, 1]).reshape((1, 1, 1, 3)) \ 93 | * quaternions.quat_mul_vec_tensor(key_glob_Q, torch.Tensor([0, 1, 0]).reshape((1, 1, 1, 3))) 94 | 95 | forward = normalize_tensor(forward) 96 | 97 | yrot = normalize_tensor(quaternions.quat_between_tensor(torch.Tensor([1, 0, 0]).reshape((1, 1, 1, 3)), forward)) 98 | 99 | global_q = quaternions.quat_mul_tensor(quaternions.quat_inv_tensor(yrot), global_q) 100 | global_x = quaternions.quat_mul_vec_tensor(quaternions.quat_inv_tensor(yrot), global_x) 101 | 102 | # back to local quat-pos 103 | Q, X = quaternions.quat_ik_tensor(global_q, global_x, parents) 104 | 105 | return X, Q 106 | 107 | def length(x, axis=-1, keepdims=True): 108 | """ 109 | Computes vector norm along a tensor axis(axes) 110 | 111 | :param x: tensor 112 | :param axis: axis(axes) along which to compute the norm 113 | :param keepdims: indicates if the dimension(s) on axis should be kept 114 | :return: The length or vector of lengths. 115 | """ 116 | lgth = np.sqrt(np.sum(x * x, axis=axis, keepdims=keepdims)) 117 | return lgth 118 | 119 | def length_tensor(x, axis=-1, keepdims=True): 120 | """ 121 | Computes vector norm along a tensor axis(axes) 122 | 123 | :param x: tensor 124 | :param axis: axis(axes) along which to compute the norm 125 | :param keepdims: indicates if the dimension(s) on axis should be kept 126 | :return: The length or vector of lengths. 127 | """ 128 | lgth = torch.sqrt(torch.sum(x * x, dim=axis, keepdim=keepdims)) 129 | return lgth 130 | 131 | 132 | def normalize(x, axis=-1, eps=1e-8): 133 | """ 134 | Normalizes a tensor over some axis (axes) 135 | 136 | :param x: data tensor 137 | :param axis: axis(axes) along which to compute the norm 138 | :param eps: epsilon to prevent numerical instabilities 139 | :return: The normalized tensor 140 | """ 141 | res = x / (length(x, axis=axis) + eps) 142 | return res 143 | 144 | def normalize_tensor(x, axis=-1, eps=1e-8): 145 | """ 146 | Normalizes a tensor over some axis (axes) 147 | 148 | :param x: data tensor 149 | :param axis: axis(axes) along which to compute the norm 150 | :param eps: epsilon to prevent numerical instabilities 151 | :return: The normalized tensor 152 | """ 153 | res = x / (length_tensor(x, axis=axis) + eps) 154 | return res 155 | 156 | def extract_feet_contacts(pos, lfoot_idx, rfoot_idx, velfactor=0.02): 157 | """ 158 | Extracts binary tensors of feet contacts 159 | 160 | :param pos: tensor of global positions of shape (Timesteps, Joints, 3) 161 | :param lfoot_idx: indices list of left foot joints 162 | :param rfoot_idx: indices list of right foot joints 163 | :param velfactor: velocity threshold to consider a joint moving or not 164 | :return: binary tensors of left foot contacts and right foot contacts 165 | """ 166 | lfoot_xyz = (pos[1:, lfoot_idx, :] - pos[:-1, lfoot_idx, :]) ** 2 167 | contacts_l = (np.sum(lfoot_xyz, axis=-1) < velfactor) 168 | 169 | rfoot_xyz = (pos[1:, rfoot_idx, :] - pos[:-1, rfoot_idx, :]) ** 2 170 | contacts_r = (np.sum(rfoot_xyz, axis=-1) < velfactor) 171 | 172 | # Duplicate the last frame for shape consistency 173 | contacts_l = np.concatenate([contacts_l, contacts_l[-1:]], axis=0) 174 | contacts_r = np.concatenate([contacts_r, contacts_r[-1:]], axis=0) 175 | 176 | return contacts_l, contacts_r 177 | 178 | -------------------------------------------------------------------------------- /util/extract.py: -------------------------------------------------------------------------------- 1 | import re, os, ntpath 2 | import numpy as np 3 | import torch 4 | 5 | from constants import DEVICE 6 | 7 | from . import conversion, quaternions 8 | 9 | channelmap = { 10 | 'Xrotation': 'x', 11 | 'Yrotation': 'y', 12 | 'Zrotation': 'z' 13 | } 14 | 15 | channelmap_inv = { 16 | 'x': 'Xrotation', 17 | 'y': 'Yrotation', 18 | 'z': 'Zrotation', 19 | } 20 | 21 | ordermap = { 22 | 'x': 0, 23 | 'y': 1, 24 | 'z': 2, 25 | } 26 | 27 | 28 | class Anim(object): 29 | """ 30 | A very basic animation object 31 | """ 32 | def __init__(self, quats, pos, offsets, parents, bones): 33 | """ 34 | :param quats: local quaternions tensor 35 | :param pos: local positions tensor 36 | :param offsets: local joint offsets 37 | :param parents: bone hierarchy 38 | :param bones: bone names 39 | """ 40 | self.quats = quats 41 | self.pos = pos 42 | self.offsets = offsets 43 | self.parents = parents 44 | self.bones = bones 45 | 46 | 47 | def read_bvh(filename, start=None, end=None, order=None): 48 | """ 49 | Reads a BVH file and extracts animation information. 50 | 51 | :param filename: BVh filename 52 | :param start: start frame 53 | :param end: end frame 54 | :param order: order of euler rotations 55 | :return: A simple Anim object conatining the extracted information. 56 | """ 57 | 58 | f = open(filename, "r") 59 | 60 | i = 0 61 | active = -1 62 | end_site = False 63 | 64 | names = [] 65 | orients = np.array([]).reshape((0, 4)) 66 | offsets = np.array([]).reshape((0, 3)) 67 | parents = np.array([], dtype=int) 68 | 69 | # Parse the file, line by line 70 | for line in f: 71 | 72 | if "HIERARCHY" in line: continue 73 | if "MOTION" in line: continue 74 | 75 | rmatch = re.match(r"ROOT (\w+)", line) 76 | if rmatch: 77 | names.append(rmatch.group(1)) 78 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 79 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0) 80 | parents = np.append(parents, active) 81 | active = (len(parents) - 1) 82 | continue 83 | 84 | if "{" in line: continue 85 | 86 | if "}" in line: 87 | if end_site: 88 | end_site = False 89 | else: 90 | active = parents[active] 91 | continue 92 | 93 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 94 | if offmatch: 95 | if not end_site: 96 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 97 | continue 98 | 99 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 100 | if chanmatch: 101 | channels = int(chanmatch.group(1)) 102 | if order is None: 103 | channelis = 0 if channels == 3 else 3 104 | channelie = 3 if channels == 3 else 6 105 | parts = line.split()[2 + channelis:2 + channelie] 106 | if any([p not in channelmap for p in parts]): 107 | continue 108 | order = "".join([channelmap[p] for p in parts]) 109 | continue 110 | 111 | jmatch = re.match("\s*JOINT\s+(\w+)", line) 112 | if jmatch: 113 | names.append(jmatch.group(1)) 114 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 115 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0) 116 | parents = np.append(parents, active) 117 | active = (len(parents) - 1) 118 | continue 119 | 120 | if "End Site" in line: 121 | end_site = True 122 | continue 123 | 124 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 125 | if fmatch: 126 | if start and end: 127 | fnum = (end - start) - 1 128 | else: 129 | fnum = int(fmatch.group(1)) 130 | 131 | # Initialize positions and rotations array. 132 | 133 | # [fnum, J, 3] 134 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 135 | 136 | # [fnum, J, 3] 137 | rotations = np.zeros((fnum, len(orients), 3)) 138 | continue 139 | 140 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 141 | if fmatch: 142 | frametime = float(fmatch.group(1)) 143 | continue 144 | 145 | # If i doesn't lie in requried range, skip. 146 | if (start and end) and (i < start or i >= end - 1): 147 | i += 1 148 | continue 149 | 150 | # If nothing else matches, it is a row of values. 151 | dmatch = line.strip().split(' ') 152 | 153 | if dmatch: 154 | # [J + 1]. First 3 are root position. Next 3*J are euler rotations. 155 | data_block = np.array(list(map(float, dmatch))) 156 | 157 | # Number of joints. J 158 | N = len(parents) 159 | 160 | # Index of current joint 161 | fi = i - start if start else i 162 | 163 | # Position [fi, 0] is alone updated. Positions[fi, 1:J] is fixed offsets. 164 | positions[fi, 0:1] = data_block[0:3] 165 | 166 | # Rest J are filled with rotations 167 | rotations[fi, :] = data_block[3:].reshape(N, 3) 168 | 169 | i += 1 170 | 171 | f.close() 172 | 173 | rotations = conversion.euler_to_quat(np.radians(rotations), order=order) 174 | rotations = quaternions.remove_quat_discontinuities(rotations) 175 | 176 | return Anim(rotations, positions, offsets, parents, names) 177 | 178 | 179 | def get_lafan1_set(bvh_path, actors, window=50, offset=20, files_to_read=-1): 180 | """ 181 | Extract the same test set as in the article, given the location of the BVH files. 182 | 183 | :param bvh_path: Path to the dataset BVH files 184 | :param list: actor prefixes to use in set 185 | :param window: width of the sliding windows (in timesteps) 186 | :param offset: offset between windows (in timesteps) 187 | :return: tuple: 188 | X: local positions 189 | Q: local quaternions 190 | parents: list of parent indices defining the bone hierarchy 191 | contacts_l: binary tensor of left-foot contacts of shape (Batchsize, Timesteps, 2) 192 | contacts_r: binary tensor of right-foot contacts of shape (Batchsize, Timesteps, 2) 193 | """ 194 | npast = 10 195 | subjects = [] 196 | seq_names = [] 197 | X = [] 198 | Q = [] 199 | contacts_l = [] 200 | contacts_r = [] 201 | 202 | # Extract 203 | bvh_files = os.listdir(bvh_path) 204 | files_read = 0 205 | 206 | for file_no, file in enumerate(bvh_files): 207 | if files_read == files_to_read: 208 | break 209 | 210 | if file.endswith('.bvh'): 211 | seq_name, subject = ntpath.basename(file[:-4]).split('_') 212 | 213 | if subject in actors: 214 | print(f'Processing file {files_read + 1}: {file}') 215 | seq_path = os.path.join(bvh_path, file) 216 | anim = read_bvh(seq_path) 217 | 218 | # Sliding windows 219 | i = 0 220 | while i+window < anim.pos.shape[0]: 221 | # q, x = quaternions.quat_fk(anim.quats[i: i+window], anim.pos[i: i+window], anim.parents) 222 | # Extract contacts 223 | # c_l, c_r = conversion.extract_feet_contacts(x, [3, 4], [7, 8], velfactor=0.02) 224 | X.append(anim.pos[i: i+window]) 225 | Q.append(anim.quats[i: i+window]) 226 | seq_names.append(seq_name) 227 | subjects.append(subjects) 228 | # contacts_l.append(c_l) 229 | # contacts_r.append(c_r) 230 | 231 | i += offset 232 | 233 | files_read += 1 234 | 235 | with torch.no_grad(): 236 | X = torch.Tensor(np.asarray(X)) 237 | Q = torch.Tensor(np.asarray(Q)) 238 | # contacts_l = np.asarray(contacts_l) 239 | # contacts_r = np.asarray(contacts_r) 240 | 241 | # Sequences around XZ = 0 242 | xzs = torch.mean(X[:, :, 0, ::2], dim=1, keepdim=True) 243 | X[:, :, 0, 0] = X[:, :, 0, 0] - xzs[..., 0] 244 | X[:, :, 0, 2] = X[:, :, 0, 2] - xzs[..., 1] 245 | 246 | # Unify facing on last seed frame 247 | X, Q = conversion.rotate_at_frame_tensor(X, Q, anim.parents, n_past=npast) 248 | 249 | return X, Q, anim.parents 250 | 251 | def get_train_stats(bvh_folder, train_set): 252 | """ 253 | Extract the same training set as in the paper in order to compute the normalizing statistics 254 | :return: Tuple of (local position mean vector, local position standard deviation vector, local joint offsets tensor) 255 | """ 256 | print("Building the train set...") 257 | xtrain, qtrain, parents, _, _ = get_lafan1_set( 258 | bvh_folder, train_set, window=50, offset=20 259 | ) 260 | 261 | print("Computing stats...\n") 262 | # Joint offsets : are constant, so just take the first frame: 263 | offsets = xtrain[0:1, 0:1, 1:, :] # Shape : (1, 1, J, 3) 264 | 265 | # Global representation: 266 | q_glbl, x_glbl = quaternions.quat_fk(qtrain, xtrain, parents) 267 | 268 | x_glbl = x_glbl.reshape(-1, 22, 3) 269 | 270 | # Global positions stats: 271 | x_mean = np.mean( 272 | x_glbl, 273 | axis=0 274 | ) 275 | x_std = np.std( 276 | x_glbl, 277 | axis=0 278 | ) 279 | 280 | return x_mean, x_std, offsets, parents -------------------------------------------------------------------------------- /util/quaternions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | 5 | def quat_mul(x, y): 6 | """ 7 | Performs quaternion multiplication on arrays of quaternions 8 | 9 | :param x: tensor of quaternions of shape (..., Nb of joints, 4) 10 | :param y: tensor of quaternions of shape (..., Nb of joints, 4) 11 | :return: The resulting quaternions 12 | """ 13 | x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4] 14 | y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4] 15 | 16 | res = np.concatenate([ 17 | y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3, 18 | y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2, 19 | y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1, 20 | y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0], axis=-1) 21 | 22 | return res 23 | 24 | def quat_mul_tensor(x, y): 25 | """ 26 | Performs quaternion multiplication on arrays of quaternions 27 | 28 | :param x: tensor of quaternions of shape (..., Nb of joints, 4) 29 | :param y: tensor of quaternions of shape (..., Nb of joints, 4) 30 | :return: The resulting quaternions 31 | """ 32 | x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4] 33 | y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4] 34 | 35 | res = torch.cat([ 36 | y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3, 37 | y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2, 38 | y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1, 39 | y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0], axis=-1) 40 | 41 | return res 42 | 43 | def quat_inv(q): 44 | """ 45 | Inverts a tensor of quaternions 46 | 47 | :param q: quaternion tensor 48 | :return: tensor of inverted quaternions 49 | """ 50 | res = np.asarray([1, -1, -1, -1], dtype=np.float32) * q 51 | return res 52 | 53 | def quat_inv_tensor(q: Tensor): 54 | """ 55 | Inverts a tensor of quaternions 56 | 57 | :param q: quaternion tensor 58 | :return: tensor of inverted quaternions 59 | """ 60 | res = torch.Tensor([1, -1, -1, -1]).to(q.device) * q 61 | return res 62 | 63 | 64 | def quat_norm(q: Tensor) -> Tensor: 65 | """Obtains the norm of a quaternion 66 | 67 | Args: 68 | q (Tensor): Tensor of Quaternions [..., 4] 69 | 70 | Returns: 71 | Tensor: Norm of the quaternions [..., 1] 72 | """ 73 | 74 | return torch.sqrt( 75 | q[..., 0:1] * q[..., 0:1] + 76 | q[..., 1:2] * q[..., 1:2] + 77 | q[..., 2:3] * q[..., 2:3] + 78 | q[..., 3:4] * q[..., 3:4] 79 | ) 80 | 81 | def quat_angle(q: Tensor) -> Tensor: 82 | """Obtains the angle phi of a quaternion 83 | 84 | a = ||q|| cos(phi) 85 | phi = acos(a / ||q||) 86 | 87 | Args: 88 | q (Tensor): Tensor of quaternions. [..., 4] 89 | 90 | Returns: 91 | Tensor: Tensor of Phis. [..., 1] 92 | """ 93 | return torch.acos(q[..., 0:1] / (quat_norm(q) + 1e-8)) 94 | 95 | def quat_unit_vector(q: Tensor) -> Tensor: 96 | """Returns the unit vector of the quaternions. 97 | 98 | q = a + v 99 | v = n ||v|| = n ||q|| sin(phi) 100 | n = v / (||q|| sin(phi)) 101 | 102 | Args: 103 | q (Tensor): Input quaternions. [..., 4] 104 | 105 | Returns: 106 | Tensor: Unit vectors. [..., 3] 107 | """ 108 | return q[..., 1:] / (quat_norm(q) * torch.sin(quat_angle(q)) + 1e-8) 109 | 110 | def quat_exp(q: Tensor, x: float) -> Tensor: 111 | """Performs quaternion exponentiation. 112 | 113 | q^x = ||q||^x . (cos(x . phi) + n . sin(x . phi)) 114 | 115 | Args: 116 | q (Tensor): Input quaternions. [..., 4] 117 | x (float): Power to raise. 118 | 119 | Returns: 120 | Tensor: Output quaternions. [..., 4] 121 | """ 122 | norm = quat_norm(q) 123 | phi = quat_angle(q) 124 | n = quat_unit_vector(q) 125 | 126 | # print("norm", norm) 127 | 128 | # print("phi", phi) 129 | 130 | # print("n", n) 131 | 132 | x_phi = x * phi 133 | 134 | return torch.pow(norm, x) * torch.cat([torch.cos(x_phi), n * torch.sin(x_phi)], dim = -1) 135 | 136 | def quat_fk(lrot, lpos, parents): 137 | """ 138 | Performs Forward Kinematics (FK) on local quaternions and local positions to retrieve global representations 139 | 140 | :param lrot: tensor of local quaternions with shape (..., Nb of joints, 4) 141 | :param lpos: tensor of local positions with shape (..., Nb of joints, 3) 142 | :param parents: list of parents indices 143 | :return: tuple of tensors of global quaternion, global positions 144 | """ 145 | gp, gr = [lpos[..., :1, :]], [lrot[..., :1, :]] 146 | for i in range(1, len(parents)): 147 | gp.append(quat_mul_vec(gr[parents[i]], lpos[..., i:i+1, :]) + gp[parents[i]]) 148 | gr.append(quat_mul (gr[parents[i]], lrot[..., i:i+1, :])) 149 | 150 | res = np.concatenate(gr, axis=-2), np.concatenate(gp, axis=-2) 151 | return res 152 | 153 | def quat_fk_tensor(lrot, lpos, parents): 154 | """ 155 | Performs Forward Kinematics (FK) on local quaternions and local positions to retrieve global representations 156 | 157 | :param lrot: tensor of local quaternions with shape (..., Nb of joints, 4) 158 | :param lpos: tensor of local positions with shape (..., Nb of joints, 3) 159 | :param parents: list of parents indices 160 | :return: tuple of tensors of global quaternion, global positions 161 | """ 162 | gp, gr = [lpos[..., :1, :]], [lrot[..., :1, :]] 163 | for i in range(1, len(parents)): 164 | gp.append(quat_mul_vec_tensor(gr[parents[i]], lpos[..., i:i+1, :]) + gp[parents[i]]) 165 | gr.append(quat_mul_tensor(gr[parents[i]], lrot[..., i:i+1, :])) 166 | 167 | return torch.cat(gr, axis=-2), torch.cat(gp, axis=-2) 168 | 169 | def quat_ik(grot, gpos, parents): 170 | """ 171 | Performs Inverse Kinematics (IK) on global quaternions and global positions to retrieve local representations 172 | 173 | :param grot: tensor of global quaternions with shape (..., Nb of joints, 4) 174 | :param gpos: tensor of global positions with shape (..., Nb of joints, 3) 175 | :param parents: list of parents indices 176 | :return: tuple of tensors of local quaternion, local positions 177 | """ 178 | res = [ 179 | np.concatenate([ 180 | grot[..., :1, :], 181 | quat_mul(quat_inv(grot[..., parents[1:], :]), grot[..., 1:, :]), 182 | ], axis=-2), 183 | np.concatenate([ 184 | gpos[..., :1, :], 185 | quat_mul_vec( 186 | quat_inv(grot[..., parents[1:], :]), 187 | gpos[..., 1:, :] - gpos[..., parents[1:], :]), 188 | ], axis=-2) 189 | ] 190 | 191 | return res 192 | 193 | def quat_ik_tensor(grot, gpos, parents): 194 | """ 195 | Performs Inverse Kinematics (IK) on global quaternions and global positions to retrieve local representations 196 | 197 | :param grot: tensor of global quaternions with shape (..., Nb of joints, 4) 198 | :param gpos: tensor of global positions with shape (..., Nb of joints, 3) 199 | :param parents: list of parents indices 200 | :return: tuple of tensors of local quaternion, local positions 201 | """ 202 | res = [ 203 | torch.cat([ 204 | grot[..., :1, :], 205 | quat_mul_tensor(quat_inv_tensor(grot[..., parents[1:], :]), grot[..., 1:, :]), 206 | ], dim=-2), 207 | torch.cat([ 208 | gpos[..., :1, :], 209 | quat_mul_vec_tensor( 210 | quat_inv_tensor(grot[..., parents[1:], :]), 211 | gpos[..., 1:, :] - gpos[..., parents[1:], :]), 212 | ], dim=-2) 213 | ] 214 | 215 | return res 216 | 217 | 218 | def quat_mul_vec(q, x): 219 | """ 220 | Performs multiplication of an array of 3D vectors by an array of quaternions (rotation). 221 | 222 | :param q: tensor of quaternions of shape (..., Nb of joints, 4) 223 | :param x: tensor of vectors of shape (..., Nb of joints, 3) 224 | :return: the resulting array of rotated vectors 225 | """ 226 | t = 2.0 * np.cross(q[..., 1:], x) 227 | res = x + q[..., 0][..., np.newaxis] * t + np.cross(q[..., 1:], t) 228 | 229 | return res 230 | 231 | def quat_mul_vec_tensor(q, x): 232 | """ 233 | Performs multiplication of an array of 3D vectors by an array of quaternions (rotation). 234 | 235 | :param q: tensor of quaternions of shape (..., Nb of joints, 4) 236 | :param x: tensor of vectors of shape (..., Nb of joints, 3) 237 | :return: the resulting array of rotated vectors 238 | """ 239 | t = 2.0 * torch.linalg.cross(q[..., 1:], x, dim=-1) 240 | res = x + q[..., 0].unsqueeze(dim=-1) * t + torch.linalg.cross(q[..., 1:], t, dim=-1) 241 | 242 | return res 243 | 244 | def quat_between(x, y): 245 | """ 246 | Quaternion rotations between two 3D-vector arrays 247 | 248 | :param x: tensor of 3D vectors 249 | :param y: tensor of 3D vetcors 250 | :return: tensor of quaternions 251 | """ 252 | res = np.concatenate([ 253 | np.sqrt(np.sum(x * x, axis=-1) * np.sum(y * y, axis=-1))[..., np.newaxis] + 254 | np.sum(x * y, axis=-1)[..., np.newaxis], 255 | np.cross(x, y)], axis=-1) 256 | return res 257 | 258 | def quat_between_tensor(x, y): 259 | """ 260 | Quaternion rotations between two 3D-vector arrays 261 | 262 | :param x: tensor of 3D vectors 263 | :param y: tensor of 3D vetcors 264 | :return: tensor of quaternions 265 | """ 266 | res = torch.cat([ 267 | torch.sqrt(torch.sum(x * x, dim=-1) * torch.sum(y * y, dim=-1)).unsqueeze(dim=-1) + 268 | torch.sum(x * y, dim=-1).unsqueeze(dim=-1), 269 | torch.linalg.cross(x, y, dim=-1)], dim=-1) 270 | return res 271 | 272 | 273 | def remove_quat_discontinuities(rotations): 274 | """ 275 | 276 | Removing quat discontinuities on the time dimension (removing flips) 277 | 278 | :param rotations: Array of quaternions of shape (T, J, 4) 279 | :return: The processed array without quaternion inversion. 280 | """ 281 | rots_inv = -rotations 282 | 283 | for i in range(1, rotations.shape[0]): 284 | # Compare dot products 285 | replace_mask = np.sum(rotations[i - 1: i] * rotations[i: i + 1], axis=-1) < np.sum( 286 | rotations[i - 1: i] * rots_inv[i: i + 1], axis=-1) 287 | replace_mask = replace_mask[..., np.newaxis] 288 | rotations[i] = replace_mask * rots_inv[i] + (1.0 - replace_mask) * rotations[i] 289 | 290 | return rotations -------------------------------------------------------------------------------- /viz/src/renderBVH.ts: -------------------------------------------------------------------------------- 1 | import { AmbientLight, AnimationClip, AnimationMixer, Bone, BoxGeometry, Clock, Color, CylinderGeometry, GridHelper, Group, KeyframeTrack, LineBasicMaterial, Matrix3, Matrix4, Mesh, MeshStandardMaterial, PerspectiveCamera, Plane, PlaneGeometry, PlaneHelper, Scene, Skeleton, SkeletonHelper, SkinnedMesh, Sphere, SphereGeometry, Vector3, VectorKeyframeTrack, WebGLRenderer } from 'three'; 2 | import { FBXLoader } from 'three/examples/jsm/loaders/FBXLoader'; 3 | 4 | export default class RenderBVH { 5 | renderer: WebGLRenderer; 6 | skeletonHelper: SkeletonHelper; 7 | scene: Scene; 8 | mixer: AnimationMixer; 9 | camera: PerspectiveCamera; 10 | clock: Clock; 11 | id: number; 12 | sphereMeshes: Mesh[]; 13 | cylinderMeshes: Mesh[]; 14 | 15 | constructor(canvas: HTMLCanvasElement, motionSequence: number[][][], clock: Clock, id: number) { 16 | this.clock = clock; 17 | this.id = id; 18 | 19 | this.init(canvas); 20 | this.animate(); 21 | 22 | const [skeleton, animationClip] = this.constructSkeleton(motionSequence); 23 | 24 | this.skeletonHelper = new SkeletonHelper(skeleton.bones[0]); 25 | // @ts-ignore 26 | this.skeletonHelper.skeleton = skeleton 27 | 28 | if (this.skeletonHelper.material instanceof LineBasicMaterial) { 29 | this.skeletonHelper.material.linewidth = 10 30 | } 31 | 32 | // this.scene.add(this.skeletonHelper); 33 | 34 | const boneContainer = new Group(); 35 | boneContainer.add(skeleton.bones[0]); 36 | this.scene.add(boneContainer); 37 | 38 | this.sphereMeshes = []; 39 | this.cylinderMeshes = []; 40 | 41 | const sphereMaterial = new MeshStandardMaterial(); 42 | sphereMaterial.color.setRGB(255, 0, 0); 43 | 44 | const cylinderMaterial = new MeshStandardMaterial(); 45 | cylinderMaterial.color.setRGB(0, 0, 255); 46 | 47 | skeleton.bones.forEach((bone, index) => { 48 | const sphereGeometry = new SphereGeometry(3.2); 49 | 50 | const sphereMesh = new Mesh(sphereGeometry, sphereMaterial); 51 | setSphereMesh(sphereMesh, bone); 52 | 53 | this.sphereMeshes.push(sphereMesh); 54 | this.scene.add(sphereMesh); 55 | }); 56 | 57 | skeleton.bones.forEach((bone, index) => { 58 | if (!(bone.parent instanceof Bone)) return; 59 | 60 | const height = bone.parent.position.distanceTo(bone.position); 61 | 62 | const cylinderGeometry = new CylinderGeometry(1.5, 1.5, height); 63 | 64 | const cylinderMesh = new Mesh(cylinderGeometry, cylinderMaterial); 65 | setCylinderMesh(cylinderMesh, bone); 66 | 67 | this.cylinderMeshes.push(cylinderMesh); 68 | 69 | this.scene.add(cylinderMesh); 70 | }); 71 | 72 | this.mixer = new AnimationMixer(this.skeletonHelper); 73 | this.mixer.clipAction(animationClip).play(); 74 | 75 | // const loader = new FBXLoader(); 76 | // loader.load('./static/rp_eric_rigged_001_u3d.fbx', model => { 77 | // model.traverse(child => { 78 | // if (child instanceof SkinnedMesh) { 79 | // // child.skeleton.bones = child.skeleton.bones.map((_, index) => child.skeleton.bones[permutation[index]]); 80 | 81 | // console.log("FBX Bones | BVH Bones") 82 | // child.skeleton.bones.forEach((bone, index) => console.log(bone.name, bone.parent.name, "|", skeleton.bones[index].name, skeleton.bones[index].parent.name)) 83 | 84 | // // console.log("Original BVH Bones") 85 | // // skeleton.bones.forEach(bone => console.log(bone.name, bone.parent.name)) 86 | 87 | // // skeleton.pose() 88 | 89 | // // child.skeleton.bones = skeleton.bones.map( 90 | // // (_, index) => { 91 | // // // console.log(skeleton.bones, index, skeleton.bones[index]) 92 | // // return child.skeleton.getBoneByName(skeleton.bones[index].name) 93 | // // } 94 | // // ); 95 | 96 | // // console.log(child.skeleton.bones[0]) 97 | // // console.log(skeleton.bones[0]) 98 | 99 | // // child.skeleton.bones = child.skeleton.bones.map((_, index) => skeleton.bones[permutation[index]]); 100 | 101 | // // console.log("Permuted FBX Bones") 102 | // // child.skeleton.bones.forEach(bone => console.log(bone.name, bone.parent.name)) 103 | 104 | // // skeleton.pose(); 105 | // child.bind(skeleton); 106 | // // child.add(skeleton.bones[0]); 107 | 108 | // console.log(child) 109 | 110 | // console.log('\n\n\n') 111 | 112 | // console.log("FBX Bones | BVH Bones") 113 | // child.skeleton.bones.forEach((bone, index) => console.log(bone.name, bone.parent.name, "|", skeleton.bones[index].name, skeleton.bones[index].parent.name)) 114 | 115 | // } 116 | // }); 117 | 118 | // // this.mixer = new AnimationMixer(model); 119 | 120 | // this.scene.add(model); 121 | // }) 122 | } 123 | 124 | init(canvas: HTMLCanvasElement) { 125 | this.camera = new PerspectiveCamera(60, window.innerWidth / window.innerHeight, 1, 10000); 126 | this.camera.position.set(200, 100, 100); 127 | this.camera.lookAt(0, 0, 0); 128 | 129 | this.scene = new Scene(); 130 | this.scene.background = new Color(0xeeeeee); 131 | 132 | this.scene.add(new AmbientLight()); 133 | 134 | const ground = new Mesh( 135 | new PlaneGeometry(300, 300), 136 | new MeshStandardMaterial({ color: 0x333333 }) 137 | ); 138 | ground.rotation.set(-Math.PI / 2, 0, 0); 139 | ground.position.set(0, -100, 0); 140 | this.scene.add(ground); 141 | 142 | 143 | this.renderer = new WebGLRenderer({ antialias: true, canvas: canvas }); 144 | this.renderer.setPixelRatio(window.devicePixelRatio); 145 | this.renderer.setSize(window.innerWidth, window.innerHeight); 146 | } 147 | 148 | resizeCanvasToDisplaySize() { 149 | const canvas = this.renderer.domElement; 150 | // look up the size the canvas is being displayed 151 | const width = canvas.clientWidth; 152 | const height = canvas.clientHeight; 153 | 154 | // adjust displayBuffer size to match 155 | if (canvas.width !== width || canvas.height !== height) { 156 | // you must pass false here or three.js sadly fights the browser 157 | this.renderer.setSize(width, height, false); 158 | this.camera.aspect = width / height; 159 | this.camera.updateProjectionMatrix(); 160 | } 161 | } 162 | 163 | animate() { 164 | requestAnimationFrame(() => this.animate()); 165 | 166 | this.resizeCanvasToDisplaySize(); 167 | 168 | const delta = this.clock.getDelta(); 169 | 170 | if (this.mixer) this.mixer.update(delta); 171 | 172 | if (this.sphereMeshes) this.sphereMeshes.forEach((mesh, index) => { 173 | const bone = this.skeletonHelper.bones[index]; 174 | setSphereMesh(mesh, bone); 175 | }) 176 | 177 | if (this.cylinderMeshes) this.cylinderMeshes.forEach((mesh, index) => { 178 | const bone = this.skeletonHelper.bones[index + 1]; 179 | setCylinderMesh(mesh, bone); 180 | }) 181 | 182 | this.renderer.render(this.scene, this.camera); 183 | } 184 | 185 | constructSkeleton(motionSequence: number[][][]): [Skeleton, AnimationClip] { 186 | const bones: Bone[] = []; 187 | 188 | const pos = motionSequence[0]; 189 | const x = pos.map((p: number[]) => [ 190 | p[0] - pos[0][0], 191 | p[1] - pos[0][1], 192 | p[2] - pos[0][2] 193 | ]); 194 | 195 | // const x = pos; 196 | 197 | parents.forEach((parent, index) => { 198 | const bone = new Bone(); 199 | bone.name = `${names[index]}`; 200 | bone.position.fromArray(x[index]); 201 | bones.push(bone); 202 | 203 | if (parent >= 0) 204 | bones[parent].add(bone); 205 | }); 206 | 207 | const skeleton = new Skeleton(bones); 208 | 209 | const tracks: number[][][] = [...Array(22)].map(() => []); 210 | 211 | motionSequence.forEach((positions: number[][]) => { 212 | positions.forEach((position, index) => { 213 | tracks[index].push([ 214 | position[0] - positions[0][0], 215 | position[1] - positions[0][1], 216 | position[2] - positions[0][2] 217 | ]); 218 | // tracks[index].push(position) 219 | }) 220 | }) 221 | 222 | const times = [...Array(tracks[0].length)].map((_, i) => i * 0.033); 223 | 224 | const keyframeTracks = tracks.map((jointPositions, index) => { 225 | const vector3Sequence = jointPositions.flatMap( 226 | p => p 227 | ) 228 | 229 | const vectorTrack = new VectorKeyframeTrack( 230 | `.bones[${names[index]}].position`, 231 | times, 232 | vector3Sequence 233 | ) 234 | 235 | return vectorTrack 236 | }) 237 | 238 | const animationClip = new AnimationClip('animation', -1, keyframeTracks); 239 | 240 | return [skeleton, animationClip]; 241 | } 242 | } 243 | 244 | const setSphereMesh = (mesh: Mesh, bone: Bone) => { 245 | mesh.position.fromArray(bone.position.toArray()); // .add(bone.parent.position).divide(new Vector3(2, 2, 2)) 246 | 247 | // console.log(bone.position.clone().sub(bone.parent.position).normalize()) 248 | } 249 | 250 | const setCylinderMesh = (mesh: Mesh, bone: Bone) => { 251 | const parent = bone.parent.position.clone(); 252 | const self = bone.position.clone(); 253 | const d = self.distanceTo(parent); 254 | 255 | mesh.geometry.dispose(); 256 | 257 | mesh.geometry = new CylinderGeometry(1.5, 1.5, d); 258 | 259 | mesh.position.fromArray( 260 | parent.clone().add(self.sub(parent).divide(new Vector3(2, 2, 2))).toArray() 261 | ); 262 | 263 | mesh.lookAt(bone.position); 264 | mesh.rotateX(1.57); 265 | } 266 | 267 | const parents = [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 11, 18, 19, 20]; 268 | // const permutation = [0, 9, 10, 11, 12, 13, 18, 19, 20, 21, 14, 15, 16, 17, 5, 6, 7, 8, 1, 2, 3, 4]; 269 | const permutation = [0, 14, 15, 16, 17, 18, 19, 20, 21, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; 270 | const names = [ 271 | "ModelHips", 272 | "ModelLeftUpLeg", 273 | "ModelLeftLeg", 274 | "ModelLeftFoot", 275 | "ModelLeftToe", 276 | "ModelRightUpLeg", 277 | "ModelRightLeg", 278 | "ModelRightFoot", 279 | "ModelRightToe", 280 | "ModelSpine", 281 | "ModelSpine1", 282 | "ModelSpine2", 283 | "ModelNeck", 284 | "ModelHead", 285 | "ModelLeftShoulder", 286 | "ModelLeftArm", 287 | "ModelLeftForeArm", 288 | "ModelLeftHand", 289 | "ModelRightShoulder", 290 | "ModelRightArm", 291 | "ModelRightForeArm", 292 | "ModelRightHand", 293 | ] -------------------------------------------------------------------------------- /generator_frontend/src/renderBVH.ts: -------------------------------------------------------------------------------- 1 | import { AmbientLight, AnimationClip, AnimationMixer, Bone, BoxGeometry, Clock, Color, CylinderGeometry, GridHelper, Group, KeyframeTrack, LineBasicMaterial, Matrix3, Matrix4, Mesh, MeshStandardMaterial, PerspectiveCamera, Plane, PlaneGeometry, PlaneHelper, Scene, Skeleton, SkeletonHelper, SkinnedMesh, Sphere, SphereGeometry, Vector3, VectorKeyframeTrack, WebGLRenderer } from 'three'; 2 | import { FBXLoader } from 'three/examples/jsm/loaders/FBXLoader'; 3 | 4 | export default class RenderBVH { 5 | renderer: WebGLRenderer; 6 | skeletonHelper: SkeletonHelper; 7 | scene: Scene; 8 | mixer: AnimationMixer; 9 | camera: PerspectiveCamera; 10 | clock: Clock; 11 | id: number; 12 | sphereMeshes: Mesh[]; 13 | cylinderMeshes: Mesh[]; 14 | 15 | constructor(canvas: HTMLCanvasElement, motionSequence: number[][][], clock: Clock, id: number) { 16 | this.clock = clock; 17 | this.id = id; 18 | 19 | this.init(canvas); 20 | this.animate(); 21 | 22 | const [skeleton, animationClip] = this.constructSkeleton(motionSequence); 23 | 24 | this.skeletonHelper = new SkeletonHelper(skeleton.bones[0]); 25 | // @ts-ignore 26 | this.skeletonHelper.skeleton = skeleton 27 | 28 | if (this.skeletonHelper.material instanceof LineBasicMaterial) { 29 | this.skeletonHelper.material.linewidth = 10 30 | } 31 | 32 | // this.scene.add(this.skeletonHelper); 33 | 34 | const boneContainer = new Group(); 35 | boneContainer.add(skeleton.bones[0]); 36 | this.scene.add(boneContainer); 37 | 38 | this.sphereMeshes = []; 39 | this.cylinderMeshes = []; 40 | 41 | const sphereMaterial = new MeshStandardMaterial(); 42 | sphereMaterial.color.setRGB(255, 0, 0); 43 | 44 | const cylinderMaterial = new MeshStandardMaterial(); 45 | cylinderMaterial.color.setRGB(0, 0, 255); 46 | 47 | skeleton.bones.forEach((bone, index) => { 48 | const sphereGeometry = new SphereGeometry(3.2); 49 | 50 | const sphereMesh = new Mesh(sphereGeometry, sphereMaterial); 51 | setSphereMesh(sphereMesh, bone); 52 | 53 | this.sphereMeshes.push(sphereMesh); 54 | this.scene.add(sphereMesh); 55 | }); 56 | 57 | skeleton.bones.forEach((bone, index) => { 58 | if (!(bone.parent instanceof Bone)) return; 59 | 60 | const height = bone.parent.position.distanceTo(bone.position); 61 | 62 | const cylinderGeometry = new CylinderGeometry(1.5, 1.5, height); 63 | 64 | const cylinderMesh = new Mesh(cylinderGeometry, cylinderMaterial); 65 | setCylinderMesh(cylinderMesh, bone); 66 | 67 | this.cylinderMeshes.push(cylinderMesh); 68 | 69 | this.scene.add(cylinderMesh); 70 | }); 71 | 72 | this.mixer = new AnimationMixer(this.skeletonHelper); 73 | this.mixer.clipAction(animationClip).play(); 74 | 75 | // const loader = new FBXLoader(); 76 | // loader.load('./static/rp_eric_rigged_001_u3d.fbx', model => { 77 | // model.traverse(child => { 78 | // if (child instanceof SkinnedMesh) { 79 | // // child.skeleton.bones = child.skeleton.bones.map((_, index) => child.skeleton.bones[permutation[index]]); 80 | 81 | // console.log("FBX Bones | BVH Bones") 82 | // child.skeleton.bones.forEach((bone, index) => console.log(bone.name, bone.parent.name, "|", skeleton.bones[index].name, skeleton.bones[index].parent.name)) 83 | 84 | // // console.log("Original BVH Bones") 85 | // // skeleton.bones.forEach(bone => console.log(bone.name, bone.parent.name)) 86 | 87 | // // skeleton.pose() 88 | 89 | // // child.skeleton.bones = skeleton.bones.map( 90 | // // (_, index) => { 91 | // // // console.log(skeleton.bones, index, skeleton.bones[index]) 92 | // // return child.skeleton.getBoneByName(skeleton.bones[index].name) 93 | // // } 94 | // // ); 95 | 96 | // // console.log(child.skeleton.bones[0]) 97 | // // console.log(skeleton.bones[0]) 98 | 99 | // // child.skeleton.bones = child.skeleton.bones.map((_, index) => skeleton.bones[permutation[index]]); 100 | 101 | // // console.log("Permuted FBX Bones") 102 | // // child.skeleton.bones.forEach(bone => console.log(bone.name, bone.parent.name)) 103 | 104 | // // skeleton.pose(); 105 | // child.bind(skeleton); 106 | // // child.add(skeleton.bones[0]); 107 | 108 | // console.log(child) 109 | 110 | // console.log('\n\n\n') 111 | 112 | // console.log("FBX Bones | BVH Bones") 113 | // child.skeleton.bones.forEach((bone, index) => console.log(bone.name, bone.parent.name, "|", skeleton.bones[index].name, skeleton.bones[index].parent.name)) 114 | 115 | // } 116 | // }); 117 | 118 | // // this.mixer = new AnimationMixer(model); 119 | 120 | // this.scene.add(model); 121 | // }) 122 | } 123 | 124 | init(canvas: HTMLCanvasElement) { 125 | this.camera = new PerspectiveCamera(60, window.innerWidth / window.innerHeight, 1, 10000); 126 | this.camera.position.set(200, 100, 100); 127 | this.camera.lookAt(0, 0, 0); 128 | 129 | this.scene = new Scene(); 130 | this.scene.background = new Color(0xeeeeee); 131 | 132 | this.scene.add(new AmbientLight()); 133 | 134 | const ground = new Mesh( 135 | new PlaneGeometry(300, 300), 136 | new MeshStandardMaterial({ color: 0x333333 }) 137 | ); 138 | ground.rotation.set(-Math.PI / 2, 0, 0); 139 | ground.position.set(0, -100, 0); 140 | this.scene.add(ground); 141 | 142 | 143 | this.renderer = new WebGLRenderer({ antialias: true, canvas: canvas }); 144 | this.renderer.setPixelRatio(window.devicePixelRatio); 145 | this.renderer.setSize(window.innerWidth, window.innerHeight); 146 | } 147 | 148 | resizeCanvasToDisplaySize() { 149 | const canvas = this.renderer.domElement; 150 | // look up the size the canvas is being displayed 151 | const width = canvas.clientWidth; 152 | const height = canvas.clientHeight; 153 | 154 | // adjust displayBuffer size to match 155 | if (canvas.width !== width || canvas.height !== height) { 156 | // you must pass false here or three.js sadly fights the browser 157 | this.renderer.setSize(width, height, false); 158 | this.camera.aspect = width / height; 159 | this.camera.updateProjectionMatrix(); 160 | } 161 | } 162 | 163 | animate() { 164 | requestAnimationFrame(() => this.animate()); 165 | 166 | this.resizeCanvasToDisplaySize(); 167 | 168 | const delta = this.clock.getDelta(); 169 | 170 | if (this.mixer) this.mixer.update(delta); 171 | 172 | if (this.sphereMeshes) this.sphereMeshes.forEach((mesh, index) => { 173 | const bone = this.skeletonHelper.bones[index]; 174 | setSphereMesh(mesh, bone); 175 | }) 176 | 177 | if (this.cylinderMeshes) this.cylinderMeshes.forEach((mesh, index) => { 178 | const bone = this.skeletonHelper.bones[index + 1]; 179 | setCylinderMesh(mesh, bone); 180 | }) 181 | 182 | this.renderer.render(this.scene, this.camera); 183 | } 184 | 185 | constructSkeleton(motionSequence: number[][][]): [Skeleton, AnimationClip] { 186 | const bones: Bone[] = []; 187 | 188 | const pos = motionSequence[0]; 189 | const x = pos.map((p: number[]) => [ 190 | p[0] - pos[0][0], 191 | p[1] - pos[0][1], 192 | p[2] - pos[0][2] 193 | ]); 194 | 195 | // const x = pos; 196 | 197 | parents.forEach((parent, index) => { 198 | const bone = new Bone(); 199 | bone.name = `${names[index]}`; 200 | bone.position.fromArray(x[index]); 201 | bones.push(bone); 202 | 203 | if (parent >= 0) 204 | bones[parent].add(bone); 205 | }); 206 | 207 | const skeleton = new Skeleton(bones); 208 | 209 | const tracks: number[][][] = [...Array(22)].map(() => []); 210 | 211 | motionSequence.forEach((positions: number[][]) => { 212 | positions.forEach((position, index) => { 213 | tracks[index].push([ 214 | position[0] - positions[0][0], 215 | position[1] - positions[0][1], 216 | position[2] - positions[0][2] 217 | ]); 218 | // tracks[index].push(position) 219 | }) 220 | }) 221 | 222 | const times = [...Array(tracks[0].length)].map((_, i) => i * 0.033); 223 | 224 | const keyframeTracks = tracks.map((jointPositions, index) => { 225 | const vector3Sequence = jointPositions.flatMap( 226 | p => p 227 | ) 228 | 229 | const vectorTrack = new VectorKeyframeTrack( 230 | `.bones[${names[index]}].position`, 231 | times, 232 | vector3Sequence 233 | ) 234 | 235 | return vectorTrack 236 | }) 237 | 238 | const animationClip = new AnimationClip('animation', -1, keyframeTracks); 239 | 240 | return [skeleton, animationClip]; 241 | } 242 | } 243 | 244 | const setSphereMesh = (mesh: Mesh, bone: Bone) => { 245 | mesh.position.fromArray(bone.position.toArray()); // .add(bone.parent.position).divide(new Vector3(2, 2, 2)) 246 | 247 | // console.log(bone.position.clone().sub(bone.parent.position).normalize()) 248 | } 249 | 250 | const setCylinderMesh = (mesh: Mesh, bone: Bone) => { 251 | const parent = bone.parent.position.clone(); 252 | const self = bone.position.clone(); 253 | const d = self.distanceTo(parent); 254 | 255 | mesh.geometry.dispose(); 256 | 257 | mesh.geometry = new CylinderGeometry(1.5, 1.5, d); 258 | 259 | mesh.position.fromArray( 260 | parent.clone().add(self.sub(parent).divide(new Vector3(2, 2, 2))).toArray() 261 | ); 262 | 263 | mesh.lookAt(bone.position); 264 | mesh.rotateX(1.57); 265 | } 266 | 267 | const parents = [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 11, 18, 19, 20]; 268 | // const permutation = [0, 9, 10, 11, 12, 13, 18, 19, 20, 21, 14, 15, 16, 17, 5, 6, 7, 8, 1, 2, 3, 4]; 269 | const permutation = [0, 14, 15, 16, 17, 18, 19, 20, 21, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; 270 | const names = [ 271 | "ModelHips", 272 | "ModelLeftUpLeg", 273 | "ModelLeftLeg", 274 | "ModelLeftFoot", 275 | "ModelLeftToe", 276 | "ModelRightUpLeg", 277 | "ModelRightLeg", 278 | "ModelRightFoot", 279 | "ModelRightToe", 280 | "ModelSpine", 281 | "ModelSpine1", 282 | "ModelSpine2", 283 | "ModelNeck", 284 | "ModelHead", 285 | "ModelLeftShoulder", 286 | "ModelLeftArm", 287 | "ModelLeftForeArm", 288 | "ModelLeftHand", 289 | "ModelRightShoulder", 290 | "ModelRightArm", 291 | "ModelRightForeArm", 292 | "ModelRightHand", 293 | ] -------------------------------------------------------------------------------- /generator_frontend/src/motionEditor.ts: -------------------------------------------------------------------------------- 1 | import { AmbientLight, AnimationClip, AnimationMixer, Bone, BoxGeometry, Clock, Color, CylinderGeometry, GridHelper, Group, KeyframeTrack, LineBasicMaterial, Matrix3, Matrix4, Mesh, MeshStandardMaterial, Object3D, PerspectiveCamera, Plane, PlaneGeometry, PlaneHelper, Ray, Raycaster, Scene, Skeleton, SkeletonHelper, SkinnedMesh, Sphere, SphereGeometry, Vector2, Vector3, VectorKeyframeTrack, WebGLRenderer } from 'three'; 2 | import Controls from './controls'; 3 | 4 | export default class MotionEditor { 5 | renderer: WebGLRenderer; 6 | skeletonHelper: SkeletonHelper; 7 | scene: Scene; 8 | mixer: AnimationMixer; 9 | camera: PerspectiveCamera; 10 | clock: Clock; 11 | id: number; 12 | sphereMeshes: Mesh[]; 13 | cylinderMeshes: Mesh[]; 14 | track: number[][][]; 15 | frame: number; 16 | raycaster: Raycaster; 17 | pointer: Vector2; 18 | hovered: Object3D; 19 | clicked: Object3D; 20 | clickedIndex: number; 21 | controls: Controls; 22 | 23 | constructor(canvas: HTMLCanvasElement, motionSequence: number[][][], controls: Controls) { 24 | this.init(canvas); 25 | 26 | this.controls = controls; 27 | this.controls.hide(); 28 | 29 | const [skeleton, track] = this.constructSkeleton(motionSequence); 30 | 31 | this.track = track; 32 | 33 | this.frame = 0; 34 | 35 | this.skeletonHelper = new SkeletonHelper(skeleton.bones[0]); 36 | // @ts-ignore 37 | this.skeletonHelper.skeleton = skeleton 38 | 39 | if (this.skeletonHelper.material instanceof LineBasicMaterial) { 40 | this.skeletonHelper.material.linewidth = 10 41 | } 42 | 43 | // this.scene.add(this.skeletonHelper); 44 | 45 | const boneContainer = new Group(); 46 | boneContainer.add(skeleton.bones[0]); 47 | this.scene.add(boneContainer); 48 | 49 | this.sphereMeshes = []; 50 | this.cylinderMeshes = []; 51 | 52 | const cylinderMaterial = new MeshStandardMaterial(); 53 | cylinderMaterial.color.setRGB(0, 0, 255); 54 | 55 | skeleton.bones.forEach((bone, index) => { 56 | const sphereGeometry = new SphereGeometry(3.2); 57 | 58 | const sphereMaterial = new MeshStandardMaterial(); 59 | sphereMaterial.color.setRGB(255, 0, 0); 60 | 61 | const sphereMesh = new Mesh(sphereGeometry, sphereMaterial); 62 | setSphereMesh(sphereMesh, bone); 63 | 64 | this.sphereMeshes.push(sphereMesh); 65 | this.scene.add(sphereMesh); 66 | }); 67 | 68 | skeleton.bones.forEach((bone, index) => { 69 | if (!(bone.parent instanceof Bone)) return; 70 | 71 | const height = bone.parent.position.distanceTo(bone.position); 72 | 73 | const cylinderGeometry = new CylinderGeometry(1.5, 1.5, height); 74 | 75 | const cylinderMesh = new Mesh(cylinderGeometry, cylinderMaterial); 76 | setCylinderMesh(cylinderMesh, bone); 77 | 78 | this.cylinderMeshes.push(cylinderMesh); 79 | 80 | this.scene.add(cylinderMesh); 81 | }); 82 | 83 | this.initControls(); 84 | 85 | this.animate(); 86 | 87 | // this.mixer = new AnimationMixer(this.skeletonHelper); 88 | // this.mixer.clipAction(animationClip).play(); 89 | 90 | } 91 | 92 | init(canvas: HTMLCanvasElement) { 93 | this.camera = new PerspectiveCamera(60, window.innerWidth / window.innerHeight, 1, 10000); 94 | this.camera.position.set(200, 100, 100); 95 | this.camera.lookAt(0, 0, 0); 96 | 97 | this.scene = new Scene(); 98 | this.scene.background = new Color(0xeeeeee); 99 | 100 | this.scene.add(new AmbientLight()); 101 | 102 | const ground = new Mesh( 103 | new PlaneGeometry(300, 300), 104 | new MeshStandardMaterial({ color: 0x333333 }) 105 | ); 106 | ground.rotation.set(-Math.PI / 2, 0, 0); 107 | ground.position.set(0, -100, 0); 108 | this.scene.add(ground); 109 | 110 | 111 | this.renderer = new WebGLRenderer({ antialias: true, canvas: canvas }); 112 | this.renderer.setPixelRatio(window.devicePixelRatio); 113 | this.renderer.setSize(window.innerWidth, window.innerHeight); 114 | } 115 | 116 | resizeCanvasToDisplaySize() { 117 | const canvas = this.renderer.domElement; 118 | // look up the size the canvas is being displayed 119 | const width = canvas.clientWidth; 120 | const height = canvas.clientHeight; 121 | 122 | // adjust displayBuffer size to match 123 | if (canvas.width !== width || canvas.height !== height) { 124 | // you must pass false here or three.js sadly fights the browser 125 | this.renderer.setSize(width, height, false); 126 | this.camera.aspect = width / height; 127 | this.camera.updateProjectionMatrix(); 128 | } 129 | } 130 | 131 | initControls() { 132 | this.raycaster = new Raycaster(); 133 | this.pointer = new Vector2(); 134 | 135 | window.addEventListener('keydown', (event: KeyboardEvent) => { 136 | switch (event.key) { 137 | case 'ArrowLeft': 138 | this.frame = (this.frame + this.track.length - 30) % this.track.length; 139 | break; 140 | case 'ArrowRight': 141 | this.frame = (this.frame + 30) % this.track.length; 142 | break; 143 | } 144 | }); 145 | 146 | window.addEventListener('pointermove', (event: MouseEvent) => { 147 | this.pointer.x = ( event.clientX / window.innerWidth ) * 2 - 1; 148 | this.pointer.y = - ( event.clientY / window.innerHeight ) * 2 + 1; 149 | }); 150 | 151 | window.addEventListener('click', () => { 152 | this.handleClick(); 153 | }) 154 | } 155 | 156 | handleIntersection() { 157 | this.raycaster.setFromCamera(this.pointer, this.camera); 158 | 159 | const intersects = this.raycaster.intersectObjects(this.sphereMeshes); 160 | 161 | if (intersects.length) { 162 | if (this.clicked === intersects[0].object) return; 163 | 164 | // @ts-ignore 165 | intersects[ 0 ].object.material.color.set( 0x880000 ); 166 | 167 | this.hovered = intersects[0].object; 168 | } else { 169 | if (this.clicked === this.hovered) return; 170 | 171 | // @ts-ignore 172 | this.hovered.material.color.set(0xff0000) 173 | 174 | this.hovered = undefined; 175 | } 176 | } 177 | 178 | handleClick() { 179 | console.log("onclick") 180 | 181 | if (this.clicked) { 182 | // @ts-ignore 183 | this.clicked.material.color.set(0xff0000) 184 | 185 | this.clicked = undefined; 186 | } 187 | 188 | this.raycaster.setFromCamera(this.pointer, this.camera); 189 | 190 | const intersects = this.raycaster.intersectObjects(this.sphereMeshes); 191 | 192 | if (intersects.length) { 193 | this.controls.show(); 194 | this.controls.set(intersects[0].object.position); 195 | 196 | // @ts-ignore 197 | intersects[ 0 ].object.material.color.set( 0x00ff00 ); 198 | 199 | this.clicked = intersects[0].object; 200 | 201 | // @ts-ignore 202 | this.clickedIndex = this.sphereMeshes.findIndex((sphereMesh) => { 203 | return sphereMesh.id === this.clicked.id; 204 | }) 205 | } 206 | } 207 | 208 | animate() { 209 | requestAnimationFrame(() => this.animate()); 210 | // setTimeout(() => this.animate(), 1000); 211 | 212 | this.resizeCanvasToDisplaySize(); 213 | 214 | if (this.clicked) { 215 | this.track[this.frame][this.clickedIndex] = this.controls.position.toArray(); 216 | } 217 | 218 | if (this.track && this.skeletonHelper) { 219 | // this.frame = (this.frame + 1) % this.track.length; 220 | 221 | this.skeletonHelper.bones.forEach((bone, index) => { 222 | bone.position.fromArray(this.track[this.frame][index]) 223 | }); 224 | } 225 | 226 | if (this.sphereMeshes) this.sphereMeshes.forEach((mesh, index) => { 227 | const bone = this.skeletonHelper.bones[index]; 228 | setSphereMesh(mesh, bone); 229 | }) 230 | 231 | if (this.cylinderMeshes) this.cylinderMeshes.forEach((mesh, index) => { 232 | const bone = this.skeletonHelper.bones[index + 1]; 233 | setCylinderMesh(mesh, bone); 234 | }) 235 | 236 | this.handleIntersection(); 237 | 238 | this.renderer.render(this.scene, this.camera); 239 | } 240 | 241 | constructSkeleton(motionSequence: number[][][]): [Skeleton, number[][][]] { 242 | const bones: Bone[] = []; 243 | 244 | const pos = motionSequence[0]; 245 | const x = pos.map((p: number[]) => [ 246 | p[0] - pos[0][0], 247 | p[1] - pos[0][1], 248 | p[2] - pos[0][2] 249 | ]); 250 | 251 | // const x = pos; 252 | 253 | parents.forEach((parent, index) => { 254 | const bone = new Bone(); 255 | bone.name = `${names[index]}`; 256 | bone.position.fromArray(x[index]); 257 | bones.push(bone); 258 | 259 | if (parent >= 0) 260 | bones[parent].add(bone); 261 | }); 262 | 263 | const skeleton = new Skeleton(bones); 264 | 265 | const tracks: number[][][] = [...Array(motionSequence.length)].map(() => []); 266 | 267 | motionSequence.forEach((positions: number[][], frameNo) => { 268 | positions.forEach((position, boneNo) => { 269 | tracks[frameNo].push([ 270 | position[0] - positions[0][0], 271 | position[1] - positions[0][1], 272 | position[2] - positions[0][2] 273 | ]); 274 | // tracks[index].push(position) 275 | }) 276 | }) 277 | 278 | return [skeleton, tracks]; 279 | } 280 | } 281 | 282 | const setSphereMesh = (mesh: Mesh, bone: Bone) => { 283 | mesh.position.fromArray(bone.position.toArray()); // .add(bone.parent.position).divide(new Vector3(2, 2, 2)) 284 | 285 | // console.log(bone.position.clone().sub(bone.parent.position).normalize()) 286 | } 287 | 288 | const setCylinderMesh = (mesh: Mesh, bone: Bone) => { 289 | const parent = bone.parent.position.clone(); 290 | const self = bone.position.clone(); 291 | const d = self.distanceTo(parent); 292 | 293 | mesh.geometry.dispose(); 294 | 295 | mesh.geometry = new CylinderGeometry(1.5, 1.5, d); 296 | 297 | mesh.position.fromArray( 298 | parent.clone().add(self.sub(parent).divide(new Vector3(2, 2, 2))).toArray() 299 | ); 300 | 301 | mesh.lookAt(bone.position); 302 | mesh.rotateX(1.57); 303 | } 304 | 305 | const parents = [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 11, 18, 19, 20]; 306 | // const permutation = [0, 9, 10, 11, 12, 13, 18, 19, 20, 21, 14, 15, 16, 17, 5, 6, 7, 8, 1, 2, 3, 4]; 307 | const permutation = [0, 14, 15, 16, 17, 18, 19, 20, 21, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; 308 | const names = [ 309 | "ModelHips", 310 | "ModelLeftUpLeg", 311 | "ModelLeftLeg", 312 | "ModelLeftFoot", 313 | "ModelLeftToe", 314 | "ModelRightUpLeg", 315 | "ModelRightLeg", 316 | "ModelRightFoot", 317 | "ModelRightToe", 318 | "ModelSpine", 319 | "ModelSpine1", 320 | "ModelSpine2", 321 | "ModelNeck", 322 | "ModelHead", 323 | "ModelLeftShoulder", 324 | "ModelLeftArm", 325 | "ModelLeftForeArm", 326 | "ModelLeftHand", 327 | "ModelRightShoulder", 328 | "ModelRightArm", 329 | "ModelRightForeArm", 330 | "ModelRightHand", 331 | ] --------------------------------------------------------------------------------