├── .gitignore ├── .gitmodules ├── LICENSE ├── Performance.png ├── README.md ├── data ├── JHMDB │ ├── GT_test_1.pkl │ └── GT_train_1.pkl └── SHREC │ ├── test.pkl │ └── train.pkl ├── dataloader ├── __init__.py ├── jhmdb_loader.py └── shrec_loader.py ├── keres_performance.png ├── models ├── DDNet_Original.py └── __init__.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .vscode 131 | experiments 132 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorch-summary"] 2 | path = pytorch-summary 3 | url = https://github.com/sksq96/pytorch-summary.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Nightwatch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/Performance.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DD-Net Pytorch 2 | 3 | A pytorch reimplementation of [DD-Net](https://github.com/fandulu/DD-Net), which is a lightweight network for body/hand action recognition. 4 | 5 | ## How to use the code 6 | 7 | A subset of preprocessd JHMDB data has been in this repo, which is `GT_test_1.pkl`,`GT_train_1.pkl`. 8 | 9 | The code is written and tested in `Pytorch 1.3.0`. 10 | 11 | ```Python 12 | git clone https://github.com/BlurryLight/DD-Net-Pytorch.git 13 | python train.py --batch-size 512 --epochs 600 --lr 0.001 | tee train.log 14 | ``` 15 | 16 | ## Performance 17 | 18 | The number of parameters is the same with the the original Keres version, which is 1.80M in JHMDB. 19 | 20 | The val performance is comparable to the original version. On the provided JHMDB split, the performance is as blow. 21 | 22 | - The Pytorch version 23 | 24 | ![Performace](./Performance.png) 25 | 26 | - The Keres version 27 | 28 | ![Keres Performance](./keres_performance.png) -------------------------------------------------------------------------------- /data/JHMDB/GT_test_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/data/JHMDB/GT_test_1.pkl -------------------------------------------------------------------------------- /data/JHMDB/GT_train_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/data/JHMDB/GT_train_1.pkl -------------------------------------------------------------------------------- /data/SHREC/test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/data/SHREC/test.pkl -------------------------------------------------------------------------------- /data/SHREC/train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/data/SHREC/train.pkl -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/jhmdb_loader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | #! coding:utf-8:w 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | from sklearn import preprocessing 7 | import pickle 8 | from pathlib import Path 9 | import sys 10 | sys.path.insert(0, '..') 11 | from utils import * # noqa 12 | current_file_dirpath = Path(__file__).parent.absolute() 13 | 14 | 15 | def load_jhmdb_data( 16 | train_path=current_file_dirpath / Path("../data/JHMDB/GT_train_1.pkl"), 17 | test_path=current_file_dirpath / Path("../data/JHMDB/GT_test_1.pkl")): 18 | Train = pickle.load(open(train_path, "rb")) 19 | Test = pickle.load(open(test_path, "rb")) 20 | le = preprocessing.LabelEncoder() 21 | le.fit(Train['label']) 22 | print("Loading JHMDB Dataset") 23 | return Train, Test, le 24 | 25 | 26 | class JConfig(): 27 | def __init__(self): 28 | self.frame_l = 32 # the length of frames 29 | self.joint_n = 15 # the number of joints 30 | self.joint_d = 2 # the dimension of joints 31 | self.clc_num = 21 # the number of class 32 | self.feat_d = 105 33 | self.filters = 64 34 | 35 | # Genrate dataset 36 | # T: Dataset C:config le:labelEncoder 37 | 38 | 39 | def Jdata_generator(T, C, le): 40 | X_0 = [] 41 | X_1 = [] 42 | Y = [] 43 | labels = le.transform(T['label']) 44 | for i in tqdm(range(len(T['pose']))): 45 | p = np.copy(T['pose'][i]) 46 | # p.shape (frame,joint_num,joint_coords_dims) 47 | p = zoom(p, target_l=C.frame_l, 48 | joints_num=C.joint_n, joints_dim=C.joint_d) 49 | # p.shape (target_frame,joint_num,joint_coords_dims) 50 | # label = np.zeros(C.clc_num) 51 | # label[labels[i]] = 1 52 | label = labels[i] 53 | # M.shape (target_frame,(joint_num - 1) * joint_num / 2) 54 | M = get_CG(p, C) 55 | 56 | X_0.append(M) 57 | X_1.append(p) 58 | Y.append(label) 59 | 60 | X_0 = np.stack(X_0) 61 | X_1 = np.stack(X_1) 62 | Y = np.stack(Y) 63 | return X_0, X_1, Y 64 | -------------------------------------------------------------------------------- /dataloader/shrec_loader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | #! coding:utf-8:w 3 | 4 | import numpy as np 5 | from sklearn import preprocessing 6 | import pickle 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | import sys 10 | sys.path.insert(0, '..') 11 | from utils import * # noqa 12 | current_file_dirpath = Path(__file__).parent.absolute() 13 | 14 | 15 | def load_shrec_data( 16 | train_path=current_file_dirpath / Path("../data/SHREC/train.pkl"), 17 | test_path=current_file_dirpath / Path("../data/SHREC/test.pkl"), 18 | ): 19 | Train = pickle.load(open(train_path, "rb")) 20 | Test = pickle.load(open(test_path, "rb")) 21 | print("Loading SHREC Dataset") 22 | dummy = None # return a dummy to provide a similar interface with JHMDB one 23 | return Train, Test, None 24 | 25 | 26 | class SConfig(): 27 | def __init__(self): 28 | self.frame_l = 32 # the length of frames 29 | self.joint_n = 22 # the number of joints 30 | self.joint_d = 3 # the dimension of joints 31 | self.class_coarse_num = 14 32 | self.class_fine_num = 28 33 | self.feat_d = 231 34 | self.filters = 64 35 | 36 | 37 | class Sdata_generator: 38 | def __init__(self, label_level='coarse_label'): 39 | self.label_level = label_level 40 | 41 | # le is None to provide a unified interface with JHMDB datagenerator 42 | def __call__(self, T, C, le=None): 43 | X_0 = [] 44 | X_1 = [] 45 | Y = [] 46 | for i in tqdm(range(len(T['pose']))): 47 | p = np.copy(T['pose'][i].reshape([-1, 22, 3])) 48 | # p.shape (frame,joint_num,joint_coords_dims) 49 | p = zoom(p, target_l=C.frame_l, 50 | joints_num=C.joint_n, joints_dim=C.joint_d) 51 | # p.shape (target_frame,joint_num,joint_coords_dims) 52 | # label = np.zeros(C.clc_num) 53 | # label[labels[i]] = 1 54 | label = (T[self.label_level])[i] - 1 55 | # M.shape (target_frame,(joint_num - 1) * joint_num / 2) 56 | M = get_CG(p, C) 57 | 58 | X_0.append(M) 59 | X_1.append(p) 60 | Y.append(label) 61 | 62 | self.X_0 = np.stack(X_0) 63 | self.X_1 = np.stack(X_1) 64 | self.Y = np.stack(Y) 65 | return self.X_0, self.X_1, self.Y 66 | 67 | 68 | if __name__ == '__main__': 69 | Train, _ = load_shrec_data() 70 | C = SConfig() 71 | X_0, X_1, Y = Sdata_generator('coarse_label')(Train, C, 'coarse_label') 72 | print(Y) 73 | X_0, X_1, Y = Sdata_generator('fine_label')(Train, C, 'fine_label') 74 | print(Y) 75 | print("X_0.shape", X_0.shape) 76 | print("X_1.shape", X_1.shape) 77 | -------------------------------------------------------------------------------- /keres_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/keres_performance.png -------------------------------------------------------------------------------- /models/DDNet_Original.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | #! coding:utf-8 3 | 4 | from utils import poses_motion 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import torch 8 | import sys 9 | import math 10 | 11 | 12 | class c1D(nn.Module): 13 | # input (B,C,D) //batch,channels,dims 14 | # output = (B,C,filters) 15 | def __init__(self, input_channels, input_dims, filters, kernel): 16 | super(c1D, self).__init__() 17 | self.cut_last_element = (kernel % 2 == 0) 18 | self.padding = math.ceil((kernel - 1)/2) 19 | self.conv1 = nn.Conv1d(input_dims, filters, 20 | kernel, bias=False, padding=self.padding) 21 | self.bn = nn.BatchNorm1d(num_features=input_channels) 22 | 23 | def forward(self, x): 24 | # x (B,D,C) 25 | x = x.permute(0, 2, 1) 26 | # output (B,filters,C) 27 | if(self.cut_last_element): 28 | output = self.conv1(x)[:, :, :-1] 29 | else: 30 | output = self.conv1(x) 31 | # output = (B,C,filters) 32 | output = output.permute(0, 2, 1) 33 | output = self.bn(output) 34 | output = F.leaky_relu(output, 0.2, True) 35 | return output 36 | 37 | 38 | class block(nn.Module): 39 | def __init__(self, input_channels, input_dims, filters, kernel): 40 | super(block, self).__init__() 41 | self.c1D1 = c1D(input_channels, input_dims, filters, kernel) 42 | self.c1D2 = c1D(input_channels, filters, filters, kernel) 43 | 44 | def forward(self, x): 45 | output = self.c1D1(x) 46 | output = self.c1D2(output) 47 | return output 48 | 49 | 50 | class d1D(nn.Module): 51 | def __init__(self, input_dims, filters): 52 | super(d1D, self).__init__() 53 | self.linear = nn.Linear(input_dims, filters) 54 | self.bn = nn.BatchNorm1d(num_features=filters) 55 | 56 | def forward(self, x): 57 | output = self.linear(x) 58 | output = self.bn(output) 59 | output = F.leaky_relu(output, 0.2) 60 | return output 61 | 62 | 63 | class spatialDropout1D(nn.Module): 64 | def __init__(self, p): 65 | super(spatialDropout1D, self).__init__() 66 | self.dropout = nn.Dropout2d(p) 67 | 68 | def forward(self, x): 69 | x = x.permute(0, 2, 1) 70 | x = self.dropout(x) 71 | x = x.permute(0, 2, 1) 72 | return x 73 | 74 | 75 | class DDNet_Original(nn.Module): 76 | def __init__(self, frame_l, joint_n, joint_d, feat_d, filters, class_num): 77 | super(DDNet_Original, self).__init__() 78 | # JCD part 79 | self.jcd_conv1 = nn.Sequential( 80 | c1D(frame_l, feat_d, 2 * filters, 1), 81 | spatialDropout1D(0.1) 82 | ) 83 | self.jcd_conv2 = nn.Sequential( 84 | c1D(frame_l, 2 * filters, filters, 3), 85 | spatialDropout1D(0.1) 86 | ) 87 | self.jcd_conv3 = c1D(frame_l, filters, filters, 1) 88 | self.jcd_pool = nn.Sequential( 89 | nn.MaxPool1d(kernel_size=2), 90 | spatialDropout1D(0.1) 91 | ) 92 | 93 | # diff_slow part 94 | self.slow_conv1 = nn.Sequential( 95 | c1D(frame_l, joint_n * joint_d, 2 * filters, 1), 96 | spatialDropout1D(0.1) 97 | ) 98 | self.slow_conv2 = nn.Sequential( 99 | c1D(frame_l, 2 * filters, filters, 3), 100 | spatialDropout1D(0.1) 101 | ) 102 | self.slow_conv3 = c1D(frame_l, filters, filters, 1) 103 | self.slow_pool = nn.Sequential( 104 | nn.MaxPool1d(kernel_size=2), 105 | spatialDropout1D(0.1) 106 | ) 107 | 108 | # fast_part 109 | self.fast_conv1 = nn.Sequential( 110 | c1D(frame_l//2, joint_n * joint_d, 2 * filters, 1), spatialDropout1D(0.1)) 111 | self.fast_conv2 = nn.Sequential( 112 | c1D(frame_l//2, 2 * filters, filters, 3), spatialDropout1D(0.1)) 113 | self.fast_conv3 = nn.Sequential( 114 | c1D(frame_l//2, filters, filters, 1), spatialDropout1D(0.1)) 115 | 116 | # after cat 117 | self.block1 = block(frame_l//2, 3 * filters, 2 * filters, 3) 118 | self.block_pool1 = nn.Sequential( 119 | nn.MaxPool1d(kernel_size=2), spatialDropout1D(0.1)) 120 | 121 | self.block2 = block(frame_l//4, 2 * filters, 4 * filters, 3) 122 | self.block_pool2 = nn.Sequential(nn.MaxPool1d( 123 | kernel_size=2), spatialDropout1D(0.1)) 124 | 125 | self.block3 = nn.Sequential( 126 | block(frame_l//8, 4 * filters, 8 * filters, 3), spatialDropout1D(0.1)) 127 | 128 | self.linear1 = nn.Sequential( 129 | d1D(8 * filters, 128), 130 | nn.Dropout(0.5) 131 | ) 132 | self.linear2 = nn.Sequential( 133 | d1D(128, 128), 134 | nn.Dropout(0.5) 135 | ) 136 | 137 | self.linear3 = nn.Linear(128, class_num) 138 | 139 | def forward(self, M, P=None): 140 | x = self.jcd_conv1(M) 141 | x = self.jcd_conv2(x) 142 | x = self.jcd_conv3(x) 143 | x = x.permute(0, 2, 1) 144 | # pool will downsample the D dim of (B,C,D) 145 | # but we want to downsample the C channels 146 | # 1x1 conv may be a better choice 147 | x = self.jcd_pool(x) 148 | x = x.permute(0, 2, 1) 149 | 150 | diff_slow, diff_fast = poses_motion(P) 151 | x_d_slow = self.slow_conv1(diff_slow) 152 | x_d_slow = self.slow_conv2(x_d_slow) 153 | x_d_slow = self.slow_conv3(x_d_slow) 154 | x_d_slow = x_d_slow.permute(0, 2, 1) 155 | x_d_slow = self.slow_pool(x_d_slow) 156 | x_d_slow = x_d_slow.permute(0, 2, 1) 157 | 158 | x_d_fast = self.fast_conv1(diff_fast) 159 | x_d_fast = self.fast_conv2(x_d_fast) 160 | x_d_fast = self.fast_conv3(x_d_fast) 161 | # x,x_d_fast,x_d_slow shape: (B,framel//2,filters) 162 | 163 | x = torch.cat((x, x_d_slow, x_d_fast), dim=2) 164 | x = self.block1(x) 165 | x = x.permute(0, 2, 1) 166 | x = self.block_pool1(x) 167 | x = x.permute(0, 2, 1) 168 | 169 | x = self.block2(x) 170 | x = x.permute(0, 2, 1) 171 | x = self.block_pool2(x) 172 | x = x.permute(0, 2, 1) 173 | 174 | x = self.block3(x) 175 | # max pool over (B,C,D) C channels 176 | x = torch.max(x, dim=1).values 177 | 178 | x = self.linear1(x) 179 | x = self.linear2(x) 180 | x = self.linear3(x) 181 | return x 182 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlurryLight/DD-Net-Pytorch/52c38d9e9c01e94de2958c1a7b1e9176e067f4a4/models/__init__.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | #! coding:utf-8 3 | from pathlib import Path 4 | import matplotlib.pyplot as plt 5 | from torch import log 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn as nn 9 | import argparse 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from sklearn.metrics import confusion_matrix 13 | 14 | from dataloader.jhmdb_loader import load_jhmdb_data, Jdata_generator, JConfig 15 | from dataloader.shrec_loader import load_shrec_data, Sdata_generator, SConfig 16 | from models.DDNet_Original import DDNet_Original as DDNet 17 | from utils import makedir 18 | import sys 19 | import time 20 | import numpy as np 21 | import logging 22 | sys.path.insert(0, './pytorch-summary/torchsummary/') 23 | from torchsummary import summary # noqa 24 | 25 | savedir = Path('experiments') / Path(str(int(time.time()))) 26 | makedir(savedir) 27 | logging.basicConfig(filename=savedir/'train.log', level=logging.INFO) 28 | history = { 29 | "train_loss": [], 30 | "test_loss": [], 31 | "test_acc": [] 32 | } 33 | 34 | 35 | def train(args, model, device, train_loader, optimizer, epoch, criterion): 36 | model.train() 37 | train_loss = 0 38 | for batch_idx, (data1, data2, target) in enumerate(tqdm(train_loader)): 39 | M, P, target = data1.to(device), data2.to(device), target.to(device) 40 | optimizer.zero_grad() 41 | output = model(M, P) 42 | loss = criterion(output, target) 43 | train_loss += loss.detach().item() 44 | loss.backward() 45 | optimizer.step() 46 | if batch_idx % args.log_interval == 0: 47 | msg = ('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 48 | epoch, batch_idx * len(data1), len(train_loader.dataset), 49 | 100. * batch_idx / len(train_loader), loss.item())) 50 | print(msg) 51 | logging.info(msg) 52 | if args.dry_run: 53 | break 54 | history['train_loss'].append(train_loss) 55 | return train_loss 56 | 57 | 58 | def test(model, device, test_loader): 59 | model.eval() 60 | test_loss = 0 61 | correct = 0 62 | criterion = nn.CrossEntropyLoss(reduction='sum') 63 | with torch.no_grad(): 64 | for _, (data1, data2, target) in enumerate(tqdm(test_loader)): 65 | M, P, target = data1.to(device), data2.to(device), target.to(device) 66 | output = model(M, P) 67 | # sum up batch loss 68 | test_loss += criterion(output, target).item() 69 | # get the index of the max log-probability 70 | pred = output.argmax(dim=1, keepdim=True) 71 | # output shape (B,Class) 72 | # target_shape (B) 73 | # pred shape (B,1) 74 | correct += pred.eq(target.view_as(pred)).sum().item() 75 | 76 | test_loss /= len(test_loader.dataset) 77 | history['test_loss'].append(test_loss) 78 | history['test_acc'].append(correct / len(test_loader.dataset)) 79 | msg = ('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 80 | test_loss, correct, len(test_loader.dataset), 81 | 100. * correct / len(test_loader.dataset))) 82 | print(msg) 83 | logging.info(msg) 84 | 85 | 86 | def main(): 87 | # Training settings 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 90 | help='input batch size for training (default: 64)') 91 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 92 | help='input batch size for testing (default: 1000)') 93 | parser.add_argument('--epochs', type=int, default=199, metavar='N', 94 | help='number of epochs to train (default: 199)') 95 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 96 | help='learning rate (default: 0.01)') 97 | parser.add_argument('--gamma', type=float, default=0.5, metavar='M', 98 | help='Learning rate step gamma (default: 0.5)') 99 | parser.add_argument('--no-cuda', action='store_true', default=False, 100 | help='disables CUDA training') 101 | parser.add_argument('--dry-run', action='store_true', default=False, 102 | help='quickly check a single pass') 103 | parser.add_argument('--log-interval', type=int, default=2, metavar='N', 104 | help='how many batches to wait before logging training status') 105 | parser.add_argument('--save-model', action='store_true', default=False, 106 | help='For Saving the current Model') 107 | parser.add_argument('--dataset', type=int, required=True, metavar='N', 108 | help='0 for JHMDB, 1 for SHREC coarse, 2 for SHREC fine, others is undefined') 109 | parser.add_argument('--model', action='store_true', default=False, 110 | help='For Saving the current Model') 111 | parser.add_argument('--calc_time', action='store_true', default=False, 112 | help='calc calc time per sample') 113 | args = parser.parse_args() 114 | logging.info(args) 115 | use_cuda = not args.no_cuda and torch.cuda.is_available() 116 | 117 | device = torch.device("cuda" if use_cuda else "cpu") 118 | 119 | kwargs = {'batch_size': args.batch_size} 120 | if use_cuda: 121 | kwargs.update({'num_workers': 1, 122 | 'pin_memory': True, 123 | 'shuffle': True},) 124 | 125 | # alias 126 | Config = None 127 | data_generator = None 128 | load_data = None 129 | clc_num = 0 130 | if args.dataset == 0: 131 | Config = JConfig() 132 | data_generator = Jdata_generator 133 | load_data = load_jhmdb_data 134 | clc_num = Config.clc_num 135 | elif args.dataset == 1: 136 | Config = SConfig() 137 | load_data = load_shrec_data 138 | clc_num = Config.class_coarse_num 139 | data_generator = Sdata_generator('coarse_label') 140 | elif args.dataset == 2: 141 | Config = SConfig() 142 | clc_num = Config.class_fine_num 143 | load_data = load_shrec_data 144 | data_generator = Sdata_generator('fine_label') 145 | else: 146 | print("Unsupported dataset!") 147 | sys.exit(1) 148 | 149 | C = Config 150 | Train, Test, le = load_data() 151 | X_0, X_1, Y = data_generator(Train, C, le) 152 | X_0 = torch.from_numpy(X_0).type('torch.FloatTensor') 153 | X_1 = torch.from_numpy(X_1).type('torch.FloatTensor') 154 | Y = torch.from_numpy(Y).type('torch.LongTensor') 155 | 156 | X_0_t, X_1_t, Y_t = data_generator(Test, C, le) 157 | X_0_t = torch.from_numpy(X_0_t).type('torch.FloatTensor') 158 | X_1_t = torch.from_numpy(X_1_t).type('torch.FloatTensor') 159 | Y_t = torch.from_numpy(Y_t).type('torch.LongTensor') 160 | 161 | trainset = torch.utils.data.TensorDataset(X_0, X_1, Y) 162 | train_loader = torch.utils.data.DataLoader(trainset, **kwargs) 163 | 164 | testset = torch.utils.data.TensorDataset(X_0_t, X_1_t, Y_t) 165 | test_loader = torch.utils.data.DataLoader( 166 | testset, batch_size=args.test_batch_size) 167 | 168 | Net = DDNet(C.frame_l, C.joint_n, C.joint_d, 169 | C.feat_d, C.filters, clc_num) 170 | model = Net.to(device) 171 | 172 | summary(model, [(C.frame_l, C.feat_d), (C.frame_l, C.joint_n, C.joint_d)]) 173 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 174 | 175 | criterion = nn.CrossEntropyLoss() 176 | scheduler = ReduceLROnPlateau( 177 | optimizer, factor=args.gamma, patience=5, cooldown=0.5, min_lr=5e-6, verbose=True) 178 | for epoch in range(1, args.epochs + 1): 179 | train_loss = train(args, model, device, train_loader, 180 | optimizer, epoch, criterion) 181 | test(model, device, test_loader) 182 | scheduler.step(train_loss) 183 | 184 | fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1) 185 | ax1.plot(history['train_loss']) 186 | ax1.plot(history['test_loss']) 187 | ax1.legend(['Train', 'Test'], loc='upper left') 188 | ax1.set_xlabel('Epoch') 189 | ax1.set_title('Loss') 190 | 191 | ax2.set_title('Model accuracy') 192 | ax2.set_ylabel('Accuracy') 193 | ax2.set_xlabel('Epoch') 194 | ax2.plot(history['test_acc']) 195 | xmax = np.argmax(history['test_acc']) 196 | ymax = np.max(history['test_acc']) 197 | text = "x={}, y={:.3f}".format(xmax, ymax) 198 | ax2.annotate(text, xy=(xmax, ymax)) 199 | 200 | ax3.set_title('Confusion matrix') 201 | model.eval() 202 | with torch.no_grad(): 203 | Y_pred = model(X_0_t.to(device), X_1_t.to( 204 | device)).cpu().numpy() 205 | Y_test = Y_t.numpy() 206 | cnf_matrix = confusion_matrix( 207 | Y_test, np.argmax(Y_pred, axis=1)) 208 | ax3.imshow(cnf_matrix) 209 | fig.tight_layout() 210 | fig.savefig(str(savedir / "perf.png")) 211 | if args.save_model: 212 | torch.save(model.state_dict(), str(savedir/"model.pt")) 213 | if args.calc_time: 214 | device = ['cpu', 'cuda'] 215 | # calc time 216 | for d in device: 217 | tmp_X_0_t = X_0_t.to(d) 218 | tmp_X_1_t = X_1_t.to(d) 219 | model = model.to(d) 220 | # warm up 221 | _ = model(tmp_X_0_t, tmp_X_1_t) 222 | 223 | tmp_X_0_t = tmp_X_0_t.unsqueeze(1) 224 | tmp_X_1_t = tmp_X_1_t.unsqueeze(1) 225 | start = time.perf_counter_ns() 226 | for i in range(tmp_X_0_t.shape[0]): 227 | _ = model(tmp_X_0_t[i, :, :, :], tmp_X_1_t[i, :, :, :]) 228 | end = time.perf_counter_ns() 229 | msg = ("total {}ns, {:.2f}ns per one on {}".format((end - start), 230 | ((end - start) / (X_0_t.shape[0])), d)) 231 | print(msg) 232 | logging.info(msg) 233 | 234 | 235 | if __name__ == '__main__': 236 | main() 237 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | #! coding:utf-8 3 | 4 | import scipy.ndimage.interpolation as inter 5 | from scipy.spatial.distance import cdist 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from tqdm import tqdm 10 | import pathlib 11 | import copy 12 | from scipy.signal import medfilt 13 | 14 | # Temple resizing function 15 | # interpolate l frames to target_l frames 16 | 17 | 18 | def zoom(p, target_l=64, joints_num=25, joints_dim=3): 19 | p_copy = copy.deepcopy(p) 20 | l = p_copy.shape[0] 21 | p_new = np.empty([target_l, joints_num, joints_dim]) 22 | for m in range(joints_num): 23 | for n in range(joints_dim): 24 | # p_new[:, m, n] = medfilt(p_new[:, m, n], 3) # make no sense. p_new is empty. 25 | p_copy[:, m, n] = medfilt(p_copy[:, m, n], 3) 26 | p_new[:, m, n] = inter.zoom(p_copy[:, m, n], target_l/l)[:target_l] 27 | return p_new 28 | 29 | 30 | # Calculate JCD feature 31 | def norm_scale(x): 32 | return (x-np.mean(x))/np.mean(x) 33 | 34 | 35 | def get_CG(p, C): 36 | M = [] 37 | # upper triangle index with offset 1, which means upper triangle without diagonal 38 | iu = np.triu_indices(C.joint_n, 1, C.joint_n) 39 | for f in range(C.frame_l): 40 | # iterate all frames, calc all frame's JCD Matrix 41 | # p[f].shape (15,2) 42 | d_m = cdist(p[f], p[f], 'euclidean') 43 | d_m = d_m[iu] 44 | # the upper triangle of Matrix and then flattned to a vector. Shape(105) 45 | M.append(d_m) 46 | M = np.stack(M) 47 | M = norm_scale(M) # normalize 48 | return M 49 | 50 | 51 | def poses_diff(x): 52 | _, H, W, _ = x.shape 53 | 54 | # x.shape (batch,channel,joint_num,joint_dim) 55 | x = x[:, 1:, ...] - x[:, :-1, ...] 56 | 57 | # x.shape (batch,joint_dim,channel,joint_num,) 58 | x = x.permute(0, 3, 1, 2) 59 | x = F.interpolate(x, size=(H, W), 60 | align_corners=False, mode='bilinear') 61 | x = x.permute(0, 2, 3, 1) 62 | # x.shape (batch,channel,joint_num,joint_dim) 63 | return x 64 | 65 | 66 | def poses_motion(P): 67 | # different from the original version 68 | # TODO: check the funtion, make sure it's right 69 | P_diff_slow = poses_diff(P) 70 | P_diff_slow = torch.flatten(P_diff_slow, start_dim=2) 71 | P_fast = P[:, ::2, :, :] 72 | P_diff_fast = poses_diff(P_fast) 73 | P_diff_fast = torch.flatten(P_diff_fast, start_dim=2) 74 | # return (B,target_l,joint_d * joint_n) , (B,target_l/2,joint_d * joint_n) 75 | return P_diff_slow, P_diff_fast 76 | 77 | 78 | def makedir(path): 79 | pathlib.Path(path).mkdir(parents=True, exist_ok=True) 80 | --------------------------------------------------------------------------------