├── interp.py ├── LICENSE ├── liveplot.py ├── dataloader.py ├── eval.py ├── README.md ├── model.py ├── train.py ├── env.py ├── spatialhelper.py ├── agent.py └── bvh.py /interp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def lerp(key1, key2, t1, t2): 5 | interval = t2 - t1 6 | linear_step = (key2 - key1) / interval 7 | linear_interp = np.expand_dims(key1, 1).repeat(interval + 1, axis=1) + \ 8 | np.expand_dims(linear_step, 1).repeat(interval + 1, axis=1) * np.arange(interval + 1) 9 | 10 | return linear_interp 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 MiniEval 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /liveplot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | class LivePlot: 5 | def __init__(self): 6 | self.reward = [] 7 | self.episodes = [] 8 | self.q_loss = [] 9 | self.reward_avgs = [] 10 | plt.ioff() 11 | 12 | def append(self, x, reward, q_loss): 13 | self.episodes.append(x) 14 | self.reward.append(reward) 15 | self.q_loss.append(q_loss) 16 | 17 | plt.figure(1) 18 | plt.xlim(-0.01, max(self.episodes)) 19 | plt.ylim(min(self.reward), max(self.reward) + 0.01) 20 | 21 | plt.scatter(self.episodes[-1:], self.reward[-1:], color="#000000") 22 | 23 | if len(self.reward) > 100: 24 | self.reward_avgs.append(sum(self.reward[-100:]) / 100) 25 | if len(self.reward_avgs) > 1: 26 | plt.plot(self.episodes[-2:], self.reward_avgs[-2:], color="#ffa500") 27 | 28 | plt.grid(True) 29 | plt.xlabel("Episode") 30 | plt.ylabel("$R_1$") 31 | 32 | plt.draw() 33 | 34 | plt.figure(2) 35 | plt.xlim(-0.01, max(self.episodes)) 36 | plt.ylim(-0.01, max(self.q_loss)) 37 | 38 | plt.plot(self.episodes[-2:], self.q_loss[-2:], color="#0000ff") 39 | plt.grid(True) 40 | plt.xlabel("Episode") 41 | plt.ylabel("Q-Loss") 42 | 43 | plt.draw() 44 | 45 | def draw(self, block=False): 46 | plt.pause(0.1) 47 | plt.show(block=block) 48 | 49 | def save(self): 50 | with open("q_loss.log", "w") as f: 51 | f.write(str(self.q_loss)) 52 | f.close() 53 | 54 | with open("reward.log", "w") as f: 55 | f.write(str(self.reward)) 56 | f.close() 57 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from pathlib import Path 4 | import sys 5 | 6 | 7 | class DataLoader: 8 | def __init__(self): 9 | train_data_paths = list(Path("./train_data/").rglob("*.npy")) 10 | test_data_paths = list(Path("./test_data/").rglob("*.npy")) 11 | 12 | if len(train_data_paths) == 0: 13 | print("Training data not found. Please ensure your training data is in the \"train_data\" subfolder.") 14 | sys.exit() 15 | if len(test_data_paths) == 0: 16 | print("Testing data not found. Please ensure your training data is in the \"test_data\" subfolder.") 17 | sys.exit() 18 | 19 | self.train_data = [] 20 | self.test_data = [] 21 | for path in train_data_paths: 22 | self.train_data.append(np.load(path)) 23 | for path in test_data_paths: 24 | self.test_data.append(np.load(path)) 25 | 26 | self.train_max = 0 27 | self.test_max = 0 28 | for data in self.train_data: 29 | self.train_max = max(self.train_max, data.shape[1]) 30 | for data in self.test_data: 31 | self.test_max = max(self.test_max, data.shape[1]) 32 | 33 | def sample_train(self, length): 34 | if self.train_max < length: 35 | return None 36 | 37 | data = None 38 | while data is None: 39 | sample = random.sample(self.train_data, 1)[0] 40 | if sample.shape[1] >= length: 41 | start_frame = random.randint(0, sample.shape[1] - length) 42 | data = sample[:, start_frame:start_frame + length] 43 | 44 | return data 45 | 46 | def sample_test(self, length): 47 | if self.test_max < length: 48 | return None 49 | 50 | data = None 51 | while data is None: 52 | sample = random.sample(self.test_data, 1)[0] 53 | if sample.shape[1] >= length: 54 | start_frame = random.randint(0, sample.shape[1] - length) 55 | data = sample[:, start_frame:start_frame + length] 56 | 57 | return data 58 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | 5 | from agent import KeyframeExtractor 6 | from env import Env 7 | from bvh import CMU_parse 8 | 9 | 10 | def input_error(): 11 | print("Usage: python eval.py [BVH FILE] [NUMBER OF KEYFRAMES]") 12 | sys.exit() 13 | 14 | 15 | if __name__ == "__main__": 16 | if len(sys.argv) == 3: 17 | if not os.path.isfile(sys.argv[1]): 18 | print("BVH file", sys.argv[1], "cannot be found.") 19 | sys.exit() 20 | try: 21 | num_keys = int(sys.argv[2]) 22 | except ValueError: 23 | input_error() 24 | else: 25 | input_error() 26 | 27 | if not os.path.isfile("./model.pt"): 28 | print("Trained model cannot be found. Please download a pre-trained model, or train a new model using train.py") 29 | sys.exit() 30 | 31 | print("Processing BVH file...", end="") 32 | with open(sys.argv[1], "r") as f: 33 | bvh_file = f.read() 34 | f.close() 35 | bvh_input = CMU_parse(bvh_file) 36 | print("done!") 37 | 38 | print("Loading GKEN model...", end="") 39 | agent = KeyframeExtractor(gamma=0.99, eps=1.0, eps_factor=2000, learning_rate=0.0004, training_len=72, mem_size=1, 40 | batch_size=1, tau=(200, 1.0)) 41 | agent.load_models() 42 | agent.policy_net.eval() 43 | print("done!") 44 | 45 | print("Evaluating keyframes...") 46 | env = Env(bvh_input, num_keys) 47 | env_state, remaining_actions = env.get_state() 48 | 49 | done = False 50 | state = torch.tensor(env_state, dtype=torch.float) 51 | while not done: 52 | remain = torch.tensor(remaining_actions / state.shape[1]).view(1, 1) 53 | action = agent.select_action(state.unsqueeze(0), remain, env.get_keyframes()[1:-1], use_eps=False) 54 | next_state, next_remaining_actions, _, done, result = env.step(action) 55 | if not done: 56 | state = torch.tensor(next_state, dtype=torch.float) 57 | remaining_actions = next_remaining_actions 58 | 59 | keyframe_mask = result == 1 60 | keyframe_idx = [] 61 | for k in range(keyframe_mask.shape[0]): 62 | if keyframe_mask[k].item(): 63 | keyframe_idx.append(k) 64 | print("done!") 65 | 66 | print("\n24 FPS Keyframes:", keyframe_idx) 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Keyframe Evaluation Network - PyTorch implementation 2 | Official Pytorch implementation of the ACM MM'21 paper: [Keyframe Extraction from Motion Capture Sequences with Graph based Deep Reinforcement Learning](https://dl.acm.org/doi/10.1145/3474085.3475635) 3 | 4 | # Prerequisites 5 | - [Python 3.9](https://www.python.org/) 6 | - [PyTorch 3.9](https://github.com/pytorch/pytorch) with CUDA 7 | - [PyTorch Geometric 1.7.2](https://github.com/rusty1s/pytorch_geometric) with CUDA 8 | - [NumPy v1.21.0](https://github.com/numpy/numpy) 9 | - [Matplotlib 3.4.2](https://github.com/matplotlib/matplotlib) 10 | 11 | # Usage 12 | ## Training 13 | To train the model, download our preprocessed CMU Mocap dataset from the [releases](https://github.com/MiniEval/pytorch-gken/releases/tag/1) page and place the `/train_data` and `/test_data` folders in the repository root. `train.py` is used as follows: 14 | 15 | `python train.py [NUMBER OF EPISODES]` 16 | 17 | The paper uses 5000 episodes to train the model. 18 | 19 | We update two Matplotlib diagrams per 200 episodes. The scatter plot displays rewards over time, while the line chart displays Q-loss over time. 20 | 21 | 22 | ### Custom datasets 23 | We provide a `CMU_parse(file, start=1, frame_skip=5)` function in `bvh.py`. This function properly formats motion capture data from the [BVH conversion of the CMU Mocap Dataset[(https://sites.google.com/a/cgspeed.com/cgspeed/motion-capture) into a NumPy tensor. The NumPy tensor can be saved as a file that is used by our `dataloader.py` The default parameters of `CMU_parse` scale the BVH file down from 120FPS to 24FPS. 24 | 25 | ## Evaluation 26 | To evaluate with our model, download our pre-trained model from the [releases](https://github.com/MiniEval/pytorch-gken/releases/tag/1) page and place `model.pt` in the repository root. Alternatively, you may train your own model with `train.py`. 27 | 28 | The evaluation accepts BVH files as input, using the CMU Mocap skeleton format. `eval.py` is used as follows: 29 | 30 | `python eval.py [BVH FILE] [NUMBER OF KEYFRAMES]` 31 | 32 | # Human annotations 33 | In the [releases](https://github.com/MiniEval/pytorch-gken/releases/tag/1) page, we provide five sets of human annotations in `Keyframe Extraction - Demonstration.blend`, which can be opened using [Blender](https://www.blender.org/). 34 | 35 | # Acknowledgements 36 | The motion capture data used in this project was obtained from [mocap.cs.cmu.edu](https://mocap.cs.cmu.edu). The database was created with funding from the American National Science Foundation, under EIA-0196217. 37 | 38 | The BVH file parsing module was written by [20tab srl](https://github.com/20tab/bvh-python). 39 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch_geometric.nn as gnn 4 | from torch_geometric.data import Data, Batch 5 | 6 | 7 | NUM_BONES = 23 8 | device = torch.device("cuda") 9 | 10 | 11 | class GKEN(nn.Module): 12 | def __init__(self, n_hidden=128): 13 | super(GKEN, self).__init__() 14 | 15 | self.CMU_SKELETON = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 16 | 22, 0, 1, 2, 3, 4, 5, 3, 7, 8, 9, 3, 11, 12, 13, 0, 15, 16, 17, 0, 19, 20, 17 | 21], 18 | [0, 1, 2, 3, 4, 5, 3, 7, 8, 9, 3, 11, 12, 13, 0, 15, 16, 17, 0, 19, 20, 21, 1, 19 | 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]], 20 | dtype=torch.long, device=torch.device("cpu")) 21 | 22 | self.n_hidden = n_hidden 23 | 24 | self.gconv1 = gnn.GCNConv(7, n_hidden, improved=True, bias=False) 25 | self.gbn1 = gnn.BatchNorm(n_hidden) 26 | self.grelu1 = nn.LeakyReLU(inplace=True) 27 | self.gconv2 = gnn.GCNConv(n_hidden, n_hidden, improved=True, bias=False) 28 | self.gbn2 = gnn.BatchNorm(n_hidden) 29 | self.grelu2 = nn.LeakyReLU(inplace=True) 30 | self.gconv3 = gnn.GCNConv(n_hidden, n_hidden, improved=True, bias=False) 31 | self.gbn3 = gnn.BatchNorm(n_hidden) 32 | self.grelu3 = nn.LeakyReLU(inplace=True) 33 | 34 | self.lstm = nn.LSTM(NUM_BONES * n_hidden + 1, 4 * n_hidden, batch_first=True, bidirectional=True, num_layers=8) 35 | 36 | self.fc1 = nn.Conv1d(8 * n_hidden, 8 * n_hidden, 3, padding=1, padding_mode="replicate", bias=False) 37 | self.bn_fc1 = nn.BatchNorm1d(8 * n_hidden) 38 | self.relu_fc1 = nn.LeakyReLU(inplace=True) 39 | self.fc2 = nn.Conv1d(8 * n_hidden, 8 * n_hidden, 3, padding=1, padding_mode="replicate", bias=False) 40 | self.bn_fc2 = nn.BatchNorm1d(8 * n_hidden) 41 | self.relu_fc2 = nn.LeakyReLU(inplace=True) 42 | 43 | self.out = nn.Conv1d(8 * n_hidden, 1, 1) 44 | 45 | def _create_graph(self, state): 46 | # Pose data in state is [ATTRIBUTES, FRAMES] 47 | 48 | node_data = torch.cat([state[:-1].t().reshape(-1, 6), 49 | state[-1].unsqueeze(1).repeat(1, NUM_BONES).reshape(-1, 1)], dim=1) 50 | 51 | skeleton_edges = torch.arange(state.shape[1], device=torch.device("cpu")).reshape(-1, 1).repeat(2, 1, self.CMU_SKELETON.shape[1]).reshape(2, -1) 52 | skeleton_edges *= NUM_BONES 53 | skeleton_edges += self.CMU_SKELETON.repeat(1, state.shape[1]) 54 | 55 | return Data(x=node_data, edge_index=skeleton_edges, device=device) 56 | 57 | def forward(self, state_batch, remaining_actions): 58 | batch_list = torch.unbind(state_batch) 59 | data_list = [] 60 | for i in range(len(batch_list)): 61 | data_list.append(self._create_graph(batch_list[i])) 62 | 63 | graph = Batch().from_data_list(data_list).to(device) 64 | 65 | g, edges = graph.x, graph.edge_index 66 | 67 | g = self.grelu1(self.gbn1(self.gconv1(g, edges))) 68 | g = self.grelu2(self.gbn2(self.gconv2(g, edges))) 69 | g = self.grelu3(self.gbn3(self.gconv3(g, edges))) 70 | 71 | g_data = g.reshape(state_batch.shape[0], state_batch.shape[2], -1) 72 | concat = torch.cat([g_data, remaining_actions.view(-1, 1, 1).repeat(1, state_batch.shape[2], 1).to(device)], dim=2) 73 | 74 | h, _ = self.lstm(concat) 75 | h = self.relu_fc1(self.bn_fc1(self.fc1(h.transpose(1, 2)))) 76 | h = self.relu_fc2(self.bn_fc2(self.fc2(h))) 77 | out = self.out(h) 78 | 79 | return out.squeeze(1)[:, 1:-1] 80 | 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import math 4 | from env import Env 5 | from dataloader import DataLoader 6 | from agent import KeyframeExtractor 7 | from liveplot import LivePlot 8 | import numpy as np 9 | import sys 10 | 11 | 12 | class Trainer: 13 | def __init__(self, train_sample_len=72): 14 | self.env = Env() 15 | self.agent = KeyframeExtractor(gamma=1.0, eps=1.0, eps_factor=2000, learning_rate=0.0004, 16 | training_len=train_sample_len, mem_size=100000, batch_size=256, tau=(200, 1.0)) 17 | 18 | self.steps = 0 19 | self.plot = LivePlot() 20 | self.data_loader = DataLoader() 21 | self.train_sample_len = train_sample_len 22 | 23 | def train(self, n_episodes): 24 | prev_results = np.zeros(100, dtype=np.float32) 25 | prev_results_idx = 0 26 | best_result = 0.0 27 | 28 | for i_episode in range(1, n_episodes + 1): 29 | num_keys = random.randint(self.train_sample_len // 12 + 2, self.train_sample_len // 4) 30 | self.env.new_motion(self.data_loader.sample_train(self.train_sample_len), num_keys) 31 | env_state, remaining_actions = self.env.get_state() 32 | 33 | state = torch.tensor(env_state, dtype=torch.float) 34 | 35 | done = False 36 | loss = 0.0 37 | score = 0.0 38 | 39 | use_eps = random.random() < 0.1 + 0.9 * math.exp(-1 * self.steps / 2000) 40 | 41 | if i_episode % 100 == 0: 42 | if best_result < prev_results.mean(): 43 | best_result = prev_results.mean() 44 | self.agent.save_models() 45 | 46 | while not done: 47 | remain = torch.tensor(remaining_actions / state.shape[1]).view(1, 1) 48 | action = self.agent.select_action(state.unsqueeze(0), remain, self.env.get_keyframes()[1:-1], 49 | use_eps=use_eps) 50 | next_state, next_remaining_actions, evals, done, result = self.env.step(action) 51 | reward = evals[2] 52 | 53 | concat_state = torch.cat([state, remain.view(1, 1).repeat(1, state.shape[1])], dim=0) 54 | concat_next_state = torch.cat([torch.tensor(next_state), 55 | torch.tensor(next_remaining_actions / state.shape[1]).view(1, 1).repeat(1, state.shape[1])], dim=0) 56 | 57 | self.agent.store_transition(concat_state, action, reward, concat_next_state, done) 58 | score = evals[1] 59 | prev_results[prev_results_idx] = score 60 | prev_results_idx = (prev_results_idx + 1) % 100 61 | 62 | if not done: 63 | state = torch.tensor(next_state, dtype=torch.float) 64 | remaining_actions = next_remaining_actions 65 | else: 66 | loss = self.agent.optimise_model() 67 | 68 | self.steps += 1 69 | if not use_eps: 70 | self.plot.append(i_episode, score, loss) 71 | 72 | keyframe_mask = result == 1 73 | keyframe_idx = [] 74 | for i in range(keyframe_mask.shape[0]): 75 | if keyframe_mask[i].item(): 76 | keyframe_idx.append(i) 77 | 78 | if i_episode % 200 == 0: 79 | self.plot.draw() 80 | 81 | print(keyframe_idx) 82 | print("Episode %d, Q-Loss: %.4f, Last Reward: %.4f" % (i_episode, loss, score)) 83 | 84 | self.plot.save() 85 | self.plot.draw(block=True) 86 | 87 | 88 | def input_error(): 89 | print("Usage: python train.py [NUMBER OF EPISODES]") 90 | sys.exit() 91 | 92 | 93 | if __name__ == "__main__": 94 | if len(sys.argv) == 2: 95 | try: 96 | epochs = int(sys.argv[1]) 97 | except ValueError: 98 | input_error() 99 | else: 100 | input_error() 101 | 102 | trainer = Trainer() 103 | trainer.train(epochs) 104 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spatialhelper import CMUSpatialHelper 3 | from interp import lerp 4 | 5 | 6 | class Env: 7 | def __init__(self, motion_data=None, n_keys=None): 8 | self.spatial_helper = CMUSpatialHelper() 9 | 10 | self.original = None 11 | self.spatial_original = None 12 | 13 | self.original_mean = None 14 | self.original_std = None 15 | self.recon = None 16 | self.spatial_recon = None 17 | self.keyframes = None 18 | 19 | self.remaining_actions = 0 20 | self.total_keys = 0 21 | self.base_score = 0.0 22 | 23 | if motion_data is not None and n_keys is not None: 24 | self.new_motion(motion_data, n_keys) 25 | 26 | def _normalise(self, x): 27 | return (x - self.original_mean) / self.original_std 28 | 29 | def _update_recon(self, key=None): 30 | keyframe_mask = self.keyframes == 1 31 | keyframe_idx = [] 32 | for i in range(keyframe_mask.shape[0]): 33 | if keyframe_mask[i].item(): 34 | keyframe_idx.append(i) 35 | 36 | if key is None: 37 | self.recon = np.zeros(self.original.shape) 38 | for i in range(len(keyframe_idx) - 1): 39 | self.recon[:, keyframe_idx[i]:keyframe_idx[i + 1] + 1] = lerp(self.original[:, keyframe_idx[i]], 40 | self.original[:, keyframe_idx[i+1]], 41 | keyframe_idx[i], 42 | keyframe_idx[i+1]) 43 | else: 44 | idx = keyframe_idx.index(key) 45 | for i in range(idx - 1, idx + 1): 46 | self.recon[:, keyframe_idx[i]:keyframe_idx[i + 1] + 1] = lerp(self.original[:, keyframe_idx[i]], 47 | self.original[:, keyframe_idx[i+1]], 48 | keyframe_idx[i], 49 | keyframe_idx[i+1]) 50 | 51 | return self.recon 52 | 53 | # Motion data in shape [ATTRIBUTES, FRAMES] 54 | def new_motion(self, motion_data, n_keys): 55 | if n_keys < 3: 56 | raise ValueError("Must provide at least 3 keys") 57 | if motion_data.shape[1] < 3: 58 | raise ValueError("Motion must have at least 3 frames") 59 | if n_keys > motion_data.shape[1]: 60 | raise ValueError("Motion must have more frames than keys") 61 | 62 | self.original = motion_data 63 | 64 | self.spatial_original = self.spatial_helper.get_spatial_representation(self.original) 65 | 66 | self.original_mean = np.mean(self.spatial_original, (0, 1, 3), keepdims=True) 67 | self.original_std = np.std(self.spatial_original - self.original_mean) 68 | 69 | self.spatial_original = self._normalise(self.spatial_original) 70 | 71 | self.keyframes = None 72 | self.keyframes = np.zeros(self.original.shape[1]) 73 | self.keyframes[0] = 1 74 | self.keyframes[-1] = 1 75 | 76 | self.remaining_actions = n_keys - 2 77 | 78 | self._update_recon() 79 | self.spatial_recon = self._normalise(self.spatial_helper.get_spatial_representation(self.recon)) 80 | self.base_score = self._get_score() + np.finfo(np.float32).tiny 81 | self.total_keys = n_keys 82 | 83 | def _get_score(self): 84 | return np.mean(np.sqrt(np.sum((self.spatial_original - self.spatial_recon) ** 2, axis=2))) 85 | 86 | def step(self, action): 87 | old_loss = np.arctanh(max(1 - (self._get_score() / self.base_score), -1 + 1e-6)) 88 | self.keyframes[action + 1] = 1 89 | self._update_recon(action + 1) 90 | 91 | self.remaining_actions -= 1 92 | 93 | done = self.remaining_actions == 0 94 | 95 | self.spatial_recon = self._normalise(self.spatial_helper.get_spatial_representation(self.recon)) 96 | score = self._get_score() 97 | reward = np.arctanh(max(1 - (score / self.base_score), -1 + 1e-6)) - old_loss 98 | 99 | next_state, _ = self.get_state() 100 | 101 | return next_state, self.remaining_actions, (score, 1 - (score / self.base_score), reward), done, self.keyframes 102 | 103 | def get_state(self): 104 | state = list() 105 | state.append(self.spatial_original.reshape(self.spatial_original.shape[0], -1).transpose((1, 0))) 106 | state.append(self.spatial_recon.reshape(self.spatial_original.shape[0], -1).transpose((1, 0))) 107 | state.append(np.expand_dims(self.keyframes, 0)) 108 | 109 | return np.concatenate(state, 0), self.remaining_actions 110 | 111 | def get_keyframes(self): 112 | return self.keyframes 113 | -------------------------------------------------------------------------------- /spatialhelper.py: -------------------------------------------------------------------------------- 1 | # Spatial Representation Helper 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | from scipy.spatial.transform import Rotation 7 | 8 | NUM_BONES = 23 9 | 10 | 11 | class CMUSpatialHelper: 12 | 13 | class Bone: 14 | def __init__(self): 15 | self.parent = None 16 | self.local_rot = None 17 | self.local_pos = None 18 | self.global_transform = None 19 | 20 | def set(self, offset, rotation_data, position_data=None, parent=None): 21 | self.parent = parent 22 | self.local_rot = rotation_data 23 | self.local_pos = np.expand_dims(offset, axis=0).repeat(self.local_rot.shape[0], axis=0) 24 | if position_data is not None: 25 | self.local_pos += position_data 26 | 27 | self.global_transform = None 28 | 29 | def get_transform_matrix(self): 30 | rotation = np.deg2rad(self.local_rot) 31 | location = self.local_pos 32 | 33 | vector = np.zeros((self.local_rot.shape[0], 4, 4), dtype=np.float32) 34 | 35 | r = Rotation.from_euler("ZYX", rotation[:, [2, 1, 0]]) 36 | vector[:, :3, :3] = r.as_matrix() 37 | vector[:, :3, 3] = location 38 | 39 | vector[:, 3, 3] = 1 40 | 41 | return vector 42 | 43 | def get_global_transform(self): 44 | if self.parent: 45 | transform = self.parent.global_transform 46 | else: 47 | transform = torch.eye(4).unsqueeze(0).repeat(self.local_rot.shape[0], 1, 1) 48 | 49 | transform_matrix = torch.from_numpy(self.get_transform_matrix()).float() 50 | self.global_transform = torch.bmm(transform, transform_matrix) 51 | 52 | return self.global_transform 53 | 54 | def get_global_position(self): 55 | global_transform = self.get_global_transform() 56 | 57 | position_vector = torch.zeros(global_transform.shape[0], 4, 1, dtype=torch.float32) 58 | position_vector[:, 3] = 1 59 | 60 | return torch.bmm(global_transform, position_vector).numpy() 61 | 62 | def __init__(self): 63 | self.CMU_OFFSETS = np.zeros((23, 3), dtype=np.float32) 64 | self.CMU_OFFSETS[2] = np.array([0.02827, 2.03559, -0.19338]) 65 | self.CMU_OFFSETS[3] = np.array([0.05672, 2.04885, -0.04275]) 66 | self.CMU_OFFSETS[5] = np.array([-0.05417, 1.74624, 0.17202]) 67 | self.CMU_OFFSETS[6] = np.array([0.10407, 1.76136, -0.12397]) 68 | self.CMU_OFFSETS[8] = np.array([3.36241, 1.20089, -0.31121]) 69 | self.CMU_OFFSETS[9] = np.array([4.98300, 0, 0]) 70 | self.CMU_OFFSETS[10] = np.array([3.48356, 0, 0]) 71 | self.CMU_OFFSETS[12] = np.array([-3.13660, 1.37405, -0.40465]) 72 | self.CMU_OFFSETS[13] = np.array([-5.24190, 0, 0]) 73 | self.CMU_OFFSETS[14] = np.array([-3.44417, 0, 0]) 74 | self.CMU_OFFSETS[15] = np.array([1.36306, -1.79463, 0.83929]) 75 | self.CMU_OFFSETS[16] = np.array([2.44811, -6.72613, 0]) 76 | self.CMU_OFFSETS[17] = np.array([2.56220, -7.03959, 0]) 77 | self.CMU_OFFSETS[18] = np.array([0.15764, -0.43311, 2.32255]) 78 | self.CMU_OFFSETS[19] = np.array([-1.30552, -1.79463, 0.83929]) 79 | self.CMU_OFFSETS[20] = np.array([-2.54253, -6.98555, 0]) 80 | self.CMU_OFFSETS[21] = np.array([-2.56826, -7.05623, 0]) 81 | self.CMU_OFFSETS[22] = np.array([-0.16473, -0.45259, 2.36315]) 82 | 83 | self.bones = [self.Bone() for _ in range(NUM_BONES)] 84 | 85 | # Motion data in shape [ATTRIBUTES, FRAMES] 86 | def get_spatial_representation(self, motion_data): 87 | data = np.transpose(motion_data, (1, 0)) 88 | self.bones[0].set(self.CMU_OFFSETS[0], data[:, 3:6], position_data=data[:, :3]) 89 | self.bones[1].set(self.CMU_OFFSETS[1], data[:, 6:9], parent=self.bones[0]) 90 | self.bones[2].set(self.CMU_OFFSETS[2], data[:, 9:12], parent=self.bones[1]) 91 | self.bones[3].set(self.CMU_OFFSETS[3], data[:, 12:15], parent=self.bones[2]) 92 | self.bones[4].set(self.CMU_OFFSETS[4], data[:, 15:18], parent=self.bones[3]) 93 | self.bones[5].set(self.CMU_OFFSETS[5], data[:, 18:21], parent=self.bones[4]) 94 | self.bones[6].set(self.CMU_OFFSETS[6], data[:, 21:24], parent=self.bones[5]) 95 | self.bones[7].set(self.CMU_OFFSETS[7], data[:, 24:27], parent=self.bones[3]) 96 | self.bones[8].set(self.CMU_OFFSETS[8], data[:, 27:30], parent=self.bones[7]) 97 | self.bones[9].set(self.CMU_OFFSETS[9], data[:, 30:33], parent=self.bones[8]) 98 | self.bones[10].set(self.CMU_OFFSETS[10], data[:, 33:36], parent=self.bones[9]) 99 | self.bones[11].set(self.CMU_OFFSETS[11], data[:, 36:39], parent=self.bones[3]) 100 | self.bones[12].set(self.CMU_OFFSETS[12], data[:, 39:42], parent=self.bones[11]) 101 | self.bones[13].set(self.CMU_OFFSETS[13], data[:, 42:45], parent=self.bones[12]) 102 | self.bones[14].set(self.CMU_OFFSETS[14], data[:, 45:48], parent=self.bones[13]) 103 | self.bones[15].set(self.CMU_OFFSETS[15], data[:, 48:51], parent=self.bones[0]) 104 | self.bones[16].set(self.CMU_OFFSETS[16], data[:, 51:54], parent=self.bones[15]) 105 | self.bones[17].set(self.CMU_OFFSETS[17], data[:, 54:57], parent=self.bones[16]) 106 | self.bones[18].set(self.CMU_OFFSETS[18], data[:, 57:60], parent=self.bones[17]) 107 | self.bones[19].set(self.CMU_OFFSETS[19], data[:, 60:63], parent=self.bones[0]) 108 | self.bones[20].set(self.CMU_OFFSETS[20], data[:, 63:66], parent=self.bones[19]) 109 | self.bones[21].set(self.CMU_OFFSETS[21], data[:, 66:69], parent=self.bones[20]) 110 | self.bones[22].set(self.CMU_OFFSETS[22], data[:, 69:72], parent=self.bones[21]) 111 | 112 | representation = [] 113 | 114 | for bone in self.bones: 115 | representation.append(bone.get_global_position()[:, :3]) 116 | 117 | # Output shape: [FRAMES, BONES, XYZ] 118 | return np.stack(representation, axis=1) 119 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import random 5 | from model import GKEN 6 | import numpy as np 7 | 8 | NUM_BONES = 23 9 | 10 | 11 | class ReplayMemory(object): 12 | def __init__(self, capacity, state_size): 13 | self.capacity = capacity 14 | self.ctr = 0 15 | self.full = False 16 | 17 | self.states = np.zeros((self.capacity, *state_size), dtype=np.float32) 18 | self.next_states = np.zeros((self.capacity, *state_size), dtype=np.float32) 19 | 20 | self.actions = np.zeros(self.capacity, dtype=np.int8) 21 | self.rewards = np.zeros(self.capacity, dtype=np.float32) 22 | self.dones = np.zeros(self.capacity, dtype=np.bool_) 23 | 24 | def push(self, state, action, reward, next_state, done): 25 | idx = self.ctr % self.capacity 26 | self.states[idx] = state 27 | self.actions[idx] = action 28 | self.rewards[idx] = reward 29 | self.next_states[idx] = next_state 30 | self.dones[idx] = done 31 | 32 | self.ctr += 1 33 | 34 | def sample(self, batch_size): 35 | total = min(self.ctr, self.capacity) 36 | 37 | batch = np.random.choice(total, batch_size, replace=False) 38 | 39 | states = self.states[batch] 40 | actions = self.actions[batch] 41 | rewards = self.rewards[batch] 42 | next_states = self.next_states[batch] 43 | dones = self.dones[batch] 44 | 45 | return states, actions, rewards, next_states, dones 46 | 47 | 48 | class KeyframeExtractor(object): 49 | def __init__(self, gamma, eps, eps_factor, learning_rate, training_len, mem_size, batch_size, tau, save_dir='dqn'): 50 | self.GAMMA = gamma 51 | self.eps = eps 52 | self.EPS_MAX = eps 53 | self.lr = learning_rate 54 | self.training_len = training_len 55 | self.batch_size = batch_size 56 | 57 | self.EPS_MIN = 0.1 58 | self.EPS_DEC = 1e-3 59 | self.EPS_FACTOR = eps_factor 60 | 61 | self.TAU = tau 62 | self.save_dir = save_dir 63 | 64 | self.steps = 0 65 | self.input_size = NUM_BONES * 6 + 2 66 | 67 | self.memory = ReplayMemory(mem_size, (self.input_size, training_len)) 68 | 69 | self.device = torch.device("cuda") 70 | 71 | self.policy_net = GKEN().to(self.device) 72 | self.target_net = GKEN().to(self.device) 73 | self.optimiser = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate) 74 | self.criterion = torch.nn.L1Loss() 75 | self.target_net.eval() 76 | 77 | def store_transition(self, state, action, reward, next_state, done): 78 | self.memory.push(state, action, reward, next_state, done) 79 | 80 | def sample_memory(self): 81 | state, action, reward, next_state, done = self.memory.sample(self.batch_size) 82 | 83 | states = torch.tensor(state) 84 | actions = torch.tensor(action, dtype=torch.long) 85 | rewards = torch.tensor(reward) 86 | next_states = torch.tensor(next_state) 87 | dones = torch.tensor(done) 88 | 89 | return states, actions, rewards, next_states, dones 90 | 91 | def select_action(self, state, remaining_actions, keyframe_mask, use_eps=True): 92 | if use_eps is False or random.random() > self.eps: 93 | pred = self.policy_net.forward(state, remaining_actions) 94 | for i in range(keyframe_mask.shape[0]): 95 | if keyframe_mask[i] == 1: 96 | pred[:, i] = float("-inf") 97 | 98 | action = torch.argmax(pred).item() 99 | else: 100 | valid = [] 101 | for i in range(keyframe_mask.shape[0]): 102 | if keyframe_mask[i] == 0: 103 | valid.append(i) 104 | action = random.sample(valid, 1)[0] 105 | return action 106 | 107 | def replace_target(self): 108 | for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()): 109 | target_param.data.copy_(self.TAU[1] * policy_param.data + (1.0 - self.TAU[1]) * target_param.data) 110 | self.target_net.eval() 111 | 112 | def decrement_epsilon(self): 113 | self.eps = self.EPS_MIN + (self.EPS_MAX - self.EPS_MIN) * math.exp(-1 * self.steps / self.EPS_FACTOR) 114 | 115 | def decrement_lr(self): 116 | for g in self.optimiser.param_groups: 117 | g['lr'] = self.lr / 10 + (self.lr * 9 / 10) * math.exp(-1 * self.steps / self.EPS_FACTOR) 118 | 119 | def optimise_model(self): 120 | if self.memory.ctr < self.batch_size: 121 | return 0 122 | 123 | states, actions, rewards, next_states, dones = self.sample_memory() 124 | 125 | indices = np.arange(self.batch_size) 126 | 127 | q_pred = self.policy_net(states[:, :-1], states[:, -1, 0])[indices, actions] 128 | q_next = self.target_net(next_states[:, :-1], next_states[:, -1, 0]).detach() 129 | q_eval = self.policy_net(next_states[:, :-1], next_states[:, -1, 0]) 130 | 131 | max_actions = torch.argmax(q_eval, dim=1) 132 | q_next[dones] = float(0.0) 133 | 134 | q_target = rewards.to(self.device) + self.GAMMA * q_next[indices, max_actions] 135 | loss = self.criterion(q_target, q_pred) 136 | 137 | self.optimiser.zero_grad() 138 | loss.backward() 139 | # for param in self.policy_net.parameters(): 140 | # param.grad.data.clamp_(-1, 1) 141 | self.optimiser.step() 142 | 143 | self.steps += 1 144 | self.decrement_epsilon() 145 | self.decrement_lr() 146 | 147 | if self.steps % self.TAU[0] == 0: 148 | print("Replacing") 149 | self.replace_target() 150 | 151 | return loss.item() 152 | 153 | def save_models(self): 154 | torch.save({ 155 | 'policy_model_state_dict': self.policy_net.state_dict(), 156 | 'target_model_state_dict': self.target_net.state_dict(), 157 | 'optimiser_state_dict': self.optimiser.state_dict(), 158 | }, "./model.pt") 159 | 160 | def load_models(self, steps=0): 161 | self.policy_net = GKEN().to(self.device) 162 | self.target_net = GKEN().to(self.device) 163 | self.optimiser = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr) 164 | 165 | checkpoint = torch.load("./model.pt") 166 | 167 | self.policy_net.load_state_dict(checkpoint['policy_model_state_dict']) 168 | self.target_net.load_state_dict(checkpoint['target_model_state_dict']) 169 | self.optimiser.load_state_dict(checkpoint['optimiser_state_dict']) 170 | 171 | self.target_net.eval() 172 | 173 | self.steps = steps 174 | -------------------------------------------------------------------------------- /bvh.py: -------------------------------------------------------------------------------- 1 | # BVH file parser from https://github.com/20tab/bvh-python 2 | 3 | import re 4 | import numpy as np 5 | 6 | 7 | class BvhNode: 8 | 9 | def __init__(self, value=[], parent=None): 10 | self.value = value 11 | self.children = [] 12 | self.parent = parent 13 | if self.parent: 14 | self.parent.add_child(self) 15 | 16 | def add_child(self, item): 17 | item.parent = self 18 | self.children.append(item) 19 | 20 | def filter(self, key): 21 | for child in self.children: 22 | if child.value[0] == key: 23 | yield child 24 | 25 | def __iter__(self): 26 | for child in self.children: 27 | yield child 28 | 29 | def __getitem__(self, key): 30 | for child in self.children: 31 | for index, item in enumerate(child.value): 32 | if item == key: 33 | if index + 1 >= len(child.value): 34 | return None 35 | else: 36 | return child.value[index + 1:] 37 | raise IndexError('key {} not found'.format(key)) 38 | 39 | def __repr__(self): 40 | return str(' '.join(self.value)) 41 | 42 | @property 43 | def name(self): 44 | return self.value[1] 45 | 46 | 47 | class Bvh: 48 | 49 | def __init__(self, data): 50 | self.data = data 51 | self.root = BvhNode() 52 | self.frames = [] 53 | self.tokenize() 54 | 55 | def tokenize(self): 56 | first_round = [] 57 | accumulator = '' 58 | for char in self.data: 59 | if char not in ('\n', '\r'): 60 | accumulator += char 61 | elif accumulator: 62 | first_round.append(re.split('\\s+', accumulator.strip())) 63 | accumulator = '' 64 | node_stack = [self.root] 65 | frame_time_found = False 66 | node = None 67 | for item in first_round: 68 | if frame_time_found: 69 | self.frames.append(item) 70 | continue 71 | key = item[0] 72 | if key == '{': 73 | node_stack.append(node) 74 | elif key == '}': 75 | node_stack.pop() 76 | else: 77 | node = BvhNode(item) 78 | node_stack[-1].add_child(node) 79 | if item[0] == 'Frame' and item[1] == 'Time:': 80 | frame_time_found = True 81 | 82 | def search(self, *items): 83 | found_nodes = [] 84 | 85 | def check_children(node): 86 | if len(node.value) >= len(items): 87 | failed = False 88 | for index, item in enumerate(items): 89 | if node.value[index] != item: 90 | failed = True 91 | break 92 | if not failed: 93 | found_nodes.append(node) 94 | for child in node: 95 | check_children(child) 96 | 97 | check_children(self.root) 98 | return found_nodes 99 | 100 | def get_joints(self): 101 | joints = [] 102 | 103 | def iterate_joints(joint): 104 | joints.append(joint) 105 | for child in joint.filter('JOINT'): 106 | iterate_joints(child) 107 | 108 | iterate_joints(next(self.root.filter('ROOT'))) 109 | return joints 110 | 111 | def get_joints_names(self): 112 | joints = [] 113 | 114 | def iterate_joints(joint): 115 | joints.append(joint.value[1]) 116 | for child in joint.filter('JOINT'): 117 | iterate_joints(child) 118 | 119 | iterate_joints(next(self.root.filter('ROOT'))) 120 | return joints 121 | 122 | def joint_direct_children(self, name): 123 | joint = self.get_joint(name) 124 | return [child for child in joint.filter('JOINT')] 125 | 126 | def get_joint_index(self, name): 127 | return self.get_joints().index(self.get_joint(name)) 128 | 129 | def get_joint(self, name): 130 | found = self.search('ROOT', name) 131 | if not found: 132 | found = self.search('JOINT', name) 133 | if found: 134 | return found[0] 135 | raise LookupError('joint not found') 136 | 137 | def joint_offset(self, name): 138 | joint = self.get_joint(name) 139 | offset = joint['OFFSET'] 140 | return (float(offset[0]), float(offset[1]), float(offset[2])) 141 | 142 | def joint_channels(self, name): 143 | joint = self.get_joint(name) 144 | return joint['CHANNELS'][1:] 145 | 146 | def get_joint_channels_index(self, joint_name): 147 | index = 0 148 | for joint in self.get_joints(): 149 | if joint.value[1] == joint_name: 150 | return index 151 | index += int(joint['CHANNELS'][0]) 152 | raise LookupError('joint not found') 153 | 154 | def get_joint_channel_index(self, joint, channel): 155 | channels = self.joint_channels(joint) 156 | if channel in channels: 157 | channel_index = channels.index(channel) 158 | else: 159 | channel_index = -1 160 | return channel_index 161 | 162 | def frame_joint_channel(self, frame_index, joint, channel, value=None): 163 | joint_index = self.get_joint_channels_index(joint) 164 | channel_index = self.get_joint_channel_index(joint, channel) 165 | if channel_index == -1 and value is not None: 166 | return value 167 | return float(self.frames[frame_index][joint_index + channel_index]) 168 | 169 | def frame_joint_channels(self, frame_index, joint, channels, value=None): 170 | values = [] 171 | joint_index = self.get_joint_channels_index(joint) 172 | for channel in channels: 173 | channel_index = self.get_joint_channel_index(joint, channel) 174 | if channel_index == -1 and value is not None: 175 | values.append(value) 176 | else: 177 | values.append( 178 | float( 179 | self.frames[frame_index][joint_index + channel_index] 180 | ) 181 | ) 182 | return values 183 | 184 | def frames_joint_channels(self, joint, channels, value=None): 185 | all_frames = [] 186 | joint_index = self.get_joint_channels_index(joint) 187 | for frame in self.frames: 188 | values = [] 189 | for channel in channels: 190 | channel_index = self.get_joint_channel_index(joint, channel) 191 | if channel_index == -1 and value is not None: 192 | values.append(value) 193 | else: 194 | values.append( 195 | float(frame[joint_index + channel_index])) 196 | all_frames.append(values) 197 | return all_frames 198 | 199 | def joint_parent(self, name): 200 | joint = self.get_joint(name) 201 | if joint.parent == self.root: 202 | return None 203 | return joint.parent 204 | 205 | def joint_parent_index(self, name): 206 | joint = self.get_joint(name) 207 | if joint.parent == self.root: 208 | return -1 209 | return self.get_joints().index(joint.parent) 210 | 211 | @property 212 | def nframes(self): 213 | try: 214 | return int(next(self.root.filter('Frames:')).value[1]) 215 | except StopIteration: 216 | raise LookupError('number of frames not found') 217 | 218 | @property 219 | def frame_time(self): 220 | try: 221 | return float(next(self.root.filter('Frame')).value[2]) 222 | except StopIteration: 223 | raise LookupError('frame time not found') 224 | 225 | 226 | NUM_CHANNELS = 24 227 | 228 | 229 | def CMU_parse(file, start=1, frame_skip=5): 230 | mocap = Bvh(file) 231 | 232 | f = 0 233 | for i in range(start, mocap.nframes, frame_skip): 234 | f += 1 235 | 236 | motion_data = np.zeros((NUM_CHANNELS, f, 3), dtype=np.float) 237 | 238 | root_channels = ["Xposition", "Yposition", "Zposition"] 239 | joint_channels = ["Zrotation", "Yrotation", "Xrotation"] 240 | 241 | f = 0 242 | 243 | for i in range(start, mocap.nframes, frame_skip): 244 | motion_data[0, f] = mocap.frame_joint_channels(i, "Hips", root_channels) 245 | motion_data[1, f] = mocap.frame_joint_channels(i, "Hips", joint_channels) 246 | motion_data[2, f] = mocap.frame_joint_channels(i, "LowerBack", joint_channels) 247 | motion_data[3, f] = mocap.frame_joint_channels(i, "Spine", joint_channels) 248 | motion_data[4, f] = mocap.frame_joint_channels(i, "Spine1", joint_channels) 249 | motion_data[5, f] = mocap.frame_joint_channels(i, "Neck", joint_channels) 250 | motion_data[6, f] = mocap.frame_joint_channels(i, "Neck1", joint_channels) 251 | motion_data[7, f] = mocap.frame_joint_channels(i, "Head", joint_channels) 252 | motion_data[8, f] = mocap.frame_joint_channels(i, "LeftShoulder", joint_channels) 253 | motion_data[9, f] = mocap.frame_joint_channels(i, "LeftArm", joint_channels) 254 | motion_data[10, f] = mocap.frame_joint_channels(i, "LeftForeArm", joint_channels) 255 | motion_data[11, f] = mocap.frame_joint_channels(i, "LeftHand", joint_channels) 256 | motion_data[12, f] = mocap.frame_joint_channels(i, "RightShoulder", joint_channels) 257 | motion_data[13, f] = mocap.frame_joint_channels(i, "RightArm", joint_channels) 258 | motion_data[14, f] = mocap.frame_joint_channels(i, "RightForeArm", joint_channels) 259 | motion_data[15, f] = mocap.frame_joint_channels(i, "RightHand", joint_channels) 260 | motion_data[16, f] = mocap.frame_joint_channels(i, "LeftUpLeg", joint_channels) 261 | motion_data[17, f] = mocap.frame_joint_channels(i, "LeftLeg", joint_channels) 262 | motion_data[18, f] = mocap.frame_joint_channels(i, "LeftFoot", joint_channels) 263 | motion_data[19, f] = mocap.frame_joint_channels(i, "LeftToeBase", joint_channels) 264 | motion_data[20, f] = mocap.frame_joint_channels(i, "RightUpLeg", joint_channels) 265 | motion_data[21, f] = mocap.frame_joint_channels(i, "RightLeg", joint_channels) 266 | motion_data[22, f] = mocap.frame_joint_channels(i, "RightFoot", joint_channels) 267 | motion_data[23, f] = mocap.frame_joint_channels(i, "RightToeBase", joint_channels) 268 | f += 1 269 | 270 | motion_data = motion_data.transpose((0, 2, 1)) 271 | motion_data = np.reshape(motion_data, (3 * NUM_CHANNELS, f)) 272 | 273 | return motion_data 274 | --------------------------------------------------------------------------------