├── torchlight ├── torchlight.egg-info │ ├── dependency_links.txt │ ├── top_level.txt │ ├── SOURCES.txt │ └── PKG-INFO ├── torchlight │ ├── __pycache__ │ │ ├── gpu.cpython-36.pyc │ │ ├── util.cpython-36.pyc │ │ ├── util.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-37.pyc │ ├── __init__.py │ ├── gpu.py │ └── util.py └── setup.py ├── feeders ├── __init__.py ├── __pycache__ │ ├── tools.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── bone_pairs.cpython-36.pyc │ ├── feeder_ntu.cpython-36.pyc │ └── feeder_ucla.cpython-36.pyc ├── bone_pairs.py ├── feeder_ntu.py └── tools.py ├── figures └── framework.PNG ├── data ├── NW-UCLA │ └── val_label.pkl ├── ntu │ ├── get_raw_skes_data.py │ ├── statistics │ │ └── samples_with_missing_skeletons.txt │ ├── seq_transformation.py │ └── get_raw_denoised_data.py └── ntu120 │ ├── get_raw_skes_data.py │ ├── seq_transformation.py │ ├── statistics │ └── NTU_RGBD120_samples_with_missing_skeletons.txt │ └── get_raw_denoised_data.py ├── model ├── __pycache__ │ ├── aha.cpython-36.pyc │ ├── hbg.cpython-36.pyc │ ├── hca.cpython-36.pyc │ ├── hdgcn.cpython-36.pyc │ ├── tools.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── baseline.cpython-36.pyc │ ├── hdgcn_hbg.cpython-36.pyc │ ├── ctrgcn_flops.cpython-36.pyc │ ├── dynamic_repr.cpython-36.pyc │ ├── flops_count.cpython-36.pyc │ ├── hdgcn_edge.cpython-36.pyc │ ├── hdgcn_edge_T.cpython-36.pyc │ ├── hdgcn_flops.cpython-36.pyc │ ├── hdgcn_main.cpython-36.pyc │ ├── spatial_conv.cpython-36.pyc │ ├── hdgcn_edge_att.cpython-36.pyc │ ├── hdgcn_main_v2.cpython-36.pyc │ ├── temporal_conv.cpython-36.pyc │ ├── hdgcn_edge_T_att.cpython-36.pyc │ ├── hdgcn_edge_att_v2.cpython-36.pyc │ └── spatial_conv_hbg.cpython-36.pyc └── HDGCN.py ├── graph ├── __pycache__ │ ├── tools.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── ntu_rgb_d.cpython-36.pyc │ ├── ucla_hierarchy.cpython-36.pyc │ ├── ntu_rgb_d_hierarchy.cpython-36.pyc │ └── ntu_rgb_d_hierarchy_v2.cpython-36.pyc ├── __init__.py ├── ntu_rgb_d_hierarchy.py └── tools.py ├── run_ensemble.sh ├── config ├── nturgbd-cross-view │ ├── bone_com_1.yaml │ ├── bone_com_2.yaml │ ├── bone_com_21.yaml │ ├── joint_com_1.yaml │ ├── joint_com_2.yaml │ └── joint_com_21.yaml ├── nturgbd-cross-subject │ ├── bone_com_1.yaml │ ├── bone_com_2.yaml │ ├── bone_com_21.yaml │ ├── joint_com_1.yaml │ ├── joint_com_2.yaml │ └── joint_com_21.yaml ├── nturgbd120-cross-setup │ ├── bone_com_1.yaml │ ├── bone_com_2.yaml │ ├── bone_com_21.yaml │ ├── joint_com_1.yaml │ ├── joint_com_2.yaml │ └── joint_com_21.yaml └── nturgbd120-cross-subject │ ├── bone_com_1.yaml │ ├── bone_com_2.yaml │ ├── bone_com_21.yaml │ ├── joint_com_1.yaml │ ├── joint_com_2.yaml │ └── joint_com_21.yaml ├── LICENSE ├── ensemble.py ├── README.md └── main.py /torchlight/torchlight.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchlight/torchlight.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | torchlight 2 | -------------------------------------------------------------------------------- /feeders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tools 2 | from . import feeder_ucla 3 | from . import feeder_ntu -------------------------------------------------------------------------------- /figures/framework.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/figures/framework.PNG -------------------------------------------------------------------------------- /data/NW-UCLA/val_label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/data/NW-UCLA/val_label.pkl -------------------------------------------------------------------------------- /model/__pycache__/aha.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/aha.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hbg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hbg.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hca.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hca.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/feeders/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/graph/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tools 2 | from . import ntu_rgb_d_hierarchy 3 | from . import ntu_rgb_d 4 | from . import ucla_hierarchy -------------------------------------------------------------------------------- /graph/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/graph/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ntu_rgb_d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/graph/__pycache__/ntu_rgb_d.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/baseline.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_hbg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_hbg.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/feeders/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/bone_pairs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/feeders/__pycache__/bone_pairs.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder_ntu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/feeders/__pycache__/feeder_ntu.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ctrgcn_flops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/ctrgcn_flops.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/dynamic_repr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/dynamic_repr.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/flops_count.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/flops_count.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_edge.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_edge.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_edge_T.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_edge_T.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_flops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_flops.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_main.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_main.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/spatial_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/spatial_conv.cpython-36.pyc -------------------------------------------------------------------------------- /feeders/__pycache__/feeder_ucla.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/feeders/__pycache__/feeder_ucla.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ucla_hierarchy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/graph/__pycache__/ucla_hierarchy.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_edge_att.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_edge_att.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_main_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_main_v2.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/temporal_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/temporal_conv.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_edge_T_att.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_edge_T_att.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/hdgcn_edge_att_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/hdgcn_edge_att_v2.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/spatial_conv_hbg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/model/__pycache__/spatial_conv_hbg.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ntu_rgb_d_hierarchy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/graph/__pycache__/ntu_rgb_d_hierarchy.cpython-36.pyc -------------------------------------------------------------------------------- /torchlight/torchlight/__pycache__/gpu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/torchlight/torchlight/__pycache__/gpu.cpython-36.pyc -------------------------------------------------------------------------------- /graph/__pycache__/ntu_rgb_d_hierarchy_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/graph/__pycache__/ntu_rgb_d_hierarchy_v2.cpython-36.pyc -------------------------------------------------------------------------------- /torchlight/torchlight/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/torchlight/torchlight/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /torchlight/torchlight/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/torchlight/torchlight/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /torchlight/torchlight/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/torchlight/torchlight/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /torchlight/torchlight/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jho-Yonsei/HD-GCN/HEAD/torchlight/torchlight/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /torchlight/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='torchlight', 5 | version='1.0', 6 | description='A mini framework for pytorch', 7 | packages=find_packages(), 8 | install_requires=[]) 9 | -------------------------------------------------------------------------------- /torchlight/torchlight.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | torchlight/__init__.py 3 | torchlight/gpu.py 4 | torchlight/util.py 5 | torchlight.egg-info/PKG-INFO 6 | torchlight.egg-info/SOURCES.txt 7 | torchlight.egg-info/dependency_links.txt 8 | torchlight.egg-info/top_level.txt -------------------------------------------------------------------------------- /torchlight/torchlight/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import IO 2 | from .util import str2bool 3 | from .util import str2dict 4 | from .util import DictAction 5 | from .util import import_class 6 | from .gpu import visible_gpu 7 | from .gpu import occupy_gpu 8 | from .gpu import ngpu 9 | -------------------------------------------------------------------------------- /torchlight/torchlight.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: torchlight 3 | Version: 1.0 4 | Summary: A mini framework for pytorch 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /feeders/bone_pairs.py: -------------------------------------------------------------------------------- 1 | ntu_pairs = ( 2 | (1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), 3 | (7, 6), (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), 4 | (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), 5 | (19, 18), (20, 19), (22, 23), (21, 21), (23, 8), (24, 25),(25, 12) 6 | ) 7 | -------------------------------------------------------------------------------- /run_ensemble.sh: -------------------------------------------------------------------------------- 1 | printf "\nNTU-RGB+D 60 Cross-Subject\n" 2 | python ensemble.py \ 3 | --dataset ntu/xsub \ 4 | --main-dir ./work_dir/ntu/cross-subject/ 5 | 6 | printf "\nNTU-RGB+D 60 Cross-View\n" 7 | python ensemble.py \ 8 | --dataset ntu/xview \ 9 | --main-dir ./work_dir/ntu/cross-view/ 10 | 11 | printf "\nNTU-RGB+D 120 Cross-Subject\n" 12 | python ensemble.py \ 13 | --dataset ntu120/xsub \ 14 | --main-dir ./work_dir/ntu120/cross-subject/ 15 | 16 | printf "\nNTU-RGB+D 120 Cross-Setup\n" 17 | python ensemble.py \ 18 | --dataset ntu120/xset \ 19 | --main-dir ./work_dir/ntu120/cross-setup/ 20 | 21 | printf "\n" 22 | -------------------------------------------------------------------------------- /torchlight/torchlight/gpu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def visible_gpu(gpus): 6 | """ 7 | set visible gpu. 8 | 9 | can be a single id, or a list 10 | 11 | return a list of new gpus ids 12 | """ 13 | gpus = [gpus] if isinstance(gpus, int) else list(gpus) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) 15 | return list(range(len(gpus))) 16 | 17 | 18 | def ngpu(gpus): 19 | """ 20 | count how many gpus used. 21 | """ 22 | gpus = [gpus] if isinstance(gpus, int) else list(gpus) 23 | return len(gpus) 24 | 25 | 26 | def occupy_gpu(gpus=None): 27 | """ 28 | make program appear on nvidia-smi. 29 | """ 30 | if gpus is None: 31 | torch.zeros(1).cuda() 32 | else: 33 | gpus = [gpus] if isinstance(gpus, int) else list(gpus) 34 | for g in gpus: 35 | torch.zeros(1).cuda(g) 36 | -------------------------------------------------------------------------------- /config/nturgbd-cross-view/bone_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-view/bone_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CV.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CV.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-view/bone_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-view/bone_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CV.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CV.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-view/bone_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-view/bone_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CV.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CV.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-subject/bone_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-subject/bone_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CS.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CS.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-subject/bone_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-subject/bone_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CS.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CS.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-view/joint_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-view/joint_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CV.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CV.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-view/joint_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-view/joint_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CV.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CV.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-view/joint_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-view/joint_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CV.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CV.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-subject/bone_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-subject/bone_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CS.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CS.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-subject/joint_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-subject/joint_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CS.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CS.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-subject/joint_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-subject/joint_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CS.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CS.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd-cross-subject/joint_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu_hdgcn/cross-subject/joint_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu/NTU60_CS.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu/NTU60_CS.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 60 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-setup/bone_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-setup/bone_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSet.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSet.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-setup/bone_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-setup/bone_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSet.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSet.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-setup/bone_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-setup/bone_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSet.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSet.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-setup/joint_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-setup/joint_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSet.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSet.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-setup/joint_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-setup/joint_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSet.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSet.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-setup/joint_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-setup/joint_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSet.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSet.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-subject/bone_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-subject/bone_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSub.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSub.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-subject/bone_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-subject/bone_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSub.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSub.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-subject/bone_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-subject/bone_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSub.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: True 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSub.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: True 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-subject/joint_com_1.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-subject/joint_CoM_1/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSub.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSub.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 1 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-subject/joint_com_2.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-subject/joint_CoM_2/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSub.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSub.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 2 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /config/nturgbd120-cross-subject/joint_com_21.yaml: -------------------------------------------------------------------------------- 1 | num_worker: 8 2 | work_dir: ./work_dir/ntu120_hdgcn/cross-subject/joint_CoM_21/ 3 | 4 | # feeder 5 | feeder: feeders.feeder_ntu.Feeder 6 | train_feeder_args: 7 | data_path: ./data/ntu120/NTU120_CSub.npz 8 | split: train 9 | debug: False 10 | random_choose: False 11 | random_shift: False 12 | random_move: False 13 | window_size: 64 14 | normalization: False 15 | random_rot: True 16 | p_interval: [0.5, 1] 17 | bone: False 18 | 19 | test_feeder_args: 20 | data_path: ./data/ntu120/NTU120_CSub.npz 21 | split: test 22 | window_size: 64 23 | p_interval: [0.95] 24 | bone: False 25 | debug: False 26 | 27 | # model 28 | model: model.HDGCN.Model 29 | model_args: 30 | num_class: 120 31 | num_point: 25 32 | num_person: 2 33 | graph: graph.ntu_rgb_d_hierarchy.Graph 34 | graph_args: 35 | labeling_mode: 'spatial' 36 | CoM: 21 37 | 38 | #optim 39 | weight_decay: 0.0004 40 | base_lr: 0.1 41 | warm_up_epoch: 5 42 | 43 | # training 44 | device: [0] 45 | batch_size: 64 46 | test_batch_size: 64 47 | num_epoch: 90 48 | nesterov: True 49 | -------------------------------------------------------------------------------- /graph/ntu_rgb_d_hierarchy.py: -------------------------------------------------------------------------------- 1 | from audioop import reverse 2 | import sys 3 | import numpy as np 4 | 5 | sys.path.extend(['../']) 6 | from graph import tools 7 | 8 | num_node = 25 9 | 10 | class Graph: 11 | def __init__(self, CoM=21, labeling_mode='spatial'): 12 | self.num_node = num_node 13 | self.CoM = CoM 14 | self.A = self.get_adjacency_matrix(labeling_mode) 15 | 16 | 17 | def get_adjacency_matrix(self, labeling_mode=None): 18 | if labeling_mode is None: 19 | return self.A 20 | if labeling_mode == 'spatial': 21 | A = tools.get_hierarchical_graph(num_node, tools.get_edgeset(dataset='NTU', CoM=self.CoM)) # L, 3, 25, 25 22 | else: 23 | raise ValueError() 24 | return A, self.CoM 25 | 26 | 27 | if __name__ == '__main__': 28 | import tools 29 | g = Graph().A 30 | import matplotlib.pyplot as plt 31 | for i, g_ in enumerate(g[0]): 32 | plt.imshow(g_, cmap='gray') 33 | cb = plt.colorbar() 34 | plt.savefig('./graph_{}.png'.format(i)) 35 | cb.remove() 36 | plt.show() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jungho Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /graph/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def edge2mat(link, num_node): 4 | A = np.zeros((num_node, num_node)) 5 | for i, j in link: 6 | A[j, i] = 1 7 | return A 8 | 9 | def normalize_digraph(A): 10 | Dl = np.sum(A, 0) 11 | h, w = A.shape 12 | Dn = np.zeros((w, w)) 13 | for i in range(w): 14 | if Dl[i] > 0: 15 | Dn[i, i] = Dl[i] ** (-1) 16 | AD = np.dot(A, Dn) 17 | return AD 18 | 19 | def get_spatial_graph(num_node, hierarchy): 20 | A = [] 21 | for i in range(len(hierarchy)): 22 | A.append(normalize_digraph(edge2mat(hierarchy[i], num_node))) 23 | 24 | A = np.stack(A) 25 | 26 | return A 27 | 28 | def get_spatial_graph_original(num_node, self_link, inward, outward): 29 | I = edge2mat(self_link, num_node) 30 | In = normalize_digraph(edge2mat(inward, num_node)) 31 | Out = normalize_digraph(edge2mat(outward, num_node)) 32 | A = np.stack((I, In, Out)) 33 | return A 34 | 35 | def normalize_adjacency_matrix(A): 36 | node_degrees = A.sum(-1) 37 | degs_inv_sqrt = np.power(node_degrees, -0.5) 38 | norm_degs_matrix = np.eye(len(node_degrees)) * degs_inv_sqrt 39 | return (norm_degs_matrix @ A @ norm_degs_matrix).astype(np.float32) 40 | 41 | def get_graph(num_node, edges): 42 | 43 | I = edge2mat(edges[0], num_node) 44 | Forward = normalize_digraph(edge2mat(edges[1], num_node)) 45 | Reverse = normalize_digraph(edge2mat(edges[2], num_node)) 46 | A = np.stack((I, Forward, Reverse)) 47 | return A # 3, 25, 25 48 | 49 | def get_hierarchical_graph(num_node, edges): 50 | A = [] 51 | for edge in edges: 52 | A.append(get_graph(num_node, edge)) 53 | A = np.stack(A) 54 | return A 55 | 56 | def get_groups(dataset='NTU', CoM=21): 57 | groups =[] 58 | 59 | if dataset == 'NTU': 60 | if CoM == 2: 61 | groups.append([2]) 62 | groups.append([1, 21]) 63 | groups.append([13, 17, 3, 5, 9]) 64 | groups.append([14, 18, 4, 6, 10]) 65 | groups.append([15, 19, 7, 11]) 66 | groups.append([16, 20, 8, 12]) 67 | groups.append([22, 23, 24, 25]) 68 | 69 | ## Center of mass : 21 70 | elif CoM == 21: 71 | groups.append([21]) 72 | groups.append([2, 3, 5, 9]) 73 | groups.append([4, 6, 10, 1]) 74 | groups.append([7, 11, 13, 17]) 75 | groups.append([8, 12, 14, 18]) 76 | groups.append([22, 23, 24, 25, 15, 19]) 77 | groups.append([16, 20]) 78 | 79 | ## Center of Mass : 1 80 | elif CoM == 1: 81 | groups.append([1]) 82 | groups.append([2, 13, 17]) 83 | groups.append([14, 18, 21]) 84 | groups.append([3, 5, 9, 15, 19]) 85 | groups.append([4, 6, 10, 16, 20]) 86 | groups.append([7, 11]) 87 | groups.append([8, 12, 22, 23, 24, 25]) 88 | 89 | else: 90 | raise ValueError() 91 | 92 | return groups 93 | 94 | def get_edgeset(dataset='NTU', CoM=21): 95 | groups = get_groups(dataset=dataset, CoM=CoM) 96 | 97 | for i, group in enumerate(groups): 98 | group = [i - 1 for i in group] 99 | groups[i] = group 100 | 101 | identity = [] 102 | forward_hierarchy = [] 103 | reverse_hierarchy = [] 104 | 105 | for i in range(len(groups) - 1): 106 | self_link = groups[i] + groups[i + 1] 107 | self_link = [(i, i) for i in self_link] 108 | identity.append(self_link) 109 | forward_g = [] 110 | for j in groups[i]: 111 | for k in groups[i + 1]: 112 | forward_g.append((j, k)) 113 | forward_hierarchy.append(forward_g) 114 | 115 | reverse_g = [] 116 | for j in groups[-1 - i]: 117 | for k in groups[-2 - i]: 118 | reverse_g.append((j, k)) 119 | reverse_hierarchy.append(reverse_g) 120 | 121 | edges = [] 122 | for i in range(len(groups) - 1): 123 | edges.append([identity[i], forward_hierarchy[i], reverse_hierarchy[-1 - i]]) 124 | 125 | return edges -------------------------------------------------------------------------------- /feeders/feeder_ntu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | from torch.utils.data import Dataset 5 | 6 | from feeders import tools 7 | 8 | 9 | class Feeder(Dataset): 10 | def __init__(self, data_path, label_path=None, p_interval=1, split='train', random_choose=False, random_shift=False, 11 | random_move=False, random_rot=False, window_size=-1, normalization=False, debug=False, use_mmap=False, 12 | bone=False): 13 | """ 14 | :param data_path: 15 | :param label_path: 16 | :param split: training set or test set 17 | :param random_choose: If true, randomly choose a portion of the input sequence 18 | :param random_shift: If true, randomly pad zeros at the begining or end of sequence 19 | :param random_move: 20 | :param random_rot: rotate skeleton around xyz axis 21 | :param window_size: The length of the output sequence 22 | :param normalization: If true, normalize input sequence 23 | :param debug: If true, only use the first 100 samples 24 | :param use_mmap: If true, use mmap mode to load data, which can save the running memory 25 | :param bone: use bone modality or not 26 | :param vel: use motion modality or not 27 | :param only_label: only load label for ensemble score compute 28 | """ 29 | 30 | self.debug = debug 31 | self.data_path = data_path 32 | self.label_path = label_path 33 | self.split = split 34 | self.random_choose = random_choose 35 | self.random_shift = random_shift 36 | self.random_move = random_move 37 | self.window_size = window_size 38 | self.normalization = normalization 39 | self.use_mmap = use_mmap 40 | self.p_interval = p_interval 41 | self.random_rot = random_rot 42 | self.bone = bone 43 | self.load_data() 44 | if normalization: 45 | self.get_mean_map() 46 | 47 | def load_data(self): 48 | # data: N C V T M 49 | npz_data = np.load(self.data_path) 50 | if self.split == 'train': 51 | self.data = npz_data['x_train'] 52 | self.label = np.where(npz_data['y_train'] > 0)[1] 53 | self.sample_name = ['train_' + str(i) for i in range(len(self.data))] 54 | elif self.split == 'test': 55 | self.data = npz_data['x_test'] 56 | self.label = np.where(npz_data['y_test'] > 0)[1] 57 | self.sample_name = ['test_' + str(i) for i in range(len(self.data))] 58 | else: 59 | raise NotImplementedError('data split only supports train/test') 60 | N, T, _ = self.data.shape 61 | self.data = self.data.reshape((N, T, 2, 25, 3)).transpose(0, 4, 1, 3, 2) 62 | 63 | def get_mean_map(self): 64 | data = self.data 65 | N, C, T, V, M = data.shape 66 | self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0) 67 | self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1)) 68 | 69 | def __len__(self): 70 | return len(self.label) 71 | 72 | def __iter__(self): 73 | return self 74 | 75 | def __getitem__(self, index): 76 | data_numpy = self.data[index] 77 | label = self.label[index] 78 | data_numpy = np.array(data_numpy) 79 | valid_frame_num = np.sum(data_numpy.sum(0).sum(-1).sum(-1) != 0) 80 | # reshape Tx(MVC) to CTVM 81 | data_numpy = tools.valid_crop_resize(data_numpy, valid_frame_num, self.p_interval, self.window_size) 82 | if self.random_rot: 83 | data_numpy = tools.random_rot(data_numpy) 84 | if self.bone: 85 | from .bone_pairs import ntu_pairs 86 | bone_data_numpy = np.zeros_like(data_numpy) # 3, T, V 87 | for v1, v2 in ntu_pairs: 88 | bone_data_numpy[:, :, v1 - 1] = data_numpy[:, :, v1 - 1] - data_numpy[:, :, v2 - 1] 89 | data_numpy = bone_data_numpy 90 | return data_numpy, label, index 91 | 92 | def top_k(self, score, top_k): 93 | rank = score.argsort() 94 | hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)] 95 | return sum(hit_top_k) * 1.0 / len(hit_top_k) 96 | 97 | 98 | def import_class(name): 99 | components = name.split('.') 100 | mod = __import__(components[0]) 101 | for comp in components[1:]: 102 | mod = getattr(mod, comp) 103 | return mod 104 | -------------------------------------------------------------------------------- /ensemble.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | import math 8 | 9 | def str2bool(v): 10 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 11 | return True 12 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 13 | return False 14 | else: 15 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--dataset', 21 | required=True, 22 | choices={'ntu/xsub', 'ntu/xview', 'ntu120/xsub', 'ntu120/xset', 'NW-UCLA'}, 23 | help='the work folder for storing results') 24 | parser.add_argument('--main-dir', 25 | help='') 26 | parser.add_argument('--CoM-2', 27 | type=str2bool, 28 | default=True) 29 | parser.add_argument('--CoM-21', 30 | type=str2bool, 31 | default=True) 32 | parser.add_argument('--CoM-1', 33 | type=str2bool, 34 | default=True) 35 | 36 | arg = parser.parse_args() 37 | 38 | dataset = arg.dataset 39 | if 'UCLA' in arg.dataset: 40 | label = [] 41 | with open('./data/' + 'NW-UCLA/' + '/val_label.pkl', 'rb') as f: 42 | data_info = pickle.load(f) 43 | for index in range(len(data_info)): 44 | info = data_info[index] 45 | label.append(int(info['label']) - 1) 46 | elif 'ntu120' in arg.dataset: 47 | if 'xsub' in arg.dataset: 48 | npz_data = np.load('./data/' + 'ntu120/' + 'NTU120_CSub.npz') 49 | label = np.where(npz_data['y_test'] > 0)[1] 50 | elif 'xset' in arg.dataset: 51 | npz_data = np.load('./data/' + 'ntu120/' + 'NTU120_CSet.npz') 52 | label = np.where(npz_data['y_test'] > 0)[1] 53 | elif 'ntu' in arg.dataset: 54 | if 'xsub' in arg.dataset: 55 | npz_data = np.load('./data/' + 'ntu/' + 'NTU60_CS.npz') 56 | label = np.where(npz_data['y_test'] > 0)[1] 57 | elif 'xview' in arg.dataset: 58 | npz_data = np.load('./data/' + 'ntu/' + 'NTU60_CV.npz') 59 | label = np.where(npz_data['y_test'] > 0)[1] 60 | else: 61 | raise NotImplementedError 62 | 63 | dir_cnt = 0 64 | 65 | if arg.CoM_1: 66 | with open(os.path.join(arg.main_dir, 'joint_CoM_1/', 'epoch1_test_score.pkl'), 'rb') as r1: 67 | r1 = list(pickle.load(r1).items()) 68 | with open(os.path.join(arg.main_dir, 'bone_CoM_1/', 'epoch1_test_score.pkl'), 'rb') as r2: 69 | r2 = list(pickle.load(r2).items()) 70 | dir_cnt += 2 71 | 72 | if arg.CoM_2: 73 | with open(os.path.join(arg.main_dir, 'joint_CoM_2/', 'epoch1_test_score.pkl'), 'rb') as r3: 74 | r3 = list(pickle.load(r3).items()) 75 | with open(os.path.join(arg.main_dir, 'bone_CoM_2/', 'epoch1_test_score.pkl'), 'rb') as r4: 76 | r4 = list(pickle.load(r4).items()) 77 | dir_cnt += 2 78 | 79 | if arg.CoM_21: 80 | if 'ntu' in arg.dataset: 81 | with open(os.path.join(arg.main_dir, 'joint_CoM_21/' 'epoch1_test_score.pkl'), 'rb') as r5: 82 | r5 = list(pickle.load(r5).items()) 83 | with open(os.path.join(arg.main_dir, 'bone_CoM_21/', 'epoch1_test_score.pkl'), 'rb') as r6: 84 | r6 = list(pickle.load(r6).items()) 85 | dir_cnt += 2 86 | elif 'UCLA' in arg.dataset: 87 | with open(os.path.join(arg.main_dir, 'joint_CoM_3/' 'epoch1_test_score.pkl'), 'rb') as r5: 88 | r5 = list(pickle.load(r5).items()) 89 | with open(os.path.join(arg.main_dir, 'bone_CoM_3/', 'epoch1_test_score.pkl'), 'rb') as r6: 90 | r6 = list(pickle.load(r6).items()) 91 | dir_cnt += 2 92 | 93 | right_num = total_num = right_num_5 = 0 94 | 95 | norm = lambda x: x / np.linalg.norm(x) 96 | 97 | if dir_cnt == 6: 98 | for i in tqdm(range(len(label))): 99 | l = label[i] 100 | r11 = norm(np.array(r1[i][1])) 101 | r22 = norm(np.array(r2[i][1])) 102 | r33 = norm(np.array(r3[i][1])) 103 | r44 = norm(np.array(r4[i][1])) 104 | r55 = norm(np.array(r5[i][1])) 105 | r66 = norm(np.array(r6[i][1])) 106 | r = r11 + r22 + r33 + r44 + r55 + r66 107 | rank_5 = r.argsort()[-5:] 108 | right_num_5 += int(int(l) in rank_5) 109 | r = np.argmax(r) 110 | right_num += int(r == int(l)) 111 | total_num += 1 112 | acc = right_num / total_num 113 | acc5 = right_num_5 / total_num 114 | 115 | elif dir_cnt == 4: 116 | r = None 117 | for i in tqdm(range(len(label))): 118 | l = label[i] 119 | if arg.CoM_1: 120 | r11 = np.array(r1[i][1]) 121 | r22 = np.array(r2[i][1]) 122 | r = norm(r11) + norm(r22) 123 | if arg.CoM_2: 124 | r33 = np.array(r3[i][1]) 125 | r44 = np.array(r4[i][1]) 126 | r = r + norm(r33) + norm(r44) if r is not None else norm(r33) + norm(r44) 127 | if arg.CoM_21: 128 | r55 = np.array(r5[i][1]) 129 | r66 = np.array(r6[i][1]) 130 | r = r + norm(r55) + norm(r66) if r is not None else norm(r55) + norm(r66) 131 | 132 | rank_5 = r.argsort()[-5:] 133 | right_num_5 += int(int(l) in rank_5) 134 | r = np.argmax(r) 135 | right_num += int(r == int(l)) 136 | total_num += 1 137 | acc = right_num / total_num 138 | acc5 = right_num_5 / total_num 139 | 140 | elif dir_cnt == 2: 141 | r = None 142 | for i in tqdm(range(len(label))): 143 | l = label[i] 144 | if arg.CoM_1: 145 | r11 = np.array(r1[i][1]) 146 | r22 = np.array(r2[i][1]) 147 | r = norm(r11) + norm(r22) 148 | if arg.CoM_2: 149 | r33 = np.array(r3[i][1]) 150 | r44 = np.array(r4[i][1]) 151 | r = r + norm(r33) + norm(r44) if r is not None else norm(r33) + norm(r44) 152 | if arg.CoM_21: 153 | r55 = np.array(r5[i][1]) 154 | r66 = np.array(r6[i][1]) 155 | r = r + norm(r55) + norm(r66) if r is not None else norm(r55) + norm(r66) 156 | 157 | rank_5 = r.argsort()[-5:] 158 | right_num_5 += int(int(l) in rank_5) 159 | r = np.argmax(r) 160 | right_num += int(r == int(l)) 161 | total_num += 1 162 | acc = right_num / total_num 163 | acc5 = right_num_5 / total_num 164 | print('Top1 Acc: {:.4f}%'.format(acc * 100)) 165 | print('Top5 Acc: {:.4f}%'.format(acc5 * 100)) 166 | -------------------------------------------------------------------------------- /data/ntu/get_raw_skes_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import os.path as osp 4 | import os 5 | import numpy as np 6 | import pickle 7 | import logging 8 | 9 | 10 | def get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger): 11 | """ 12 | Get raw bodies data from a skeleton sequence. 13 | 14 | Each body's data is a dict that contains the following keys: 15 | - joints: raw 3D joints positions. Shape: (num_frames x 25, 3) 16 | - colors: raw 2D color locations. Shape: (num_frames, 25, 2) 17 | - interval: a list which stores the frame indices of this body. 18 | - motion: motion amount (only for the sequence with 2 or more bodyIDs). 19 | 20 | Return: 21 | a dict for a skeleton sequence with 3 key-value pairs: 22 | - name: the skeleton filename. 23 | - data: a dict which stores raw data of each body. 24 | - num_frames: the number of valid frames. 25 | """ 26 | ske_file = osp.join(skes_path, ske_name + '.skeleton') 27 | assert osp.exists(ske_file), 'Error: Skeleton file %s not found' % ske_file 28 | # Read all data from .skeleton file into a list (in string format) 29 | print('Reading data from %s' % ske_file[-29:]) 30 | with open(ske_file, 'r') as fr: 31 | str_data = fr.readlines() 32 | 33 | num_frames = int(str_data[0].strip('\r\n')) 34 | frames_drop = [] 35 | bodies_data = dict() 36 | valid_frames = -1 # 0-based index 37 | current_line = 1 38 | 39 | for f in range(num_frames): 40 | num_bodies = int(str_data[current_line].strip('\r\n')) 41 | current_line += 1 42 | 43 | if num_bodies == 0: # no data in this frame, drop it 44 | frames_drop.append(f) # 0-based index 45 | continue 46 | 47 | valid_frames += 1 48 | joints = np.zeros((num_bodies, 25, 3), dtype=np.float32) 49 | colors = np.zeros((num_bodies, 25, 2), dtype=np.float32) 50 | 51 | for b in range(num_bodies): 52 | bodyID = str_data[current_line].strip('\r\n').split()[0] 53 | current_line += 1 54 | num_joints = int(str_data[current_line].strip('\r\n')) # 25 joints 55 | current_line += 1 56 | 57 | for j in range(num_joints): 58 | temp_str = str_data[current_line].strip('\r\n').split() 59 | joints[b, j, :] = np.array(temp_str[:3], dtype=np.float32) 60 | colors[b, j, :] = np.array(temp_str[5:7], dtype=np.float32) 61 | current_line += 1 62 | 63 | if bodyID not in bodies_data: # Add a new body's data 64 | body_data = dict() 65 | body_data['joints'] = joints[b] # ndarray: (25, 3) 66 | body_data['colors'] = colors[b, np.newaxis] # ndarray: (1, 25, 2) 67 | body_data['interval'] = [valid_frames] # the index of the first frame 68 | else: # Update an already existed body's data 69 | body_data = bodies_data[bodyID] 70 | # Stack each body's data of each frame along the frame order 71 | body_data['joints'] = np.vstack((body_data['joints'], joints[b])) 72 | body_data['colors'] = np.vstack((body_data['colors'], colors[b, np.newaxis])) 73 | pre_frame_idx = body_data['interval'][-1] 74 | body_data['interval'].append(pre_frame_idx + 1) # add a new frame index 75 | 76 | bodies_data[bodyID] = body_data # Update bodies_data 77 | 78 | num_frames_drop = len(frames_drop) 79 | assert num_frames_drop < num_frames, \ 80 | 'Error: All frames data (%d) of %s is missing or lost' % (num_frames, ske_name) 81 | if num_frames_drop > 0: 82 | frames_drop_skes[ske_name] = np.array(frames_drop, dtype=np.int) 83 | frames_drop_logger.info('{}: {} frames missed: {}\n'.format(ske_name, num_frames_drop, 84 | frames_drop)) 85 | 86 | # Calculate motion (only for the sequence with 2 or more bodyIDs) 87 | if len(bodies_data) > 1: 88 | for body_data in bodies_data.values(): 89 | body_data['motion'] = np.sum(np.var(body_data['joints'], axis=0)) 90 | 91 | return {'name': ske_name, 'data': bodies_data, 'num_frames': num_frames - num_frames_drop} 92 | 93 | 94 | def get_raw_skes_data(): 95 | # # save_path = './data' 96 | # # skes_path = '/data/pengfei/NTU/nturgb+d_skeletons/' 97 | # stat_path = osp.join(save_path, 'statistics') 98 | # 99 | # skes_name_file = osp.join(stat_path, 'skes_available_name.txt') 100 | # save_data_pkl = osp.join(save_path, 'raw_skes_data.pkl') 101 | # frames_drop_pkl = osp.join(save_path, 'frames_drop_skes.pkl') 102 | # 103 | # frames_drop_logger = logging.getLogger('frames_drop') 104 | # frames_drop_logger.setLevel(logging.INFO) 105 | # frames_drop_logger.addHandler(logging.FileHandler(osp.join(save_path, 'frames_drop.log'))) 106 | # frames_drop_skes = dict() 107 | 108 | skes_name = np.loadtxt(skes_name_file, dtype=str) 109 | 110 | num_files = skes_name.size 111 | print('Found %d available skeleton files.' % num_files) 112 | 113 | raw_skes_data = [] 114 | frames_cnt = np.zeros(num_files, dtype=np.int) 115 | 116 | for (idx, ske_name) in enumerate(skes_name): 117 | bodies_data = get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger) 118 | raw_skes_data.append(bodies_data) 119 | frames_cnt[idx] = bodies_data['num_frames'] 120 | if (idx + 1) % 1000 == 0: 121 | print('Processed: %.2f%% (%d / %d)' % \ 122 | (100.0 * (idx + 1) / num_files, idx + 1, num_files)) 123 | 124 | with open(save_data_pkl, 'wb') as fw: 125 | pickle.dump(raw_skes_data, fw, pickle.HIGHEST_PROTOCOL) 126 | np.savetxt(osp.join(save_path, 'raw_data', 'frames_cnt.txt'), frames_cnt, fmt='%d') 127 | 128 | print('Saved raw bodies data into %s' % save_data_pkl) 129 | print('Total frames: %d' % np.sum(frames_cnt)) 130 | 131 | with open(frames_drop_pkl, 'wb') as fw: 132 | pickle.dump(frames_drop_skes, fw, pickle.HIGHEST_PROTOCOL) 133 | 134 | if __name__ == '__main__': 135 | save_path = './' 136 | 137 | skes_path = '../nturgbd_raw/nturgb+d_skeletons/' 138 | stat_path = osp.join(save_path, 'statistics') 139 | if not osp.exists('./raw_data'): 140 | os.makedirs('./raw_data') 141 | 142 | skes_name_file = osp.join(stat_path, 'skes_available_name.txt') 143 | save_data_pkl = osp.join(save_path, 'raw_data', 'raw_skes_data.pkl') 144 | frames_drop_pkl = osp.join(save_path, 'raw_data', 'frames_drop_skes.pkl') 145 | 146 | frames_drop_logger = logging.getLogger('frames_drop') 147 | frames_drop_logger.setLevel(logging.INFO) 148 | frames_drop_logger.addHandler(logging.FileHandler(osp.join(save_path, 'raw_data', 'frames_drop.log'))) 149 | frames_drop_skes = dict() 150 | 151 | get_raw_skes_data() 152 | 153 | with open(frames_drop_pkl, 'wb') as fw: 154 | pickle.dump(frames_drop_skes, fw, pickle.HIGHEST_PROTOCOL) 155 | 156 | -------------------------------------------------------------------------------- /data/ntu120/get_raw_skes_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import os.path as osp 4 | import os 5 | import numpy as np 6 | import pickle 7 | import logging 8 | 9 | 10 | def get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger): 11 | """ 12 | Get raw bodies data from a skeleton sequence. 13 | 14 | Each body's data is a dict that contains the following keys: 15 | - joints: raw 3D joints positions. Shape: (num_frames x 25, 3) 16 | - colors: raw 2D color locations. Shape: (num_frames, 25, 2) 17 | - interval: a list which stores the frame indices of this body. 18 | - motion: motion amount (only for the sequence with 2 or more bodyIDs). 19 | 20 | Return: 21 | a dict for a skeleton sequence with 3 key-value pairs: 22 | - name: the skeleton filename. 23 | - data: a dict which stores raw data of each body. 24 | - num_frames: the number of valid frames. 25 | """ 26 | if int(ske_name[1:4]) >= 18: 27 | skes_path = '../nturgbd_raw/nturgb+d_skeletons120/' 28 | ske_file = osp.join(skes_path, ske_name + '.skeleton') 29 | assert osp.exists(ske_file), 'Error: Skeleton file %s not found' % ske_file 30 | # Read all data from .skeleton file into a list (in string format) 31 | print('Reading data from %s' % ske_file[-29:]) 32 | with open(ske_file, 'r') as fr: 33 | str_data = fr.readlines() 34 | 35 | num_frames = int(str_data[0].strip('\r\n')) 36 | frames_drop = [] 37 | bodies_data = dict() 38 | valid_frames = -1 # 0-based index 39 | current_line = 1 40 | 41 | for f in range(num_frames): 42 | num_bodies = int(str_data[current_line].strip('\r\n')) 43 | current_line += 1 44 | 45 | if num_bodies == 0: # no data in this frame, drop it 46 | frames_drop.append(f) # 0-based index 47 | continue 48 | 49 | valid_frames += 1 50 | joints = np.zeros((num_bodies, 25, 3), dtype=np.float32) 51 | colors = np.zeros((num_bodies, 25, 2), dtype=np.float32) 52 | 53 | for b in range(num_bodies): 54 | bodyID = str_data[current_line].strip('\r\n').split()[0] 55 | current_line += 1 56 | num_joints = int(str_data[current_line].strip('\r\n')) # 25 joints 57 | current_line += 1 58 | 59 | for j in range(num_joints): 60 | temp_str = str_data[current_line].strip('\r\n').split() 61 | joints[b, j, :] = np.array(temp_str[:3], dtype=np.float32) 62 | colors[b, j, :] = np.array(temp_str[5:7], dtype=np.float32) 63 | current_line += 1 64 | 65 | if bodyID not in bodies_data: # Add a new body's data 66 | body_data = dict() 67 | body_data['joints'] = joints[b] # ndarray: (25, 3) 68 | body_data['colors'] = colors[b, np.newaxis] # ndarray: (1, 25, 2) 69 | body_data['interval'] = [valid_frames] # the index of the first frame 70 | else: # Update an already existed body's data 71 | body_data = bodies_data[bodyID] 72 | # Stack each body's data of each frame along the frame order 73 | body_data['joints'] = np.vstack((body_data['joints'], joints[b])) 74 | body_data['colors'] = np.vstack((body_data['colors'], colors[b, np.newaxis])) 75 | pre_frame_idx = body_data['interval'][-1] 76 | body_data['interval'].append(pre_frame_idx + 1) # add a new frame index 77 | 78 | bodies_data[bodyID] = body_data # Update bodies_data 79 | 80 | num_frames_drop = len(frames_drop) 81 | assert num_frames_drop < num_frames, \ 82 | 'Error: All frames data (%d) of %s is missing or lost' % (num_frames, ske_name) 83 | if num_frames_drop > 0: 84 | frames_drop_skes[ske_name] = np.array(frames_drop, dtype=np.int) 85 | frames_drop_logger.info('{}: {} frames missed: {}\n'.format(ske_name, num_frames_drop, 86 | frames_drop)) 87 | 88 | # Calculate motion (only for the sequence with 2 or more bodyIDs) 89 | if len(bodies_data) > 1: 90 | for body_data in bodies_data.values(): 91 | body_data['motion'] = np.sum(np.var(body_data['joints'], axis=0)) 92 | 93 | return {'name': ske_name, 'data': bodies_data, 'num_frames': num_frames - num_frames_drop} 94 | 95 | 96 | def get_raw_skes_data(): 97 | # # save_path = './data' 98 | # # skes_path = '/data/pengfei/NTU/nturgb+d_skeletons/' 99 | # stat_path = osp.join(save_path, 'statistics') 100 | # 101 | # skes_name_file = osp.join(stat_path, 'skes_available_name.txt') 102 | # save_data_pkl = osp.join(save_path, 'raw_skes_data.pkl') 103 | # frames_drop_pkl = osp.join(save_path, 'frames_drop_skes.pkl') 104 | # 105 | # frames_drop_logger = logging.getLogger('frames_drop') 106 | # frames_drop_logger.setLevel(logging.INFO) 107 | # frames_drop_logger.addHandler(logging.FileHandler(osp.join(save_path, 'frames_drop.log'))) 108 | # frames_drop_skes = dict() 109 | 110 | skes_name = np.loadtxt(skes_name_file, dtype=str) 111 | 112 | num_files = skes_name.size 113 | print('Found %d available skeleton files.' % num_files) 114 | 115 | raw_skes_data = [] 116 | frames_cnt = np.zeros(num_files, dtype=np.int) 117 | 118 | for (idx, ske_name) in enumerate(skes_name): 119 | bodies_data = get_raw_bodies_data(skes_path, ske_name, frames_drop_skes, frames_drop_logger) 120 | raw_skes_data.append(bodies_data) 121 | frames_cnt[idx] = bodies_data['num_frames'] 122 | if (idx + 1) % 1000 == 0: 123 | print('Processed: %.2f%% (%d / %d)' % \ 124 | (100.0 * (idx + 1) / num_files, idx + 1, num_files)) 125 | 126 | with open(save_data_pkl, 'wb') as fw: 127 | pickle.dump(raw_skes_data, fw, pickle.HIGHEST_PROTOCOL) 128 | np.savetxt(osp.join(save_path, 'raw_data', 'frames_cnt.txt'), frames_cnt, fmt='%d') 129 | 130 | print('Saved raw bodies data into %s' % save_data_pkl) 131 | print('Total frames: %d' % np.sum(frames_cnt)) 132 | 133 | with open(frames_drop_pkl, 'wb') as fw: 134 | pickle.dump(frames_drop_skes, fw, pickle.HIGHEST_PROTOCOL) 135 | 136 | if __name__ == '__main__': 137 | save_path = './' 138 | 139 | skes_path = '../nturgbd_raw/nturgb+d_skeletons/' 140 | stat_path = osp.join(save_path, 'statistics') 141 | if not osp.exists('./raw_data'): 142 | os.makedirs('./raw_data') 143 | 144 | skes_name_file = osp.join(stat_path, 'skes_available_name.txt') 145 | save_data_pkl = osp.join(save_path, 'raw_data', 'raw_skes_data.pkl') 146 | frames_drop_pkl = osp.join(save_path, 'raw_data', 'frames_drop_skes.pkl') 147 | 148 | frames_drop_logger = logging.getLogger('frames_drop') 149 | frames_drop_logger.setLevel(logging.INFO) 150 | frames_drop_logger.addHandler(logging.FileHandler(osp.join(save_path, 'raw_data', 'frames_drop.log'))) 151 | frames_drop_skes = dict() 152 | 153 | get_raw_skes_data() 154 | 155 | with open(frames_drop_pkl, 'wb') as fw: 156 | pickle.dump(frames_drop_skes, fw, pickle.HIGHEST_PROTOCOL) 157 | 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HD-GCN [ICCV 2023] 2 | 3 | Official Implementation of [Hierarchically Decomposed Graph Convolutional Networks for Skeleton-Based Action Recognition](https://arxiv.org/abs/2208.10741). 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchically-decomposed-graph-convolutional/skeleton-based-action-recognition-on-ntu-rgbd)](https://paperswithcode.com/sota/skeleton-based-action-recognition-on-ntu-rgbd?p=hierarchically-decomposed-graph-convolutional) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchically-decomposed-graph-convolutional/skeleton-based-action-recognition-on-ntu-rgbd-1)](https://paperswithcode.com/sota/skeleton-based-action-recognition-on-ntu-rgbd-1?p=hierarchically-decomposed-graph-convolutional) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchically-decomposed-graph-convolutional/skeleton-based-action-recognition-on-n-ucla)](https://paperswithcode.com/sota/skeleton-based-action-recognition-on-n-ucla?p=hierarchically-decomposed-graph-convolutional) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchically-decomposed-graph-convolutional/skeleton-based-action-recognition-on-kinetics)](https://paperswithcode.com/sota/skeleton-based-action-recognition-on-kinetics?p=hierarchically-decomposed-graph-convolutional) 9 | 10 | ![image](figures/framework.PNG) 11 | 12 | # Abstract 13 | 14 | Graph convolutional networks (GCNs) are the most commonly used methods for skeleton-based action recognition and have achieved remarkable performance. Generating adjacency matrices with semantically meaningful edges is particularly important for this task, but extracting such edges is challenging problem. To solve this, we propose a hierarchically decomposed graph convolutional network (HD-GCN) architecture with a novel hierarchically decomposed graph (HD-Graph). The proposed HD-GCN effectively decomposes every joint node into several sets to extract major structurally adjacent and distant edges, and uses them to construct an HD-Graph containing those edges in the same semantic spaces of a human skeleton. In addition, we introduce an attention-guided hierarchy aggregation (A-HA) module to highlight the dominant hierarchical edge sets of the HD-Graph. Furthermore, we apply a new six-way ensemble method, which uses only joint and bone stream without any motion stream. The proposed model is evaluated and achieves state-of-the-art performance on three large, popular datasets: NTU-RGB+D 60, NTU-RGB+D 120, and Northwestern-UCLA. Finally, we demonstrate the effectiveness of our model with various comparative experiments. 15 | 16 | # Dependencies 17 | 18 | - Python >= 3.6 19 | - PyTorch >= 1.10.0 20 | - PyYAML == 5.4.1 21 | - torchpack == 0.2.2 22 | - matplotlib, einops, sklearn, tqdm, tensorboardX, h5py 23 | - Run `pip install -e torchlight` 24 | 25 | # Data Preparation 26 | 27 | ### Download datasets. 28 | 29 | #### There are 3 datasets to download: 30 | 31 | - NTU RGB+D 60 Skeleton 32 | - NTU RGB+D 120 Skeleton 33 | - NW-UCLA 34 | 35 | #### NTU RGB+D 60 and 120 36 | 37 | 1. Request dataset here: https://rose1.ntu.edu.sg/dataset/actionRecognition 38 | 2. Download the skeleton-only datasets: 39 | 1. `nturgbd_skeletons_s001_to_s017.zip` (NTU RGB+D 60) 40 | 2. `nturgbd_skeletons_s018_to_s032.zip` (NTU RGB+D 120) 41 | 3. Extract above files to `./data/nturgbd_raw` 42 | 43 | #### NW-UCLA 44 | 45 | 1. Download dataset from here: https://www.dropbox.com/s/10pcm4pksjy6mkq/all_sqe.zip?dl=0 46 | 2. Move `all_sqe` to `./data/NW-UCLA` 47 | 48 | ### Data Processing 49 | 50 | #### Directory Structure 51 | 52 | Put downloaded data into the following directory structure: 53 | 54 | ``` 55 | - data/ 56 | - NW-UCLA/ 57 | - all_sqe 58 | ... # raw data of NW-UCLA 59 | - ntu/ 60 | - ntu120/ 61 | - nturgbd_raw/ 62 | - nturgb+d_skeletons/ # from `nturgbd_skeletons_s001_to_s017.zip` 63 | ... 64 | - nturgb+d_skeletons120/ # from `nturgbd_skeletons_s018_to_s032.zip` 65 | ... 66 | ``` 67 | 68 | #### Generating Data 69 | 70 | - Generate NTU RGB+D 60 or NTU RGB+D 120 dataset: 71 | 72 | ``` 73 | cd ./data/ntu # or cd ./data/ntu120 74 | # Get skeleton of each performer 75 | python get_raw_skes_data.py 76 | # Remove the bad skeleton 77 | python get_raw_denoised_data.py 78 | # Transform the skeleton to the center of the first frame 79 | python seq_transformation.py 80 | ``` 81 | 82 | # Training & Testing 83 | 84 | ### Training 85 | 86 | - NTU-RGB+D 60 & 120 87 | ``` 88 | # Example: training HD-GCN (joint CoM 1) on NTU RGB+D 60 cross subject with GPU 0 89 | python main.py --config ./config/nturgbd60-cross-subject/joint_com_1.yaml --device 0 90 | 91 | # Example: training HD-GCN (bone CoM 1) on NTU RGB+D 60 cross subject with GPU 0 92 | python main.py --config ./config/nturgbd60-cross-subject/bone_com_1.yaml --device 0 93 | 94 | # Example: training HD-GCN (joint CoM 1) on NTU RGB+D 120 cross subject with GPU 0 95 | python main.py --config ./config/nturgbd120-cross-subject/joint_com_1.yaml --device 0 96 | 97 | # Example: training HD-GCN (bone CoM 1) on NTU RGB+D 120 cross subject with GPU 0 98 | python main.py --config ./config/nturgbd120-cross-subject/bone_com_1.yaml --device 0 99 | ``` 100 | 101 | - To train your own model, put model file `your_model.py` under `./model` and run: 102 | 103 | ``` 104 | # Example: training your own model on NTU RGB+D 120 cross subject 105 | python main.py --config ./config/nturgbd120-cross-subject/your_config.yaml --model model.your_model.Model --work-dir ./work_dir/your_work_dir/ --device 0 106 | ``` 107 | 108 | ### Testing 109 | 110 | - To test the trained models saved in , run the following command: 111 | 112 | ``` 113 | python main.py --config /config.yaml --work-dir --phase test --save-score True --weights /xxx.pt --device 0 114 | ``` 115 | 116 | - To ensemble the results of different modalities, run 117 | ``` 118 | # Example: six-way ensemble for NTU-RGB+D 120 cross-subject 119 | python ensemble.py --datasets ntu120/xsub --main-dir ./work_dir/ntu120/cross-subject/ 120 | ``` 121 | 122 | ### Pretrained Weights 123 | 124 | - Pretrained weights for NTU RGB+D 60 and 120 can be downloaded from the following link [[Google Drive]](https://drive.google.com/drive/folders/1FB_IQdTMWE8cRvwE2KiyxC0P6LyqZku4?usp=sharing). 125 | 126 | ## Acknowledgements 127 | This repo is based on [2s-AGCN](https://github.com/lshiwjx/2s-AGCN) and [CTR-GCN](https://github.com/Uason-Chen/CTR-GCN). The data processing is borrowed from [SGN](https://github.com/microsoft/SGN) and [HCN](https://github.com/huguyuehuhu/HCN-pytorch). 128 | 129 | Thanks to the original authors for their awesome works! 130 | 131 | # Citation 132 | 133 | Please cite this work if you find it useful: 134 | ```BibTex 135 | @InProceedings{Lee_2023_ICCV, 136 | author = {Lee, Jungho and Lee, Minhyeok and Lee, Dogyoon and Lee, Sangyoun}, 137 | title = {Hierarchically Decomposed Graph Convolutional Networks for Skeleton-Based Action Recognition}, 138 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 139 | month = {October}, 140 | year = {2023}, 141 | pages = {10444-10453} 142 | } 143 | ``` 144 | 145 | # Contact 146 | If you have any questions, feel free to contact: 2015142131@yonsei.ac.kr 147 | -------------------------------------------------------------------------------- /torchlight/torchlight/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | import sys 5 | import traceback 6 | import time 7 | import pickle 8 | from collections import OrderedDict 9 | import yaml 10 | import h5py 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | from torchpack.runner.hooks import PaviLogger 17 | 18 | 19 | class IO(): 20 | def __init__(self, work_dir, save_log=True, print_log=True): 21 | self.work_dir = work_dir 22 | self.save_log = save_log 23 | self.print_to_screen = print_log 24 | self.cur_time = time.time() 25 | self.split_timer = {} 26 | self.pavi_logger = None 27 | self.session_file = None 28 | self.model_text = '' 29 | 30 | def log(self, *args, **kwargs): 31 | try: 32 | if self.pavi_logger is None: 33 | url = 'http://pavi.parrotsdnn.org/log' 34 | with open(self.session_file, 'r') as f: 35 | info = dict(session_file=self.session_file, session_text=f.read(), model_text=self.model_text) 36 | self.pavi_logger = PaviLogger(url) 37 | self.pavi_logger.connect(self.work_dir, info=info) 38 | self.pavi_logger.log(*args, **kwargs) 39 | except: #pylint: disable=W0702 40 | pass 41 | 42 | def load_model(self, model, **model_args): 43 | Model = import_class(model) 44 | model = Model(**model_args) 45 | self.model_text += '\n\n' + str(model) 46 | return model 47 | 48 | def load_weights(self, model, weights_path, ignore_weights=None, fix_weights=False): 49 | if ignore_weights is None: 50 | ignore_weights = [] 51 | if isinstance(ignore_weights, str): 52 | ignore_weights = [ignore_weights] 53 | 54 | self.print_log(f'Load weights from {weights_path}.') 55 | weights = torch.load(weights_path) 56 | weights = OrderedDict([[k.split('module.')[-1], v.cpu()] for k, v in weights.items()]) 57 | 58 | # filter weights 59 | for i in ignore_weights: 60 | ignore_name = list() 61 | for w in weights: 62 | if w.find(i) == 0: 63 | ignore_name.append(w) 64 | for n in ignore_name: 65 | weights.pop(n) 66 | self.print_log(f'Filter [{i}] remove weights [{n}].') 67 | 68 | for w in weights: 69 | self.print_log(f'Load weights [{w}].') 70 | 71 | try: 72 | model.load_state_dict(weights) 73 | except (KeyError, RuntimeError): 74 | state = model.state_dict() 75 | diff = list(set(state.keys()).difference(set(weights.keys()))) 76 | for d in diff: 77 | self.print_log(f'Can not find weights [{d}].') 78 | state.update(weights) 79 | model.load_state_dict(state) 80 | 81 | if fix_weights: 82 | for name, param in model.named_parameters(): 83 | if name in weights.keys(): 84 | param.requires_grad = False 85 | self.print_log(f'Fix weights [{name}].') 86 | 87 | return model 88 | 89 | def save_pkl(self, result, filename): 90 | with open(f'{self.work_dir}/{filename}', 'wb') as f: 91 | pickle.dump(result, f) 92 | 93 | def save_h5(self, result, filename, append=False): 94 | with h5py.File(f'{self.work_dir}/{filename}', 'a' if append else 'w') as f: 95 | for k in result.keys(): 96 | f[k] = result[k] 97 | 98 | def save_model(self, model, name): 99 | model_path = f'{self.work_dir}/{name}' 100 | # symlink = f'{self.work_dir}/latest_model.pt' 101 | state_dict = model.state_dict() 102 | weights = OrderedDict([[''.join(k.split('module.')), v.cpu()] for k, v in state_dict.items()]) 103 | torch.save(weights, model_path) 104 | # os.symlink(model_path, symlink) 105 | self.print_log(f'The model has been saved as {model_path}.') 106 | 107 | def save_arg(self, arg): 108 | 109 | self.session_file = f'{self.work_dir}/config.yaml' 110 | 111 | # save arg 112 | arg_dict = vars(arg) 113 | if not os.path.exists(self.work_dir): 114 | os.makedirs(self.work_dir) 115 | with open(self.session_file, 'w') as f: 116 | f.write(f"# command line: {' '.join(sys.argv)}\n\n") 117 | yaml.dump(arg_dict, f, default_flow_style=False, indent=4) 118 | 119 | def print_log(self, str, print_time=True): 120 | if print_time: 121 | # localtime = time.asctime(time.localtime(time.time())) 122 | str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str 123 | 124 | if self.print_to_screen: 125 | print(str) 126 | if self.save_log: 127 | with open(f'{self.work_dir}/log.txt', 'a') as f: 128 | print(str, file=f) 129 | 130 | def init_timer(self, *name): 131 | self.record_time() 132 | self.split_timer = {k: 0.0000001 for k in name} 133 | 134 | def check_time(self, name): 135 | self.split_timer[name] += self.split_time() 136 | 137 | def record_time(self): 138 | self.cur_time = time.time() 139 | return self.cur_time 140 | 141 | def split_time(self): 142 | split_time = time.time() - self.cur_time 143 | self.record_time() 144 | return split_time 145 | 146 | def print_timer(self): 147 | proportion = { 148 | k: f'{int(round(v * 100 / sum(self.split_timer.values()))):02d}%' 149 | for k, v in self.split_timer.items() 150 | } 151 | self.print_log(f'Time consumption:') 152 | for k in proportion: 153 | self.print_log(f'\t[{k}][{proportion[k]}]: {self.split_timer[k]:.4f}') 154 | 155 | 156 | def str2bool(v): 157 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 158 | return True 159 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 160 | return False 161 | else: 162 | raise argparse.ArgumentTypeError('Boolean value expected.') 163 | 164 | 165 | def str2dict(v): 166 | return eval(f'dict({v})') #pylint: disable=W0123 167 | 168 | 169 | def _import_class_0(name): 170 | components = name.split('.') 171 | mod = __import__(components[0]) 172 | for comp in components[1:]: 173 | mod = getattr(mod, comp) 174 | return mod 175 | 176 | 177 | def import_class(import_str): 178 | mod_str, _sep, class_str = import_str.rpartition('.') 179 | __import__(mod_str) 180 | try: 181 | return getattr(sys.modules[mod_str], class_str) 182 | except AttributeError: 183 | raise ImportError('Class %s cannot be found (%s)' % (class_str, traceback.format_exception(*sys.exc_info()))) 184 | 185 | 186 | class DictAction(argparse.Action): 187 | def __init__(self, option_strings, dest, nargs=None, **kwargs): 188 | if nargs is not None: 189 | raise ValueError("nargs not allowed") 190 | super(DictAction, self).__init__(option_strings, dest, **kwargs) 191 | 192 | def __call__(self, parser, namespace, values, option_string=None): 193 | input_dict = eval(f'dict({values})') #pylint: disable=W0123 194 | output_dict = getattr(namespace, self.dest) 195 | for k in input_dict: 196 | output_dict[k] = input_dict[k] 197 | setattr(namespace, self.dest, output_dict) 198 | -------------------------------------------------------------------------------- /data/ntu/statistics/samples_with_missing_skeletons.txt: -------------------------------------------------------------------------------- 1 | S001C002P005R002A008 2 | S001C002P006R001A008 3 | S001C003P002R001A055 4 | S001C003P002R002A012 5 | S001C003P005R002A004 6 | S001C003P005R002A005 7 | S001C003P005R002A006 8 | S001C003P006R002A008 9 | S002C002P011R002A030 10 | S002C003P008R001A020 11 | S002C003P010R002A010 12 | S002C003P011R002A007 13 | S002C003P011R002A011 14 | S002C003P014R002A007 15 | S003C001P019R001A055 16 | S003C002P002R002A055 17 | S003C002P018R002A055 18 | S003C003P002R001A055 19 | S003C003P016R001A055 20 | S003C003P018R002A024 21 | S004C002P003R001A013 22 | S004C002P008R001A009 23 | S004C002P020R001A003 24 | S004C002P020R001A004 25 | S004C002P020R001A012 26 | S004C002P020R001A020 27 | S004C002P020R001A021 28 | S004C002P020R001A036 29 | S005C002P004R001A001 30 | S005C002P004R001A003 31 | S005C002P010R001A016 32 | S005C002P010R001A017 33 | S005C002P010R001A048 34 | S005C002P010R001A049 35 | S005C002P016R001A009 36 | S005C002P016R001A010 37 | S005C002P018R001A003 38 | S005C002P018R001A028 39 | S005C002P018R001A029 40 | S005C003P016R002A009 41 | S005C003P018R002A013 42 | S005C003P021R002A057 43 | S006C001P001R002A055 44 | S006C002P007R001A005 45 | S006C002P007R001A006 46 | S006C002P016R001A043 47 | S006C002P016R001A051 48 | S006C002P016R001A052 49 | S006C002P022R001A012 50 | S006C002P023R001A020 51 | S006C002P023R001A021 52 | S006C002P023R001A022 53 | S006C002P023R001A023 54 | S006C002P024R001A018 55 | S006C002P024R001A019 56 | S006C003P001R002A013 57 | S006C003P007R002A009 58 | S006C003P007R002A010 59 | S006C003P007R002A025 60 | S006C003P016R001A060 61 | S006C003P017R001A055 62 | S006C003P017R002A013 63 | S006C003P017R002A014 64 | S006C003P017R002A015 65 | S006C003P022R002A013 66 | S007C001P018R002A050 67 | S007C001P025R002A051 68 | S007C001P028R001A050 69 | S007C001P028R001A051 70 | S007C001P028R001A052 71 | S007C002P008R002A008 72 | S007C002P015R002A055 73 | S007C002P026R001A008 74 | S007C002P026R001A009 75 | S007C002P026R001A010 76 | S007C002P026R001A011 77 | S007C002P026R001A012 78 | S007C002P026R001A050 79 | S007C002P027R001A011 80 | S007C002P027R001A013 81 | S007C002P028R002A055 82 | S007C003P007R001A002 83 | S007C003P007R001A004 84 | S007C003P019R001A060 85 | S007C003P027R002A001 86 | S007C003P027R002A002 87 | S007C003P027R002A003 88 | S007C003P027R002A004 89 | S007C003P027R002A005 90 | S007C003P027R002A006 91 | S007C003P027R002A007 92 | S007C003P027R002A008 93 | S007C003P027R002A009 94 | S007C003P027R002A010 95 | S007C003P027R002A011 96 | S007C003P027R002A012 97 | S007C003P027R002A013 98 | S008C002P001R001A009 99 | S008C002P001R001A010 100 | S008C002P001R001A014 101 | S008C002P001R001A015 102 | S008C002P001R001A016 103 | S008C002P001R001A018 104 | S008C002P001R001A019 105 | S008C002P008R002A059 106 | S008C002P025R001A060 107 | S008C002P029R001A004 108 | S008C002P031R001A005 109 | S008C002P031R001A006 110 | S008C002P032R001A018 111 | S008C002P034R001A018 112 | S008C002P034R001A019 113 | S008C002P035R001A059 114 | S008C002P035R002A002 115 | S008C002P035R002A005 116 | S008C003P007R001A009 117 | S008C003P007R001A016 118 | S008C003P007R001A017 119 | S008C003P007R001A018 120 | S008C003P007R001A019 121 | S008C003P007R001A020 122 | S008C003P007R001A021 123 | S008C003P007R001A022 124 | S008C003P007R001A023 125 | S008C003P007R001A025 126 | S008C003P007R001A026 127 | S008C003P007R001A028 128 | S008C003P007R001A029 129 | S008C003P007R002A003 130 | S008C003P008R002A050 131 | S008C003P025R002A002 132 | S008C003P025R002A011 133 | S008C003P025R002A012 134 | S008C003P025R002A016 135 | S008C003P025R002A020 136 | S008C003P025R002A022 137 | S008C003P025R002A023 138 | S008C003P025R002A030 139 | S008C003P025R002A031 140 | S008C003P025R002A032 141 | S008C003P025R002A033 142 | S008C003P025R002A049 143 | S008C003P025R002A060 144 | S008C003P031R001A001 145 | S008C003P031R002A004 146 | S008C003P031R002A014 147 | S008C003P031R002A015 148 | S008C003P031R002A016 149 | S008C003P031R002A017 150 | S008C003P032R002A013 151 | S008C003P033R002A001 152 | S008C003P033R002A011 153 | S008C003P033R002A012 154 | S008C003P034R002A001 155 | S008C003P034R002A012 156 | S008C003P034R002A022 157 | S008C003P034R002A023 158 | S008C003P034R002A024 159 | S008C003P034R002A044 160 | S008C003P034R002A045 161 | S008C003P035R002A016 162 | S008C003P035R002A017 163 | S008C003P035R002A018 164 | S008C003P035R002A019 165 | S008C003P035R002A020 166 | S008C003P035R002A021 167 | S009C002P007R001A001 168 | S009C002P007R001A003 169 | S009C002P007R001A014 170 | S009C002P008R001A014 171 | S009C002P015R002A050 172 | S009C002P016R001A002 173 | S009C002P017R001A028 174 | S009C002P017R001A029 175 | S009C003P017R002A030 176 | S009C003P025R002A054 177 | S010C001P007R002A020 178 | S010C002P016R002A055 179 | S010C002P017R001A005 180 | S010C002P017R001A018 181 | S010C002P017R001A019 182 | S010C002P019R001A001 183 | S010C002P025R001A012 184 | S010C003P007R002A043 185 | S010C003P008R002A003 186 | S010C003P016R001A055 187 | S010C003P017R002A055 188 | S011C001P002R001A008 189 | S011C001P018R002A050 190 | S011C002P008R002A059 191 | S011C002P016R002A055 192 | S011C002P017R001A020 193 | S011C002P017R001A021 194 | S011C002P018R002A055 195 | S011C002P027R001A009 196 | S011C002P027R001A010 197 | S011C002P027R001A037 198 | S011C003P001R001A055 199 | S011C003P002R001A055 200 | S011C003P008R002A012 201 | S011C003P015R001A055 202 | S011C003P016R001A055 203 | S011C003P019R001A055 204 | S011C003P025R001A055 205 | S011C003P028R002A055 206 | S012C001P019R001A060 207 | S012C001P019R002A060 208 | S012C002P015R001A055 209 | S012C002P017R002A012 210 | S012C002P025R001A060 211 | S012C003P008R001A057 212 | S012C003P015R001A055 213 | S012C003P015R002A055 214 | S012C003P016R001A055 215 | S012C003P017R002A055 216 | S012C003P018R001A055 217 | S012C003P018R001A057 218 | S012C003P019R002A011 219 | S012C003P019R002A012 220 | S012C003P025R001A055 221 | S012C003P027R001A055 222 | S012C003P027R002A009 223 | S012C003P028R001A035 224 | S012C003P028R002A055 225 | S013C001P015R001A054 226 | S013C001P017R002A054 227 | S013C001P018R001A016 228 | S013C001P028R001A040 229 | S013C002P015R001A054 230 | S013C002P017R002A054 231 | S013C002P028R001A040 232 | S013C003P008R002A059 233 | S013C003P015R001A054 234 | S013C003P017R002A054 235 | S013C003P025R002A022 236 | S013C003P027R001A055 237 | S013C003P028R001A040 238 | S014C001P027R002A040 239 | S014C002P015R001A003 240 | S014C002P019R001A029 241 | S014C002P025R002A059 242 | S014C002P027R002A040 243 | S014C002P039R001A050 244 | S014C003P007R002A059 245 | S014C003P015R002A055 246 | S014C003P019R002A055 247 | S014C003P025R001A048 248 | S014C003P027R002A040 249 | S015C001P008R002A040 250 | S015C001P016R001A055 251 | S015C001P017R001A055 252 | S015C001P017R002A055 253 | S015C002P007R001A059 254 | S015C002P008R001A003 255 | S015C002P008R001A004 256 | S015C002P008R002A040 257 | S015C002P015R001A002 258 | S015C002P016R001A001 259 | S015C002P016R002A055 260 | S015C003P008R002A007 261 | S015C003P008R002A011 262 | S015C003P008R002A012 263 | S015C003P008R002A028 264 | S015C003P008R002A040 265 | S015C003P025R002A012 266 | S015C003P025R002A017 267 | S015C003P025R002A020 268 | S015C003P025R002A021 269 | S015C003P025R002A030 270 | S015C003P025R002A033 271 | S015C003P025R002A034 272 | S015C003P025R002A036 273 | S015C003P025R002A037 274 | S015C003P025R002A044 275 | S016C001P019R002A040 276 | S016C001P025R001A011 277 | S016C001P025R001A012 278 | S016C001P025R001A060 279 | S016C001P040R001A055 280 | S016C001P040R002A055 281 | S016C002P008R001A011 282 | S016C002P019R002A040 283 | S016C002P025R002A012 284 | S016C003P008R001A011 285 | S016C003P008R002A002 286 | S016C003P008R002A003 287 | S016C003P008R002A004 288 | S016C003P008R002A006 289 | S016C003P008R002A009 290 | S016C003P019R002A040 291 | S016C003P039R002A016 292 | S017C001P016R002A031 293 | S017C002P007R001A013 294 | S017C002P008R001A009 295 | S017C002P015R001A042 296 | S017C002P016R002A031 297 | S017C002P016R002A055 298 | S017C003P007R002A013 299 | S017C003P008R001A059 300 | S017C003P016R002A031 301 | S017C003P017R001A055 302 | S017C003P020R001A059 303 | -------------------------------------------------------------------------------- /feeders/tools.py: -------------------------------------------------------------------------------- 1 | import random 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pdb 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | def valid_crop_resize(data_numpy,valid_frame_num,p_interval,window): 10 | # input: C,T,V,M 11 | C, T, V, M = data_numpy.shape 12 | begin = 0 13 | end = valid_frame_num 14 | valid_size = end - begin 15 | 16 | #crop 17 | if len(p_interval) == 1: 18 | p = p_interval[0] 19 | bias = int((1-p) * valid_size/2) 20 | data = data_numpy[:, begin+bias:end-bias, :, :]# center_crop 21 | cropped_length = data.shape[1] 22 | else: 23 | p = np.random.rand(1)*(p_interval[1]-p_interval[0])+p_interval[0] 24 | cropped_length = np.minimum(np.maximum(int(np.floor(valid_size*p)),64), valid_size)# constraint cropped_length lower bound as 64 25 | bias = np.random.randint(0,valid_size-cropped_length+1) 26 | data = data_numpy[:, begin+bias:begin+bias+cropped_length, :, :] 27 | if data.shape[1] == 0: 28 | print(cropped_length, bias, valid_size) 29 | 30 | # resize 31 | data = torch.tensor(data,dtype=torch.float) 32 | data = data.permute(0, 2, 3, 1).contiguous().view(C * V * M, cropped_length) 33 | data = data[None, None, :, :] 34 | data = F.interpolate(data, size=(C * V * M, window), mode='bilinear',align_corners=False).squeeze() # could perform both up sample and down sample 35 | data = data.contiguous().view(C, V, M, window).permute(0, 3, 1, 2).contiguous().numpy() 36 | 37 | return data 38 | 39 | def downsample(data_numpy, step, random_sample=True): 40 | # input: C,T,V,M 41 | begin = np.random.randint(step) if random_sample else 0 42 | return data_numpy[:, begin::step, :, :] 43 | 44 | 45 | def temporal_slice(data_numpy, step): 46 | # input: C,T,V,M 47 | C, T, V, M = data_numpy.shape 48 | return data_numpy.reshape(C, T / step, step, V, M).transpose( 49 | (0, 1, 3, 2, 4)).reshape(C, T / step, V, step * M) 50 | 51 | 52 | def mean_subtractor(data_numpy, mean): 53 | # input: C,T,V,M 54 | # naive version 55 | if mean == 0: 56 | return 57 | C, T, V, M = data_numpy.shape 58 | valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 59 | begin = valid_frame.argmax() 60 | end = len(valid_frame) - valid_frame[::-1].argmax() 61 | data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean 62 | return data_numpy 63 | 64 | 65 | def auto_pading(data_numpy, size, random_pad=False): 66 | C, T, V, M = data_numpy.shape 67 | if T < size: 68 | begin = random.randint(0, size - T) if random_pad else 0 69 | data_numpy_paded = np.zeros((C, size, V, M)) 70 | data_numpy_paded[:, begin:begin + T, :, :] = data_numpy 71 | return data_numpy_paded 72 | else: 73 | return data_numpy 74 | 75 | 76 | def random_choose(data_numpy, size, auto_pad=True): 77 | # input: C,T,V,M 随机选择其中一段,不是很合理。因为有0 78 | C, T, V, M = data_numpy.shape 79 | if T == size: 80 | return data_numpy 81 | elif T < size: 82 | if auto_pad: 83 | return auto_pading(data_numpy, size, random_pad=True) 84 | else: 85 | return data_numpy 86 | else: 87 | begin = random.randint(0, T - size) 88 | return data_numpy[:, begin:begin + size, :, :] 89 | 90 | def random_move(data_numpy, 91 | angle_candidate=[-10., -5., 0., 5., 10.], 92 | scale_candidate=[0.9, 1.0, 1.1], 93 | transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2], 94 | move_time_candidate=[1]): 95 | # input: C,T,V,M 96 | C, T, V, M = data_numpy.shape 97 | move_time = random.choice(move_time_candidate) 98 | node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) 99 | node = np.append(node, T) 100 | num_node = len(node) 101 | 102 | A = np.random.choice(angle_candidate, num_node) 103 | S = np.random.choice(scale_candidate, num_node) 104 | T_x = np.random.choice(transform_candidate, num_node) 105 | T_y = np.random.choice(transform_candidate, num_node) 106 | 107 | a = np.zeros(T) 108 | s = np.zeros(T) 109 | t_x = np.zeros(T) 110 | t_y = np.zeros(T) 111 | 112 | # linspace 113 | for i in range(num_node - 1): 114 | a[node[i]:node[i + 1]] = np.linspace( 115 | A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 116 | s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], 117 | node[i + 1] - node[i]) 118 | t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], 119 | node[i + 1] - node[i]) 120 | t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], 121 | node[i + 1] - node[i]) 122 | 123 | theta = np.array([[np.cos(a) * s, -np.sin(a) * s], 124 | [np.sin(a) * s, np.cos(a) * s]]) 125 | 126 | # perform transformation 127 | for i_frame in range(T): 128 | xy = data_numpy[0:2, i_frame, :, :] 129 | new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) 130 | new_xy[0] += t_x[i_frame] 131 | new_xy[1] += t_y[i_frame] 132 | data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) 133 | 134 | return data_numpy 135 | 136 | 137 | def random_shift(data_numpy): 138 | C, T, V, M = data_numpy.shape 139 | data_shift = np.zeros(data_numpy.shape) 140 | valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 141 | begin = valid_frame.argmax() 142 | end = len(valid_frame) - valid_frame[::-1].argmax() 143 | 144 | size = end - begin 145 | bias = random.randint(0, T - size) 146 | data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :] 147 | 148 | return data_shift 149 | 150 | 151 | def _rot(rot): 152 | """ 153 | rot: T,3 154 | """ 155 | cos_r, sin_r = rot.cos(), rot.sin() # T,3 156 | zeros = torch.zeros(rot.shape[0], 1) # T,1 157 | ones = torch.ones(rot.shape[0], 1) # T,1 158 | 159 | r1 = torch.stack((ones, zeros, zeros),dim=-1) # T,1,3 160 | rx2 = torch.stack((zeros, cos_r[:,0:1], sin_r[:,0:1]), dim = -1) # T,1,3 161 | rx3 = torch.stack((zeros, -sin_r[:,0:1], cos_r[:,0:1]), dim = -1) # T,1,3 162 | rx = torch.cat((r1, rx2, rx3), dim = 1) # T,3,3 163 | 164 | ry1 = torch.stack((cos_r[:,1:2], zeros, -sin_r[:,1:2]), dim =-1) 165 | r2 = torch.stack((zeros, ones, zeros),dim=-1) 166 | ry3 = torch.stack((sin_r[:,1:2], zeros, cos_r[:,1:2]), dim =-1) 167 | ry = torch.cat((ry1, r2, ry3), dim = 1) 168 | 169 | rz1 = torch.stack((cos_r[:,2:3], sin_r[:,2:3], zeros), dim =-1) 170 | r3 = torch.stack((zeros, zeros, ones),dim=-1) 171 | rz2 = torch.stack((-sin_r[:,2:3], cos_r[:,2:3],zeros), dim =-1) 172 | rz = torch.cat((rz1, rz2, r3), dim = 1) 173 | 174 | rot = rz.matmul(ry).matmul(rx) 175 | return rot 176 | 177 | 178 | def random_rot(data_numpy, theta=0.3): 179 | """ 180 | data_numpy: C,T,V,M 181 | """ 182 | data_torch = torch.from_numpy(data_numpy) 183 | C, T, V, M = data_torch.shape 184 | data_torch = data_torch.permute(1, 0, 2, 3).contiguous().view(T, C, V*M) # T,3,V*M 185 | rot = torch.zeros(3).uniform_(-theta, theta) 186 | rot = torch.stack([rot, ] * T, dim=0) 187 | rot = _rot(rot) # T,3,3 188 | data_torch = torch.matmul(rot, data_torch) 189 | data_torch = data_torch.view(T, C, V, M).permute(1, 0, 2, 3).contiguous() 190 | 191 | return data_torch 192 | 193 | def openpose_match(data_numpy): 194 | C, T, V, M = data_numpy.shape 195 | assert (C == 3) 196 | score = data_numpy[2, :, :, :].sum(axis=1) 197 | # the rank of body confidence in each frame (shape: T-1, M) 198 | rank = (-score[0:T - 1]).argsort(axis=1).reshape(T - 1, M) 199 | 200 | # data of frame 1 201 | xy1 = data_numpy[0:2, 0:T - 1, :, :].reshape(2, T - 1, V, M, 1) 202 | # data of frame 2 203 | xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M) 204 | # square of distance between frame 1&2 (shape: T-1, M, M) 205 | distance = ((xy2 - xy1) ** 2).sum(axis=2).sum(axis=0) 206 | 207 | # match pose 208 | forward_map = np.zeros((T, M), dtype=int) - 1 209 | forward_map[0] = range(M) 210 | for m in range(M): 211 | choose = (rank == m) 212 | forward = distance[choose].argmin(axis=1) 213 | for t in range(T - 1): 214 | distance[t, :, forward[t]] = np.inf 215 | forward_map[1:][choose] = forward 216 | assert (np.all(forward_map >= 0)) 217 | 218 | # string data 219 | for t in range(T - 1): 220 | forward_map[t + 1] = forward_map[t + 1][forward_map[t]] 221 | 222 | # generate data 223 | new_data_numpy = np.zeros(data_numpy.shape) 224 | for t in range(T): 225 | new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[ 226 | t]].transpose(1, 2, 0) 227 | data_numpy = new_data_numpy 228 | 229 | # score sort 230 | trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0) 231 | rank = (-trace_score).argsort() 232 | data_numpy = data_numpy[:, :, :, rank] 233 | 234 | return data_numpy 235 | -------------------------------------------------------------------------------- /data/ntu/seq_transformation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import pickle 7 | import logging 8 | import h5py 9 | from sklearn.model_selection import train_test_split 10 | 11 | root_path = './' 12 | stat_path = osp.join(root_path, 'statistics') 13 | setup_file = osp.join(stat_path, 'setup.txt') 14 | camera_file = osp.join(stat_path, 'camera.txt') 15 | performer_file = osp.join(stat_path, 'performer.txt') 16 | replication_file = osp.join(stat_path, 'replication.txt') 17 | label_file = osp.join(stat_path, 'label.txt') 18 | skes_name_file = osp.join(stat_path, 'skes_available_name.txt') 19 | 20 | denoised_path = osp.join(root_path, 'denoised_data') 21 | raw_skes_joints_pkl = osp.join(denoised_path, 'raw_denoised_joints.pkl') 22 | frames_file = osp.join(denoised_path, 'frames_cnt.txt') 23 | 24 | save_path = './' 25 | 26 | 27 | if not osp.exists(save_path): 28 | os.mkdir(save_path) 29 | 30 | 31 | def remove_nan_frames(ske_name, ske_joints, nan_logger): 32 | num_frames = ske_joints.shape[0] 33 | valid_frames = [] 34 | 35 | for f in range(num_frames): 36 | if not np.any(np.isnan(ske_joints[f])): 37 | valid_frames.append(f) 38 | else: 39 | nan_indices = np.where(np.isnan(ske_joints[f]))[0] 40 | nan_logger.info('{}\t{:^5}\t{}'.format(ske_name, f + 1, nan_indices)) 41 | 42 | return ske_joints[valid_frames] 43 | 44 | def seq_translation(skes_joints): 45 | for idx, ske_joints in enumerate(skes_joints): 46 | num_frames = ske_joints.shape[0] 47 | num_bodies = 1 if ske_joints.shape[1] == 75 else 2 48 | if num_bodies == 2: 49 | missing_frames_1 = np.where(ske_joints[:, :75].sum(axis=1) == 0)[0] 50 | missing_frames_2 = np.where(ske_joints[:, 75:].sum(axis=1) == 0)[0] 51 | cnt1 = len(missing_frames_1) 52 | cnt2 = len(missing_frames_2) 53 | 54 | i = 0 # get the "real" first frame of actor1 55 | while i < num_frames: 56 | if np.any(ske_joints[i, :75] != 0): 57 | break 58 | i += 1 59 | 60 | origin = np.copy(ske_joints[i, 3:6]) # new origin: joint-2 61 | 62 | for f in range(num_frames): 63 | if num_bodies == 1: 64 | ske_joints[f] -= np.tile(origin, 25) 65 | else: # for 2 actors 66 | ske_joints[f] -= np.tile(origin, 50) 67 | 68 | if (num_bodies == 2) and (cnt1 > 0): 69 | ske_joints[missing_frames_1, :75] = np.zeros((cnt1, 75), dtype=np.float32) 70 | 71 | if (num_bodies == 2) and (cnt2 > 0): 72 | ske_joints[missing_frames_2, 75:] = np.zeros((cnt2, 75), dtype=np.float32) 73 | 74 | skes_joints[idx] = ske_joints # Update 75 | 76 | return skes_joints 77 | 78 | 79 | def frame_translation(skes_joints, skes_name, frames_cnt): 80 | nan_logger = logging.getLogger('nan_skes') 81 | nan_logger.setLevel(logging.INFO) 82 | nan_logger.addHandler(logging.FileHandler("./nan_frames.log")) 83 | nan_logger.info('{}\t{}\t{}'.format('Skeleton', 'Frame', 'Joints')) 84 | 85 | for idx, ske_joints in enumerate(skes_joints): 86 | num_frames = ske_joints.shape[0] 87 | # Calculate the distance between spine base (joint-1) and spine (joint-21) 88 | j1 = ske_joints[:, 0:3] 89 | j21 = ske_joints[:, 60:63] 90 | dist = np.sqrt(((j1 - j21) ** 2).sum(axis=1)) 91 | 92 | for f in range(num_frames): 93 | origin = ske_joints[f, 3:6] # new origin: middle of the spine (joint-2) 94 | if (ske_joints[f, 75:] == 0).all(): 95 | ske_joints[f, :75] = (ske_joints[f, :75] - np.tile(origin, 25)) / \ 96 | dist[f] + np.tile(origin, 25) 97 | else: 98 | ske_joints[f] = (ske_joints[f] - np.tile(origin, 50)) / \ 99 | dist[f] + np.tile(origin, 50) 100 | 101 | ske_name = skes_name[idx] 102 | ske_joints = remove_nan_frames(ske_name, ske_joints, nan_logger) 103 | frames_cnt[idx] = num_frames # update valid number of frames 104 | skes_joints[idx] = ske_joints 105 | 106 | return skes_joints, frames_cnt 107 | 108 | 109 | def align_frames(skes_joints, frames_cnt): 110 | """ 111 | Align all sequences with the same frame length. 112 | 113 | """ 114 | num_skes = len(skes_joints) 115 | max_num_frames = frames_cnt.max() # 300 116 | aligned_skes_joints = np.zeros((num_skes, max_num_frames, 150), dtype=np.float32) 117 | 118 | for idx, ske_joints in enumerate(skes_joints): 119 | num_frames = ske_joints.shape[0] 120 | num_bodies = 1 if ske_joints.shape[1] == 75 else 2 121 | if num_bodies == 1: 122 | aligned_skes_joints[idx, :num_frames] = np.hstack((ske_joints, 123 | np.zeros_like(ske_joints))) 124 | else: 125 | aligned_skes_joints[idx, :num_frames] = ske_joints 126 | 127 | return aligned_skes_joints 128 | 129 | 130 | def one_hot_vector(labels): 131 | num_skes = len(labels) 132 | labels_vector = np.zeros((num_skes, 60)) 133 | for idx, l in enumerate(labels): 134 | labels_vector[idx, l] = 1 135 | 136 | return labels_vector 137 | 138 | 139 | def split_train_val(train_indices, method='sklearn', ratio=0.05): 140 | """ 141 | Get validation set by splitting data randomly from training set with two methods. 142 | In fact, I thought these two methods are equal as they got the same performance. 143 | 144 | """ 145 | if method == 'sklearn': 146 | return train_test_split(train_indices, test_size=ratio, random_state=10000) 147 | else: 148 | np.random.seed(10000) 149 | np.random.shuffle(train_indices) 150 | val_num_skes = int(np.ceil(0.05 * len(train_indices))) 151 | val_indices = train_indices[:val_num_skes] 152 | train_indices = train_indices[val_num_skes:] 153 | return train_indices, val_indices 154 | 155 | 156 | def split_dataset(skes_joints, label, performer, camera, evaluation, save_path): 157 | train_indices, test_indices = get_indices(performer, camera, evaluation) 158 | m = 'sklearn' # 'sklearn' or 'numpy' 159 | # Select validation set from training set 160 | # train_indices, val_indices = split_train_val(train_indices, m) 161 | 162 | # Save labels and num_frames for each sequence of each data set 163 | train_labels = label[train_indices] 164 | test_labels = label[test_indices] 165 | 166 | train_x = skes_joints[train_indices] 167 | train_y = one_hot_vector(train_labels) 168 | test_x = skes_joints[test_indices] 169 | test_y = one_hot_vector(test_labels) 170 | 171 | save_name = 'NTU60_%s.npz' % evaluation 172 | np.savez(save_name, x_train=train_x, y_train=train_y, x_test=test_x, y_test=test_y) 173 | 174 | # Save data into a .h5 file 175 | # h5file = h5py.File(osp.join(save_path, 'NTU_%s.h5' % (evaluation)), 'w') 176 | # Training set 177 | # h5file.create_dataset('x', data=skes_joints[train_indices]) 178 | # train_one_hot_labels = one_hot_vector(train_labels) 179 | # h5file.create_dataset('y', data=train_one_hot_labels) 180 | # Validation set 181 | # h5file.create_dataset('valid_x', data=skes_joints[val_indices]) 182 | # val_one_hot_labels = one_hot_vector(val_labels) 183 | # h5file.create_dataset('valid_y', data=val_one_hot_labels) 184 | # Test set 185 | # h5file.create_dataset('test_x', data=skes_joints[test_indices]) 186 | # test_one_hot_labels = one_hot_vector(test_labels) 187 | # h5file.create_dataset('test_y', data=test_one_hot_labels) 188 | 189 | # h5file.close() 190 | 191 | 192 | def get_indices(performer, camera, evaluation='CS'): 193 | test_indices = np.empty(0) 194 | train_indices = np.empty(0) 195 | 196 | if evaluation == 'CS': # Cross Subject (Subject IDs) 197 | train_ids = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 198 | 17, 18, 19, 25, 27, 28, 31, 34, 35, 38] 199 | test_ids = [3, 6, 7, 10, 11, 12, 20, 21, 22, 23, 200 | 24, 26, 29, 30, 32, 33, 36, 37, 39, 40] 201 | 202 | # Get indices of test data 203 | for idx in test_ids: 204 | temp = np.where(performer == idx)[0] # 0-based index 205 | test_indices = np.hstack((test_indices, temp)).astype(np.int) 206 | 207 | # Get indices of training data 208 | for train_id in train_ids: 209 | temp = np.where(performer == train_id)[0] # 0-based index 210 | train_indices = np.hstack((train_indices, temp)).astype(np.int) 211 | else: # Cross View (Camera IDs) 212 | train_ids = [2, 3] 213 | test_ids = 1 214 | # Get indices of test data 215 | temp = np.where(camera == test_ids)[0] # 0-based index 216 | test_indices = np.hstack((test_indices, temp)).astype(np.int) 217 | 218 | # Get indices of training data 219 | for train_id in train_ids: 220 | temp = np.where(camera == train_id)[0] # 0-based index 221 | train_indices = np.hstack((train_indices, temp)).astype(np.int) 222 | 223 | return train_indices, test_indices 224 | 225 | 226 | if __name__ == '__main__': 227 | camera = np.loadtxt(camera_file, dtype=np.int) # camera id: 1, 2, 3 228 | performer = np.loadtxt(performer_file, dtype=np.int) # subject id: 1~40 229 | label = np.loadtxt(label_file, dtype=np.int) - 1 # action label: 0~59 230 | 231 | frames_cnt = np.loadtxt(frames_file, dtype=np.int) # frames_cnt 232 | skes_name = np.loadtxt(skes_name_file, dtype=np.string_) 233 | 234 | with open(raw_skes_joints_pkl, 'rb') as fr: 235 | skes_joints = pickle.load(fr) # a list 236 | 237 | skes_joints = seq_translation(skes_joints) 238 | 239 | skes_joints = align_frames(skes_joints, frames_cnt) # aligned to the same frame length 240 | 241 | evaluations = ['CS', 'CV'] 242 | for evaluation in evaluations: 243 | split_dataset(skes_joints, label, performer, camera, evaluation, save_path) 244 | -------------------------------------------------------------------------------- /data/ntu120/seq_transformation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import pickle 7 | import logging 8 | import h5py 9 | from sklearn.model_selection import train_test_split 10 | 11 | root_path = './' 12 | stat_path = osp.join(root_path, 'statistics') 13 | setup_file = osp.join(stat_path, 'setup.txt') 14 | camera_file = osp.join(stat_path, 'camera.txt') 15 | performer_file = osp.join(stat_path, 'performer.txt') 16 | replication_file = osp.join(stat_path, 'replication.txt') 17 | label_file = osp.join(stat_path, 'label.txt') 18 | skes_name_file = osp.join(stat_path, 'skes_available_name.txt') 19 | 20 | denoised_path = osp.join(root_path, 'denoised_data') 21 | raw_skes_joints_pkl = osp.join(denoised_path, 'raw_denoised_joints.pkl') 22 | frames_file = osp.join(denoised_path, 'frames_cnt.txt') 23 | 24 | save_path = './' 25 | 26 | 27 | if not osp.exists(save_path): 28 | os.mkdir(save_path) 29 | 30 | 31 | def remove_nan_frames(ske_name, ske_joints, nan_logger): 32 | num_frames = ske_joints.shape[0] 33 | valid_frames = [] 34 | 35 | for f in range(num_frames): 36 | if not np.any(np.isnan(ske_joints[f])): 37 | valid_frames.append(f) 38 | else: 39 | nan_indices = np.where(np.isnan(ske_joints[f]))[0] 40 | nan_logger.info('{}\t{:^5}\t{}'.format(ske_name, f + 1, nan_indices)) 41 | 42 | return ske_joints[valid_frames] 43 | 44 | def seq_translation(skes_joints): 45 | for idx, ske_joints in enumerate(skes_joints): 46 | num_frames = ske_joints.shape[0] 47 | num_bodies = 1 if ske_joints.shape[1] == 75 else 2 48 | if num_bodies == 2: 49 | missing_frames_1 = np.where(ske_joints[:, :75].sum(axis=1) == 0)[0] 50 | missing_frames_2 = np.where(ske_joints[:, 75:].sum(axis=1) == 0)[0] 51 | cnt1 = len(missing_frames_1) 52 | cnt2 = len(missing_frames_2) 53 | 54 | i = 0 # get the "real" first frame of actor1 55 | while i < num_frames: 56 | if np.any(ske_joints[i, :75] != 0): 57 | break 58 | i += 1 59 | 60 | origin = np.copy(ske_joints[i, 3:6]) # new origin: joint-2 61 | 62 | for f in range(num_frames): 63 | if num_bodies == 1: 64 | ske_joints[f] -= np.tile(origin, 25) 65 | else: # for 2 actors 66 | ske_joints[f] -= np.tile(origin, 50) 67 | 68 | if (num_bodies == 2) and (cnt1 > 0): 69 | ske_joints[missing_frames_1, :75] = np.zeros((cnt1, 75), dtype=np.float32) 70 | 71 | if (num_bodies == 2) and (cnt2 > 0): 72 | ske_joints[missing_frames_2, 75:] = np.zeros((cnt2, 75), dtype=np.float32) 73 | 74 | skes_joints[idx] = ske_joints # Update 75 | 76 | return skes_joints 77 | 78 | 79 | def frame_translation(skes_joints, skes_name, frames_cnt): 80 | nan_logger = logging.getLogger('nan_skes') 81 | nan_logger.setLevel(logging.INFO) 82 | nan_logger.addHandler(logging.FileHandler("./nan_frames.log")) 83 | nan_logger.info('{}\t{}\t{}'.format('Skeleton', 'Frame', 'Joints')) 84 | 85 | for idx, ske_joints in enumerate(skes_joints): 86 | num_frames = ske_joints.shape[0] 87 | # Calculate the distance between spine base (joint-1) and spine (joint-21) 88 | j1 = ske_joints[:, 0:3] 89 | j21 = ske_joints[:, 60:63] 90 | dist = np.sqrt(((j1 - j21) ** 2).sum(axis=1)) 91 | 92 | for f in range(num_frames): 93 | origin = ske_joints[f, 3:6] # new origin: middle of the spine (joint-2) 94 | if (ske_joints[f, 75:] == 0).all(): 95 | ske_joints[f, :75] = (ske_joints[f, :75] - np.tile(origin, 25)) / \ 96 | dist[f] + np.tile(origin, 25) 97 | else: 98 | ske_joints[f] = (ske_joints[f] - np.tile(origin, 50)) / \ 99 | dist[f] + np.tile(origin, 50) 100 | 101 | ske_name = skes_name[idx] 102 | ske_joints = remove_nan_frames(ske_name, ske_joints, nan_logger) 103 | frames_cnt[idx] = num_frames # update valid number of frames 104 | skes_joints[idx] = ske_joints 105 | 106 | return skes_joints, frames_cnt 107 | 108 | 109 | def align_frames(skes_joints, frames_cnt): 110 | """ 111 | Align all sequences with the same frame length. 112 | 113 | """ 114 | num_skes = len(skes_joints) 115 | max_num_frames = frames_cnt.max() # 300 116 | aligned_skes_joints = np.zeros((num_skes, max_num_frames, 150), dtype=np.float32) 117 | 118 | for idx, ske_joints in enumerate(skes_joints): 119 | num_frames = ske_joints.shape[0] 120 | num_bodies = 1 if ske_joints.shape[1] == 75 else 2 121 | if num_bodies == 1: 122 | aligned_skes_joints[idx, :num_frames] = np.hstack((ske_joints, ske_joints)) 123 | # aligned_skes_joints[idx, :num_frames] = np.hstack((ske_joints, np.zeros_like(ske_joints))) 124 | else: 125 | aligned_skes_joints[idx, :num_frames] = ske_joints 126 | 127 | return aligned_skes_joints 128 | 129 | 130 | def one_hot_vector(labels): 131 | num_skes = len(labels) 132 | labels_vector = np.zeros((num_skes, 120)) 133 | for idx, l in enumerate(labels): 134 | labels_vector[idx, l] = 1 135 | 136 | return labels_vector 137 | 138 | 139 | def split_train_val(train_indices, method='sklearn', ratio=0.05): 140 | """ 141 | Get validation set by splitting data randomly from training set with two methods. 142 | In fact, I thought these two methods are equal as they got the same performance. 143 | 144 | """ 145 | if method == 'sklearn': 146 | return train_test_split(train_indices, test_size=ratio, random_state=10000) 147 | else: 148 | np.random.seed(10000) 149 | np.random.shuffle(train_indices) 150 | val_num_skes = int(np.ceil(0.05 * len(train_indices))) 151 | val_indices = train_indices[:val_num_skes] 152 | train_indices = train_indices[val_num_skes:] 153 | return train_indices, val_indices 154 | 155 | 156 | def split_dataset(skes_joints, label, performer, setup, evaluation, save_path): 157 | train_indices, test_indices = get_indices(performer, setup, evaluation) 158 | # m = 'sklearn' # 'sklearn' or 'numpy' 159 | # Select validation set from training set 160 | # train_indices, val_indices = split_train_val(train_indices, m) 161 | 162 | # Save labels and num_frames for each sequence of each data set 163 | train_labels = label[train_indices] 164 | test_labels = label[test_indices] 165 | 166 | train_x = skes_joints[train_indices] 167 | train_y = one_hot_vector(train_labels) 168 | test_x = skes_joints[test_indices] 169 | test_y = one_hot_vector(test_labels) 170 | 171 | save_name = 'NTU120_%s.npz' % evaluation 172 | np.savez(save_name, x_train=train_x, y_train=train_y, x_test=test_x, y_test=test_y) 173 | 174 | # # Save data into a .h5 file 175 | # h5file = h5py.File(osp.join(save_path, 'NTU_%s.h5' % (evaluation)), 'w') 176 | # # Training set 177 | # h5file.create_dataset('x', data=skes_joints[train_indices]) 178 | # train_one_hot_labels = one_hot_vector(train_labels) 179 | # h5file.create_dataset('y', data=train_one_hot_labels) 180 | # # Validation set 181 | # h5file.create_dataset('valid_x', data=skes_joints[val_indices]) 182 | # val_one_hot_labels = one_hot_vector(val_labels) 183 | # h5file.create_dataset('valid_y', data=val_one_hot_labels) 184 | # # Test set 185 | # h5file.create_dataset('test_x', data=skes_joints[test_indices]) 186 | # test_one_hot_labels = one_hot_vector(test_labels) 187 | # h5file.create_dataset('test_y', data=test_one_hot_labels) 188 | 189 | # h5file.close() 190 | 191 | 192 | def get_indices(performer, setup, evaluation='CSub'): 193 | test_indices = np.empty(0) 194 | train_indices = np.empty(0) 195 | 196 | if evaluation == 'CSub': # Cross Subject (Subject IDs) 197 | train_ids = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 198 | 31, 34, 35, 38, 45, 46, 47, 49, 50, 52, 53, 54, 55, 56, 57, 199 | 58, 59, 70, 74, 78, 80, 81, 82, 83, 84, 85, 86, 89, 91, 92, 200 | 93, 94, 95, 97, 98, 100, 103] 201 | test_ids = [i for i in range(1, 107) if i not in train_ids] 202 | 203 | # Get indices of test data 204 | for idx in test_ids: 205 | temp = np.where(performer == idx)[0] # 0-based index 206 | test_indices = np.hstack((test_indices, temp)).astype(np.int) 207 | 208 | # Get indices of training data 209 | for train_id in train_ids: 210 | temp = np.where(performer == train_id)[0] # 0-based index 211 | train_indices = np.hstack((train_indices, temp)).astype(np.int) 212 | else: # Cross Setup (Setup IDs) 213 | train_ids = [i for i in range(1, 33) if i % 2 == 0] # Even setup 214 | test_ids = [i for i in range(1, 33) if i % 2 == 1] # Odd setup 215 | 216 | # Get indices of test data 217 | for test_id in test_ids: 218 | temp = np.where(setup == test_id)[0] # 0-based index 219 | test_indices = np.hstack((test_indices, temp)).astype(np.int) 220 | 221 | # Get indices of training data 222 | for train_id in train_ids: 223 | temp = np.where(setup == train_id)[0] # 0-based index 224 | train_indices = np.hstack((train_indices, temp)).astype(np.int) 225 | 226 | return train_indices, test_indices 227 | 228 | 229 | if __name__ == '__main__': 230 | setup = np.loadtxt(setup_file, dtype=np.int) # camera id: 1~32 231 | performer = np.loadtxt(performer_file, dtype=np.int) # subject id: 1~106 232 | label = np.loadtxt(label_file, dtype=np.int) - 1 # action label: 0~119 233 | 234 | frames_cnt = np.loadtxt(frames_file, dtype=np.int) # frames_cnt 235 | skes_name = np.loadtxt(skes_name_file, dtype=np.string_) 236 | 237 | with open(raw_skes_joints_pkl, 'rb') as fr: 238 | skes_joints = pickle.load(fr) # a list 239 | 240 | skes_joints = seq_translation(skes_joints) 241 | 242 | skes_joints = align_frames(skes_joints, frames_cnt) # aligned to the same frame length 243 | 244 | evaluations = ['CSet', 'CSub'] 245 | for evaluation in evaluations: 246 | split_dataset(skes_joints, label, performer, setup, evaluation, save_path) 247 | -------------------------------------------------------------------------------- /data/ntu120/statistics/NTU_RGBD120_samples_with_missing_skeletons.txt: -------------------------------------------------------------------------------- 1 | S001C002P005R002A008 2 | S001C002P006R001A008 3 | S001C003P002R001A055 4 | S001C003P002R002A012 5 | S001C003P005R002A004 6 | S001C003P005R002A005 7 | S001C003P005R002A006 8 | S001C003P006R002A008 9 | S002C002P011R002A030 10 | S002C003P008R001A020 11 | S002C003P010R002A010 12 | S002C003P011R002A007 13 | S002C003P011R002A011 14 | S002C003P014R002A007 15 | S003C001P019R001A055 16 | S003C002P002R002A055 17 | S003C002P018R002A055 18 | S003C003P002R001A055 19 | S003C003P016R001A055 20 | S003C003P018R002A024 21 | S004C002P003R001A013 22 | S004C002P008R001A009 23 | S004C002P020R001A003 24 | S004C002P020R001A004 25 | S004C002P020R001A012 26 | S004C002P020R001A020 27 | S004C002P020R001A021 28 | S004C002P020R001A036 29 | S005C002P004R001A001 30 | S005C002P004R001A003 31 | S005C002P010R001A016 32 | S005C002P010R001A017 33 | S005C002P010R001A048 34 | S005C002P010R001A049 35 | S005C002P016R001A009 36 | S005C002P016R001A010 37 | S005C002P018R001A003 38 | S005C002P018R001A028 39 | S005C002P018R001A029 40 | S005C003P016R002A009 41 | S005C003P018R002A013 42 | S005C003P021R002A057 43 | S006C001P001R002A055 44 | S006C002P007R001A005 45 | S006C002P007R001A006 46 | S006C002P016R001A043 47 | S006C002P016R001A051 48 | S006C002P016R001A052 49 | S006C002P022R001A012 50 | S006C002P023R001A020 51 | S006C002P023R001A021 52 | S006C002P023R001A022 53 | S006C002P023R001A023 54 | S006C002P024R001A018 55 | S006C002P024R001A019 56 | S006C003P001R002A013 57 | S006C003P007R002A009 58 | S006C003P007R002A010 59 | S006C003P007R002A025 60 | S006C003P016R001A060 61 | S006C003P017R001A055 62 | S006C003P017R002A013 63 | S006C003P017R002A014 64 | S006C003P017R002A015 65 | S006C003P022R002A013 66 | S007C001P018R002A050 67 | S007C001P025R002A051 68 | S007C001P028R001A050 69 | S007C001P028R001A051 70 | S007C001P028R001A052 71 | S007C002P008R002A008 72 | S007C002P015R002A055 73 | S007C002P026R001A008 74 | S007C002P026R001A009 75 | S007C002P026R001A010 76 | S007C002P026R001A011 77 | S007C002P026R001A012 78 | S007C002P026R001A050 79 | S007C002P027R001A011 80 | S007C002P027R001A013 81 | S007C002P028R002A055 82 | S007C003P007R001A002 83 | S007C003P007R001A004 84 | S007C003P019R001A060 85 | S007C003P027R002A001 86 | S007C003P027R002A002 87 | S007C003P027R002A003 88 | S007C003P027R002A004 89 | S007C003P027R002A005 90 | S007C003P027R002A006 91 | S007C003P027R002A007 92 | S007C003P027R002A008 93 | S007C003P027R002A009 94 | S007C003P027R002A010 95 | S007C003P027R002A011 96 | S007C003P027R002A012 97 | S007C003P027R002A013 98 | S008C002P001R001A009 99 | S008C002P001R001A010 100 | S008C002P001R001A014 101 | S008C002P001R001A015 102 | S008C002P001R001A016 103 | S008C002P001R001A018 104 | S008C002P001R001A019 105 | S008C002P008R002A059 106 | S008C002P025R001A060 107 | S008C002P029R001A004 108 | S008C002P031R001A005 109 | S008C002P031R001A006 110 | S008C002P032R001A018 111 | S008C002P034R001A018 112 | S008C002P034R001A019 113 | S008C002P035R001A059 114 | S008C002P035R002A002 115 | S008C002P035R002A005 116 | S008C003P007R001A009 117 | S008C003P007R001A016 118 | S008C003P007R001A017 119 | S008C003P007R001A018 120 | S008C003P007R001A019 121 | S008C003P007R001A020 122 | S008C003P007R001A021 123 | S008C003P007R001A022 124 | S008C003P007R001A023 125 | S008C003P007R001A025 126 | S008C003P007R001A026 127 | S008C003P007R001A028 128 | S008C003P007R001A029 129 | S008C003P007R002A003 130 | S008C003P008R002A050 131 | S008C003P025R002A002 132 | S008C003P025R002A011 133 | S008C003P025R002A012 134 | S008C003P025R002A016 135 | S008C003P025R002A020 136 | S008C003P025R002A022 137 | S008C003P025R002A023 138 | S008C003P025R002A030 139 | S008C003P025R002A031 140 | S008C003P025R002A032 141 | S008C003P025R002A033 142 | S008C003P025R002A049 143 | S008C003P025R002A060 144 | S008C003P031R001A001 145 | S008C003P031R002A004 146 | S008C003P031R002A014 147 | S008C003P031R002A015 148 | S008C003P031R002A016 149 | S008C003P031R002A017 150 | S008C003P032R002A013 151 | S008C003P033R002A001 152 | S008C003P033R002A011 153 | S008C003P033R002A012 154 | S008C003P034R002A001 155 | S008C003P034R002A012 156 | S008C003P034R002A022 157 | S008C003P034R002A023 158 | S008C003P034R002A024 159 | S008C003P034R002A044 160 | S008C003P034R002A045 161 | S008C003P035R002A016 162 | S008C003P035R002A017 163 | S008C003P035R002A018 164 | S008C003P035R002A019 165 | S008C003P035R002A020 166 | S008C003P035R002A021 167 | S009C002P007R001A001 168 | S009C002P007R001A003 169 | S009C002P007R001A014 170 | S009C002P008R001A014 171 | S009C002P015R002A050 172 | S009C002P016R001A002 173 | S009C002P017R001A028 174 | S009C002P017R001A029 175 | S009C003P017R002A030 176 | S009C003P025R002A054 177 | S010C001P007R002A020 178 | S010C002P016R002A055 179 | S010C002P017R001A005 180 | S010C002P017R001A018 181 | S010C002P017R001A019 182 | S010C002P019R001A001 183 | S010C002P025R001A012 184 | S010C003P007R002A043 185 | S010C003P008R002A003 186 | S010C003P016R001A055 187 | S010C003P017R002A055 188 | S011C001P002R001A008 189 | S011C001P018R002A050 190 | S011C002P008R002A059 191 | S011C002P016R002A055 192 | S011C002P017R001A020 193 | S011C002P017R001A021 194 | S011C002P018R002A055 195 | S011C002P027R001A009 196 | S011C002P027R001A010 197 | S011C002P027R001A037 198 | S011C003P001R001A055 199 | S011C003P002R001A055 200 | S011C003P008R002A012 201 | S011C003P015R001A055 202 | S011C003P016R001A055 203 | S011C003P019R001A055 204 | S011C003P025R001A055 205 | S011C003P028R002A055 206 | S012C001P019R001A060 207 | S012C001P019R002A060 208 | S012C002P015R001A055 209 | S012C002P017R002A012 210 | S012C002P025R001A060 211 | S012C003P008R001A057 212 | S012C003P015R001A055 213 | S012C003P015R002A055 214 | S012C003P016R001A055 215 | S012C003P017R002A055 216 | S012C003P018R001A055 217 | S012C003P018R001A057 218 | S012C003P019R002A011 219 | S012C003P019R002A012 220 | S012C003P025R001A055 221 | S012C003P027R001A055 222 | S012C003P027R002A009 223 | S012C003P028R001A035 224 | S012C003P028R002A055 225 | S013C001P015R001A054 226 | S013C001P017R002A054 227 | S013C001P018R001A016 228 | S013C001P028R001A040 229 | S013C002P015R001A054 230 | S013C002P017R002A054 231 | S013C002P028R001A040 232 | S013C003P008R002A059 233 | S013C003P015R001A054 234 | S013C003P017R002A054 235 | S013C003P025R002A022 236 | S013C003P027R001A055 237 | S013C003P028R001A040 238 | S014C001P027R002A040 239 | S014C002P015R001A003 240 | S014C002P019R001A029 241 | S014C002P025R002A059 242 | S014C002P027R002A040 243 | S014C002P039R001A050 244 | S014C003P007R002A059 245 | S014C003P015R002A055 246 | S014C003P019R002A055 247 | S014C003P025R001A048 248 | S014C003P027R002A040 249 | S015C001P008R002A040 250 | S015C001P016R001A055 251 | S015C001P017R001A055 252 | S015C001P017R002A055 253 | S015C002P007R001A059 254 | S015C002P008R001A003 255 | S015C002P008R001A004 256 | S015C002P008R002A040 257 | S015C002P015R001A002 258 | S015C002P016R001A001 259 | S015C002P016R002A055 260 | S015C003P008R002A007 261 | S015C003P008R002A011 262 | S015C003P008R002A012 263 | S015C003P008R002A028 264 | S015C003P008R002A040 265 | S015C003P025R002A012 266 | S015C003P025R002A017 267 | S015C003P025R002A020 268 | S015C003P025R002A021 269 | S015C003P025R002A030 270 | S015C003P025R002A033 271 | S015C003P025R002A034 272 | S015C003P025R002A036 273 | S015C003P025R002A037 274 | S015C003P025R002A044 275 | S016C001P019R002A040 276 | S016C001P025R001A011 277 | S016C001P025R001A012 278 | S016C001P025R001A060 279 | S016C001P040R001A055 280 | S016C001P040R002A055 281 | S016C002P008R001A011 282 | S016C002P019R002A040 283 | S016C002P025R002A012 284 | S016C003P008R001A011 285 | S016C003P008R002A002 286 | S016C003P008R002A003 287 | S016C003P008R002A004 288 | S016C003P008R002A006 289 | S016C003P008R002A009 290 | S016C003P019R002A040 291 | S016C003P039R002A016 292 | S017C001P016R002A031 293 | S017C002P007R001A013 294 | S017C002P008R001A009 295 | S017C002P015R001A042 296 | S017C002P016R002A031 297 | S017C002P016R002A055 298 | S017C003P007R002A013 299 | S017C003P008R001A059 300 | S017C003P016R002A031 301 | S017C003P017R001A055 302 | S017C003P020R001A059 303 | S019C001P046R001A075 304 | S019C002P042R001A094 305 | S019C002P042R001A095 306 | S019C002P042R001A096 307 | S019C002P042R001A097 308 | S019C002P042R001A098 309 | S019C002P042R001A099 310 | S019C002P042R001A100 311 | S019C002P042R001A101 312 | S019C002P042R001A102 313 | S019C002P049R002A074 314 | S019C002P049R002A079 315 | S019C002P051R001A061 316 | S019C003P046R001A061 317 | S019C003P046R002A061 318 | S019C003P046R002A062 319 | S020C002P041R001A063 320 | S020C002P041R001A064 321 | S020C002P044R001A063 322 | S020C002P044R001A064 323 | S020C002P044R001A066 324 | S020C002P044R001A084 325 | S020C002P054R001A081 326 | S021C001P059R001A108 327 | S021C002P055R001A065 328 | S021C002P055R001A092 329 | S021C002P055R001A093 330 | S021C002P057R001A064 331 | S021C002P058R001A063 332 | S021C002P058R001A064 333 | S021C002P059R001A074 334 | S021C002P059R001A075 335 | S021C002P059R001A076 336 | S021C002P059R001A077 337 | S021C002P059R001A078 338 | S021C002P059R001A079 339 | S021C003P057R002A078 340 | S021C003P057R002A079 341 | S021C003P057R002A094 342 | S022C002P061R001A113 343 | S022C003P061R002A061 344 | S022C003P061R002A062 345 | S022C003P063R002A061 346 | S022C003P063R002A062 347 | S022C003P063R002A063 348 | S022C003P063R002A064 349 | S022C003P063R002A078 350 | S022C003P064R002A061 351 | S022C003P064R002A062 352 | S022C003P065R002A061 353 | S022C003P065R002A062 354 | S022C003P065R002A119 355 | S022C003P067R002A064 356 | S023C002P055R001A114 357 | S023C002P055R002A092 358 | S023C002P059R001A075 359 | S023C002P063R001A075 360 | S023C003P055R002A093 361 | S023C003P055R002A094 362 | S023C003P061R002A061 363 | S023C003P064R001A092 364 | S024C001P063R001A109 365 | S024C002P062R002A074 366 | S024C002P067R001A100 367 | S024C002P067R001A101 368 | S024C002P067R001A102 369 | S024C002P067R001A103 370 | S024C003P062R002A074 371 | S024C003P063R002A061 372 | S024C003P063R002A062 373 | S025C001P055R002A119 374 | S025C003P056R002A119 375 | S025C003P059R002A115 376 | S026C002P044R001A061 377 | S026C002P044R001A062 378 | S026C002P070R001A092 379 | S026C003P069R002A075 380 | S026C003P074R002A061 381 | S026C003P074R002A062 382 | S026C003P075R001A117 383 | S026C003P075R001A118 384 | S027C001P082R001A063 385 | S027C002P044R002A092 386 | S027C002P079R001A061 387 | S027C002P079R001A062 388 | S027C002P079R001A063 389 | S027C002P079R001A064 390 | S027C002P082R001A092 391 | S027C002P084R001A061 392 | S027C002P084R001A062 393 | S027C002P086R001A061 394 | S027C003P041R002A087 395 | S027C003P080R002A061 396 | S027C003P082R002A061 397 | S027C003P082R002A062 398 | S027C003P086R002A061 399 | S027C003P086R002A062 400 | S028C001P087R001A061 401 | S028C002P041R001A091 402 | S028C002P087R001A061 403 | S028C003P042R002A064 404 | S028C003P046R002A063 405 | S028C003P046R002A066 406 | S028C003P046R002A067 407 | S028C003P046R002A068 408 | S028C003P046R002A069 409 | S028C003P046R002A070 410 | S028C003P046R002A071 411 | S028C003P046R002A072 412 | S028C003P046R002A074 413 | S028C003P046R002A075 414 | S028C003P046R002A077 415 | S028C003P046R002A081 416 | S028C003P046R002A082 417 | S028C003P046R002A083 418 | S028C003P046R002A084 419 | S028C003P048R002A061 420 | S028C003P048R002A062 421 | S028C003P048R002A073 422 | S028C003P073R002A073 423 | S028C003P087R001A061 424 | S028C003P087R002A061 425 | S028C003P087R002A062 426 | S029C001P043R002A092 427 | S029C001P044R002A092 428 | S029C001P048R001A073 429 | S029C001P089R001A063 430 | S029C002P041R001A074 431 | S029C002P041R001A084 432 | S029C002P044R001A091 433 | S029C002P048R001A075 434 | S029C002P048R001A081 435 | S029C002P074R001A081 436 | S029C002P074R001A095 437 | S029C002P074R001A096 438 | S029C002P080R001A091 439 | S029C002P088R001A066 440 | S029C002P089R001A065 441 | S029C002P090R001A067 442 | S029C003P008R002A065 443 | S029C003P008R002A067 444 | S029C003P041R001A089 445 | S029C003P043R001A080 446 | S029C003P043R001A092 447 | S029C003P043R001A105 448 | S029C003P043R002A085 449 | S029C003P043R002A086 450 | S029C003P044R002A106 451 | S029C003P048R001A065 452 | S029C003P048R002A073 453 | S029C003P048R002A074 454 | S029C003P048R002A075 455 | S029C003P048R002A076 456 | S029C003P048R002A092 457 | S029C003P048R002A094 458 | S029C003P051R002A073 459 | S029C003P051R002A074 460 | S029C003P051R002A075 461 | S029C003P051R002A076 462 | S029C003P051R002A077 463 | S029C003P051R002A078 464 | S029C003P051R002A079 465 | S029C003P051R002A080 466 | S029C003P051R002A081 467 | S029C003P051R002A082 468 | S029C003P051R002A083 469 | S029C003P051R002A084 470 | S029C003P051R002A085 471 | S029C003P051R002A086 472 | S029C003P051R002A110 473 | S029C003P067R001A098 474 | S029C003P074R002A110 475 | S029C003P080R002A066 476 | S029C003P088R002A078 477 | S029C003P089R001A075 478 | S029C003P089R002A061 479 | S029C003P089R002A062 480 | S029C003P089R002A063 481 | S029C003P090R002A092 482 | S029C003P090R002A095 483 | S030C002P091R002A091 484 | S030C002P091R002A092 485 | S030C002P091R002A093 486 | S030C002P091R002A094 487 | S030C002P091R002A095 488 | S030C002P091R002A096 489 | S030C002P091R002A097 490 | S030C002P091R002A098 491 | S030C002P091R002A099 492 | S030C002P091R002A100 493 | S030C002P091R002A101 494 | S030C002P091R002A102 495 | S030C002P091R002A103 496 | S030C002P091R002A104 497 | S030C002P091R002A105 498 | S030C003P044R002A065 499 | S030C003P044R002A081 500 | S030C003P044R002A084 501 | S031C002P042R001A111 502 | S031C002P051R001A061 503 | S031C002P051R001A062 504 | S031C002P067R001A067 505 | S031C002P067R001A068 506 | S031C002P067R001A069 507 | S031C002P067R001A070 508 | S031C002P067R001A071 509 | S031C002P067R001A072 510 | S031C002P082R001A075 511 | S031C002P082R002A117 512 | S031C002P097R001A061 513 | S031C002P097R001A062 514 | S031C003P043R002A074 515 | S031C003P043R002A075 516 | S031C003P044R002A094 517 | S031C003P082R002A067 518 | S031C003P082R002A068 519 | S031C003P082R002A069 520 | S031C003P082R002A070 521 | S031C003P082R002A071 522 | S031C003P082R002A072 523 | S031C003P082R002A073 524 | S031C003P082R002A075 525 | S031C003P082R002A076 526 | S031C003P082R002A077 527 | S031C003P082R002A084 528 | S031C003P082R002A085 529 | S031C003P082R002A086 530 | S032C002P067R001A092 531 | S032C003P067R002A066 532 | S032C003P067R002A067 533 | S032C003P067R002A075 534 | S032C003P067R002A076 535 | S032C003P067R002A077 536 | -------------------------------------------------------------------------------- /model/HDGCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | 11 | from graph.tools import get_groups 12 | 13 | def import_class(name): 14 | components = name.split('.') 15 | mod = __import__(components[0]) 16 | for comp in components[1:]: 17 | mod = getattr(mod, comp) 18 | return mod 19 | 20 | 21 | def conv_branch_init(conv, branches): 22 | weight = conv.weight 23 | n = weight.size(0) 24 | k1 = weight.size(1) 25 | k2 = weight.size(2) 26 | nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches))) 27 | if conv.bias is not None: 28 | nn.init.constant_(conv.bias, 0) 29 | 30 | 31 | def conv_init(conv): 32 | if conv.weight is not None: 33 | nn.init.kaiming_normal_(conv.weight, mode='fan_out') 34 | if conv.bias is not None: 35 | nn.init.constant_(conv.bias, 0) 36 | 37 | 38 | def bn_init(bn, scale): 39 | nn.init.constant_(bn.weight, scale) 40 | nn.init.constant_(bn.bias, 0) 41 | 42 | 43 | def weights_init(m): 44 | classname = m.__class__.__name__ 45 | if classname.find('Conv') != -1: 46 | if hasattr(m, 'weight'): 47 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 48 | if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): 49 | nn.init.constant_(m.bias, 0) 50 | elif classname.find('BatchNorm') != -1: 51 | if hasattr(m, 'weight') and m.weight is not None: 52 | m.weight.data.normal_(1.0, 0.02) 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | m.bias.data.fill_(0) 55 | 56 | 57 | class TemporalConv(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1): 59 | super(TemporalConv, self).__init__() 60 | pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 61 | self.conv = nn.Conv2d( 62 | in_channels, 63 | out_channels, 64 | kernel_size=(kernel_size, 1), 65 | padding=(pad, 0), 66 | stride=(stride, 1), 67 | dilation=(dilation, 1), 68 | bias=False) 69 | self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1), requires_grad=True) 70 | 71 | self.bn = nn.BatchNorm2d(out_channels) 72 | 73 | def forward(self, x): 74 | x = self.conv(x) + self.bias 75 | x = self.bn(x) 76 | return x 77 | 78 | 79 | class MultiScale_TemporalConv(nn.Module): 80 | def __init__(self, 81 | in_channels, 82 | out_channels, 83 | kernel_size=5, 84 | stride=1, 85 | dilations=[1,2], 86 | residual=True, 87 | residual_kernel_size=1): 88 | 89 | super().__init__() 90 | assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches' 91 | 92 | # Multiple branches of temporal convolution 93 | self.num_branches = len(dilations) + 2 94 | branch_channels = out_channels // self.num_branches 95 | if type(kernel_size) == list: 96 | assert len(kernel_size) == len(dilations) 97 | else: 98 | kernel_size = [kernel_size] * len(dilations) 99 | # Temporal Convolution branches 100 | self.branches = nn.ModuleList([ 101 | nn.Sequential( 102 | nn.Conv2d( 103 | in_channels, 104 | branch_channels, 105 | kernel_size=1, 106 | padding=0), 107 | nn.BatchNorm2d(branch_channels), 108 | nn.ReLU(inplace=True), 109 | TemporalConv( 110 | branch_channels, 111 | branch_channels, 112 | kernel_size=ks, 113 | stride=stride, 114 | dilation=dilation), 115 | ) 116 | for ks, dilation in zip(kernel_size, dilations) 117 | ]) 118 | 119 | # Additional Max & 1x1 branch 120 | self.branches.append(nn.Sequential( 121 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0), 122 | nn.BatchNorm2d(branch_channels), 123 | nn.ReLU(inplace=True), 124 | nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)), 125 | nn.BatchNorm2d(branch_channels) 126 | )) 127 | 128 | self.branches.append(nn.Sequential( 129 | nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)), 130 | nn.BatchNorm2d(branch_channels) 131 | )) 132 | 133 | # Residual connection 134 | if not residual: 135 | self.residual = lambda x: 0 136 | elif (in_channels == out_channels) and (stride == 1): 137 | self.residual = lambda x: x 138 | else: 139 | self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride) 140 | 141 | # initialize 142 | self.apply(weights_init) 143 | 144 | def forward(self, x): 145 | branch_outs = [] 146 | for tempconv in self.branches: 147 | out = tempconv(x) 148 | branch_outs.append(out) 149 | 150 | out = torch.cat(branch_outs, dim=1) 151 | out += self.residual(x) 152 | return out 153 | 154 | 155 | class residual_conv(nn.Module): 156 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=1): 157 | super(residual_conv, self).__init__() 158 | pad = int((kernel_size - 1) / 2) 159 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), 160 | stride=(stride, 1)) 161 | 162 | self.bn = nn.BatchNorm2d(out_channels) 163 | self.relu = nn.ReLU(inplace=True) 164 | conv_init(self.conv) 165 | bn_init(self.bn, 1) 166 | 167 | def forward(self, x): 168 | x = self.bn(self.conv(x)) 169 | return x 170 | 171 | class EdgeConv(nn.Module): 172 | def __init__(self, in_channels, out_channels, k): 173 | super(EdgeConv, self).__init__() 174 | 175 | self.k = k 176 | 177 | self.conv = nn.Sequential( 178 | nn.Conv2d(in_channels*2, out_channels, kernel_size=1, bias=False), 179 | nn.BatchNorm2d(out_channels), 180 | nn.LeakyReLU(inplace=True, negative_slope=0.2) 181 | ) 182 | 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | conv_init(m) 186 | elif isinstance(m, nn.BatchNorm2d): 187 | bn_init(m, 1) 188 | 189 | def forward(self, x, dim=4): # N, C, T, V 190 | 191 | if dim == 3: 192 | N, C, L = x.size() 193 | pass 194 | else: 195 | N, C, T, V = x.size() 196 | x = x.mean(dim=-2, keepdim=False) # N, C, V 197 | 198 | x = self.get_graph_feature(x, self.k) 199 | x = self.conv(x) 200 | x = x.max(dim=-1, keepdim=False)[0] 201 | 202 | if dim == 3: 203 | pass 204 | else: 205 | x = repeat(x, 'n c v -> n c t v', t=T) 206 | 207 | return x 208 | 209 | def knn(self, x, k): 210 | 211 | inner = -2 * torch.matmul(x.transpose(2, 1), x) # N, V, V 212 | xx = torch.sum(x**2, dim=1, keepdim=True) 213 | pairwise_distance = - xx - inner - xx.transpose(2, 1) 214 | 215 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # N, V, k 216 | return idx 217 | 218 | def get_graph_feature(self, x, k, idx=None): 219 | N, C, V = x.size() 220 | if idx is None: 221 | idx = self.knn(x, k=k) 222 | device = x.get_device() 223 | 224 | idx_base = torch.arange(0, N, device=device).view(-1, 1, 1) * V 225 | 226 | idx = idx + idx_base 227 | idx = idx.view(-1) 228 | 229 | x = rearrange(x, 'n c v -> n v c') 230 | feature = rearrange(x, 'n v c -> (n v) c')[idx, :] 231 | feature = feature.view(N, V, k, C) 232 | x = repeat(x, 'n v c -> n v k c', k=k) 233 | 234 | feature = torch.cat((feature - x, x), dim=3) 235 | feature = rearrange(feature, 'n v k c -> n c v k') 236 | 237 | return feature 238 | 239 | 240 | class AHA(nn.Module): 241 | def __init__(self, in_channels, num_layers, CoM): 242 | super(AHA, self).__init__() 243 | 244 | self.num_layers = num_layers 245 | 246 | groups = get_groups(dataset='NTU', CoM=CoM) 247 | 248 | for i, group in enumerate(groups): 249 | group = [i - 1 for i in group] 250 | groups[i] = group 251 | 252 | inter_channels = in_channels // 4 253 | 254 | self.layers = [groups[i] + groups[i + 1] for i in range(len(groups) - 1)] 255 | 256 | self.conv_down = nn.Sequential( 257 | nn.Conv2d(in_channels, inter_channels, kernel_size=1), 258 | nn.BatchNorm2d(inter_channels), 259 | nn.ReLU(inplace=True) 260 | ) 261 | 262 | self.edge_conv = EdgeConv(inter_channels, inter_channels, k=3) 263 | 264 | self.aggregate = nn.Conv1d(inter_channels, in_channels, kernel_size=1) 265 | self.sigmoid = nn.Sigmoid() 266 | 267 | 268 | 269 | def forward(self, x): 270 | N, C, L, T, V = x.size() 271 | 272 | x_t = x.max(dim=-2, keepdim=False)[0] 273 | x_t = self.conv_down(x_t) 274 | 275 | x_sampled = [] 276 | for i in range(self.num_layers): 277 | s_t = x_t[:, :, i, self.layers[i]] 278 | s_t = s_t.mean(dim=-1, keepdim=True) 279 | x_sampled.append(s_t) 280 | x_sampled = torch.cat(x_sampled, dim=2) 281 | 282 | att = self.edge_conv(x_sampled, dim=3) 283 | att = self.aggregate(att).view(N, C, L, 1, 1) 284 | 285 | out = (x * self.sigmoid(att)).sum(dim=2, keepdim=False) 286 | 287 | return out 288 | 289 | 290 | 291 | class HD_Gconv(nn.Module): 292 | def __init__(self, in_channels, out_channels, A, adaptive=True, residual=True, att=False, CoM=21): 293 | super(HD_Gconv, self).__init__() 294 | self.num_layers = A.shape[0] 295 | self.num_subset = A.shape[1] 296 | 297 | self.att = att 298 | 299 | inter_channels = out_channels // (self.num_subset + 1) 300 | self.adaptive = adaptive 301 | 302 | if adaptive: 303 | self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)), requires_grad=True) 304 | else: 305 | raise ValueError() 306 | 307 | self.conv_down = nn.ModuleList() 308 | self.conv = nn.ModuleList() 309 | for i in range(self.num_layers): 310 | self.conv_d = nn.ModuleList() 311 | self.conv_down.append(nn.Sequential( 312 | nn.Conv2d(in_channels, inter_channels, kernel_size=1), 313 | nn.BatchNorm2d(inter_channels), 314 | nn.ReLU(inplace=True) 315 | )) 316 | for j in range(self.num_subset): 317 | self.conv_d.append(nn.Sequential( 318 | nn.Conv2d(inter_channels, inter_channels, kernel_size=1), 319 | nn.BatchNorm2d(inter_channels) 320 | )) 321 | 322 | self.conv_d.append(EdgeConv(inter_channels, inter_channels, k=5)) 323 | self.conv.append(self.conv_d) 324 | 325 | if self.att: 326 | self.aha = AHA(out_channels, num_layers=self.num_layers, CoM=CoM) 327 | 328 | if residual: 329 | if in_channels != out_channels: 330 | self.down = nn.Sequential( 331 | nn.Conv2d(in_channels, out_channels, 1), 332 | nn.BatchNorm2d(out_channels) 333 | ) 334 | else: 335 | self.down = lambda x: x 336 | else: 337 | self.down = lambda x: 0 338 | 339 | self.bn = nn.BatchNorm2d(out_channels) 340 | 341 | # 7개 conv layer 342 | self.relu = nn.ReLU(inplace=True) 343 | 344 | for m in self.modules(): 345 | if isinstance(m, nn.Conv2d): 346 | conv_init(m) 347 | elif isinstance(m, nn.BatchNorm2d): 348 | bn_init(m, 1) 349 | bn_init(self.bn, 1e-6) 350 | 351 | def forward(self, x): 352 | 353 | A = self.PA 354 | 355 | out = [] 356 | for i in range(self.num_layers): 357 | y = [] 358 | x_down = self.conv_down[i](x) 359 | for j in range(self.num_subset): 360 | z = torch.einsum('n c t u, v u -> n c t v', x_down, A[i, j]) 361 | z = self.conv[i][j](z) 362 | y.append(z) 363 | y_edge = self.conv[i][-1](x_down) 364 | y.append(y_edge) 365 | y = torch.cat(y, dim=1) 366 | 367 | out.append(y) 368 | 369 | out = torch.stack(out, dim=2) 370 | if self.att: 371 | out = self.aha(out) 372 | else: 373 | out = out.sum(dim=2, keepdim=False) 374 | 375 | out = self.bn(out) 376 | 377 | out += self.down(x) 378 | out = self.relu(out) 379 | 380 | return out 381 | 382 | class TCN_GCN_unit(nn.Module): 383 | def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, 384 | kernel_size=5, dilations=[1, 2], att=True, CoM=21): 385 | super(TCN_GCN_unit, self).__init__() 386 | self.gcn1 = HD_Gconv(in_channels, out_channels, A, adaptive=adaptive, att=att, CoM=CoM) 387 | self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilations=dilations, 388 | residual=False) 389 | self.relu = nn.ReLU(inplace=True) 390 | if not residual: 391 | self.residual = lambda x: 0 392 | 393 | elif (in_channels == out_channels) and (stride == 1): 394 | self.residual = lambda x: x 395 | 396 | else: 397 | self.residual = residual_conv(in_channels, out_channels, kernel_size=1, stride=stride) 398 | 399 | def forward(self, x): 400 | y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x)) 401 | return y 402 | 403 | 404 | class Model(nn.Module): 405 | def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3, 406 | drop_out=0, adaptive=True): 407 | super(Model, self).__init__() 408 | 409 | if graph is None: 410 | raise ValueError() 411 | else: 412 | Graph = import_class(graph) 413 | self.graph = Graph(**graph_args) 414 | A, CoM = self.graph.A 415 | 416 | self.dataset = 'NTU' if num_point == 25 else 'UCLA' 417 | 418 | self.num_class = num_class 419 | self.num_point = num_point 420 | self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point) 421 | 422 | base_channels = 64 423 | 424 | self.l1 = TCN_GCN_unit(3, base_channels, A, residual=False, adaptive=adaptive, att=False, CoM=CoM) 425 | self.l2 = TCN_GCN_unit(base_channels, base_channels, A, adaptive=adaptive, CoM=CoM) 426 | self.l3 = TCN_GCN_unit(base_channels, base_channels, A, adaptive=adaptive, CoM=CoM) 427 | self.l4 = TCN_GCN_unit(base_channels, base_channels, A, adaptive=adaptive, CoM=CoM) 428 | self.l5 = TCN_GCN_unit(base_channels, base_channels*2, A, stride=2, adaptive=adaptive, CoM=CoM) 429 | self.l6 = TCN_GCN_unit(base_channels*2, base_channels*2, A, adaptive=adaptive, CoM=CoM) 430 | self.l7 = TCN_GCN_unit(base_channels*2, base_channels*2, A, adaptive=adaptive, CoM=CoM) 431 | self.l8 = TCN_GCN_unit(base_channels*2, base_channels*4, A, stride=2, adaptive=adaptive, CoM=CoM) 432 | self.l9 = TCN_GCN_unit(base_channels*4, base_channels*4, A, adaptive=adaptive, CoM=CoM) 433 | self.l10 = TCN_GCN_unit(base_channels*4, base_channels*4, A, adaptive=adaptive, CoM=CoM) 434 | 435 | self.fc = nn.Linear(base_channels*4, num_class) 436 | 437 | nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class)) 438 | bn_init(self.data_bn, 1) 439 | if drop_out: 440 | self.drop_out = nn.Dropout(drop_out) 441 | else: 442 | self.drop_out = lambda x: x 443 | 444 | def forward(self, x): 445 | N, C, T, V, M = x.size() 446 | x = rearrange(x, 'n c t v m -> n (m v c) t') 447 | x = self.data_bn(x) 448 | x = rearrange(x, 'n (m v c) t -> (n m) c t v', m=M, v=V) 449 | 450 | x = self.l1(x) 451 | x = self.l2(x) 452 | x = self.l3(x) 453 | x = self.l4(x) 454 | x = self.l5(x) 455 | x = self.l6(x) 456 | x = self.l7(x) 457 | x = self.l8(x) 458 | x = self.l9(x) 459 | x = self.l10(x) 460 | 461 | # N*M,C,T,V 462 | c_new = x.size(1) 463 | x = x.view(N, M, c_new, -1) 464 | x = x.mean(3).mean(1) 465 | x = self.drop_out(x) 466 | 467 | return self.fc(x) 468 | -------------------------------------------------------------------------------- /data/ntu/get_raw_denoised_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import pickle 7 | import logging 8 | 9 | root_path = './' 10 | raw_data_file = osp.join(root_path, 'raw_data', 'raw_skes_data.pkl') 11 | save_path = osp.join(root_path, 'denoised_data') 12 | 13 | if not osp.exists(save_path): 14 | os.mkdir(save_path) 15 | 16 | rgb_ske_path = osp.join(save_path, 'rgb+ske') 17 | if not osp.exists(rgb_ske_path): 18 | os.mkdir(rgb_ske_path) 19 | 20 | actors_info_dir = osp.join(save_path, 'actors_info') 21 | if not osp.exists(actors_info_dir): 22 | os.mkdir(actors_info_dir) 23 | 24 | missing_count = 0 25 | noise_len_thres = 11 26 | noise_spr_thres1 = 0.8 27 | noise_spr_thres2 = 0.69754 28 | noise_mot_thres_lo = 0.089925 29 | noise_mot_thres_hi = 2 30 | 31 | noise_len_logger = logging.getLogger('noise_length') 32 | noise_len_logger.setLevel(logging.INFO) 33 | noise_len_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_length.log'))) 34 | noise_len_logger.info('{:^20}\t{:^17}\t{:^8}\t{}'.format('Skeleton', 'bodyID', 'Motion', 'Length')) 35 | 36 | noise_spr_logger = logging.getLogger('noise_spread') 37 | noise_spr_logger.setLevel(logging.INFO) 38 | noise_spr_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_spread.log'))) 39 | noise_spr_logger.info('{:^20}\t{:^17}\t{:^8}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion', 'Rate')) 40 | 41 | noise_mot_logger = logging.getLogger('noise_motion') 42 | noise_mot_logger.setLevel(logging.INFO) 43 | noise_mot_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_motion.log'))) 44 | noise_mot_logger.info('{:^20}\t{:^17}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion')) 45 | 46 | fail_logger_1 = logging.getLogger('noise_outliers_1') 47 | fail_logger_1.setLevel(logging.INFO) 48 | fail_logger_1.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_1.log'))) 49 | 50 | fail_logger_2 = logging.getLogger('noise_outliers_2') 51 | fail_logger_2.setLevel(logging.INFO) 52 | fail_logger_2.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_2.log'))) 53 | 54 | missing_skes_logger = logging.getLogger('missing_frames') 55 | missing_skes_logger.setLevel(logging.INFO) 56 | missing_skes_logger.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes.log'))) 57 | missing_skes_logger.info('{:^20}\t{}\t{}'.format('Skeleton', 'num_frames', 'num_missing')) 58 | 59 | missing_skes_logger1 = logging.getLogger('missing_frames_1') 60 | missing_skes_logger1.setLevel(logging.INFO) 61 | missing_skes_logger1.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_1.log'))) 62 | missing_skes_logger1.info('{:^20}\t{}\t{}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1', 63 | 'Actor2', 'Start', 'End')) 64 | 65 | missing_skes_logger2 = logging.getLogger('missing_frames_2') 66 | missing_skes_logger2.setLevel(logging.INFO) 67 | missing_skes_logger2.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_2.log'))) 68 | missing_skes_logger2.info('{:^20}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1', 'Actor2')) 69 | 70 | 71 | def denoising_by_length(ske_name, bodies_data): 72 | """ 73 | Denoising data based on the frame length for each bodyID. 74 | Filter out the bodyID which length is less or equal than the predefined threshold. 75 | 76 | """ 77 | noise_info = str() 78 | new_bodies_data = bodies_data.copy() 79 | for (bodyID, body_data) in new_bodies_data.items(): 80 | length = len(body_data['interval']) 81 | if length <= noise_len_thres: 82 | noise_info += 'Filter out: %s, %d (length).\n' % (bodyID, length) 83 | noise_len_logger.info('{}\t{}\t{:.6f}\t{:^6d}'.format(ske_name, bodyID, 84 | body_data['motion'], length)) 85 | del bodies_data[bodyID] 86 | if noise_info != '': 87 | noise_info += '\n' 88 | 89 | return bodies_data, noise_info 90 | 91 | 92 | def get_valid_frames_by_spread(points): 93 | """ 94 | Find the valid (or reasonable) frames (index) based on the spread of X and Y. 95 | 96 | :param points: joints or colors 97 | """ 98 | num_frames = points.shape[0] 99 | valid_frames = [] 100 | for i in range(num_frames): 101 | x = points[i, :, 0] 102 | y = points[i, :, 1] 103 | if (x.max() - x.min()) <= noise_spr_thres1 * (y.max() - y.min()): # 0.8 104 | valid_frames.append(i) 105 | return valid_frames 106 | 107 | 108 | def denoising_by_spread(ske_name, bodies_data): 109 | """ 110 | Denoising data based on the spread of Y value and X value. 111 | Filter out the bodyID which the ratio of noisy frames is higher than the predefined 112 | threshold. 113 | 114 | bodies_data: contains at least 2 bodyIDs 115 | """ 116 | noise_info = str() 117 | denoised_by_spr = False # mark if this sequence has been processed by spread. 118 | 119 | new_bodies_data = bodies_data.copy() 120 | # for (bodyID, body_data) in bodies_data.items(): 121 | for (bodyID, body_data) in new_bodies_data.items(): 122 | if len(bodies_data) == 1: 123 | break 124 | valid_frames = get_valid_frames_by_spread(body_data['joints'].reshape(-1, 25, 3)) 125 | num_frames = len(body_data['interval']) 126 | num_noise = num_frames - len(valid_frames) 127 | if num_noise == 0: 128 | continue 129 | 130 | ratio = num_noise / float(num_frames) 131 | motion = body_data['motion'] 132 | if ratio >= noise_spr_thres2: # 0.69754 133 | del bodies_data[bodyID] 134 | denoised_by_spr = True 135 | noise_info += 'Filter out: %s (spread rate >= %.2f).\n' % (bodyID, noise_spr_thres2) 136 | noise_spr_logger.info('%s\t%s\t%.6f\t%.6f' % (ske_name, bodyID, motion, ratio)) 137 | else: # Update motion 138 | joints = body_data['joints'].reshape(-1, 25, 3)[valid_frames] 139 | body_data['motion'] = min(motion, np.sum(np.var(joints.reshape(-1, 3), axis=0))) 140 | noise_info += '%s: motion %.6f -> %.6f\n' % (bodyID, motion, body_data['motion']) 141 | # TODO: Consider removing noisy frames for each bodyID 142 | 143 | if noise_info != '': 144 | noise_info += '\n' 145 | 146 | return bodies_data, noise_info, denoised_by_spr 147 | 148 | 149 | def denoising_by_motion(ske_name, bodies_data, bodies_motion): 150 | """ 151 | Filter out the bodyID which motion is out of the range of predefined interval 152 | 153 | """ 154 | # Sort bodies based on the motion, return a list of tuples 155 | # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True) 156 | bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True) 157 | 158 | # Reserve the body data with the largest motion 159 | denoised_bodies_data = [(bodies_motion[0][0], bodies_data[bodies_motion[0][0]])] 160 | noise_info = str() 161 | 162 | for (bodyID, motion) in bodies_motion[1:]: 163 | if (motion < noise_mot_thres_lo) or (motion > noise_mot_thres_hi): 164 | noise_info += 'Filter out: %s, %.6f (motion).\n' % (bodyID, motion) 165 | noise_mot_logger.info('{}\t{}\t{:.6f}'.format(ske_name, bodyID, motion)) 166 | else: 167 | denoised_bodies_data.append((bodyID, bodies_data[bodyID])) 168 | if noise_info != '': 169 | noise_info += '\n' 170 | 171 | return denoised_bodies_data, noise_info 172 | 173 | 174 | def denoising_bodies_data(bodies_data): 175 | """ 176 | Denoising data based on some heuristic methods, not necessarily correct for all samples. 177 | 178 | Return: 179 | denoised_bodies_data (list): tuple: (bodyID, body_data). 180 | """ 181 | ske_name = bodies_data['name'] 182 | bodies_data = bodies_data['data'] 183 | 184 | # Step 1: Denoising based on frame length. 185 | bodies_data, noise_info_len = denoising_by_length(ske_name, bodies_data) 186 | 187 | if len(bodies_data) == 1: # only has one bodyID left after step 1 188 | return bodies_data.items(), noise_info_len 189 | 190 | # Step 2: Denoising based on spread. 191 | bodies_data, noise_info_spr, denoised_by_spr = denoising_by_spread(ske_name, bodies_data) 192 | 193 | if len(bodies_data) == 1: 194 | return bodies_data.items(), noise_info_len + noise_info_spr 195 | 196 | bodies_motion = dict() # get body motion 197 | for (bodyID, body_data) in bodies_data.items(): 198 | bodies_motion[bodyID] = body_data['motion'] 199 | # Sort bodies based on the motion 200 | # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True) 201 | bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True) 202 | denoised_bodies_data = list() 203 | for (bodyID, _) in bodies_motion: 204 | denoised_bodies_data.append((bodyID, bodies_data[bodyID])) 205 | 206 | return denoised_bodies_data, noise_info_len + noise_info_spr 207 | 208 | # TODO: Consider denoising further by integrating motion method 209 | 210 | # if denoised_by_spr: # this sequence has been denoised by spread 211 | # bodies_motion = sorted(bodies_motion.items(), lambda x, y: cmp(x[1], y[1]), reverse=True) 212 | # denoised_bodies_data = list() 213 | # for (bodyID, _) in bodies_motion: 214 | # denoised_bodies_data.append((bodyID, bodies_data[bodyID])) 215 | # return denoised_bodies_data, noise_info 216 | 217 | # Step 3: Denoising based on motion 218 | # bodies_data, noise_info = denoising_by_motion(ske_name, bodies_data, bodies_motion) 219 | 220 | # return bodies_data, noise_info 221 | 222 | 223 | def get_one_actor_points(body_data, num_frames): 224 | """ 225 | Get joints and colors for only one actor. 226 | For joints, each frame contains 75 X-Y-Z coordinates. 227 | For colors, each frame contains 25 x 2 (X, Y) coordinates. 228 | """ 229 | joints = np.zeros((num_frames, 75), dtype=np.float32) 230 | colors = np.ones((num_frames, 1, 25, 2), dtype=np.float32) * np.nan 231 | start, end = body_data['interval'][0], body_data['interval'][-1] 232 | joints[start:end + 1] = body_data['joints'].reshape(-1, 75) 233 | colors[start:end + 1, 0] = body_data['colors'] 234 | 235 | return joints, colors 236 | 237 | 238 | def remove_missing_frames(ske_name, joints, colors): 239 | """ 240 | Cut off missing frames which all joints positions are 0s 241 | 242 | For the sequence with 2 actors' data, also record the number of missing frames for 243 | actor1 and actor2, respectively (for debug). 244 | """ 245 | num_frames = joints.shape[0] 246 | num_bodies = colors.shape[1] # 1 or 2 247 | 248 | if num_bodies == 2: # DEBUG 249 | missing_indices_1 = np.where(joints[:, :75].sum(axis=1) == 0)[0] 250 | missing_indices_2 = np.where(joints[:, 75:].sum(axis=1) == 0)[0] 251 | cnt1 = len(missing_indices_1) 252 | cnt2 = len(missing_indices_2) 253 | 254 | start = 1 if 0 in missing_indices_1 else 0 255 | end = 1 if num_frames - 1 in missing_indices_1 else 0 256 | if max(cnt1, cnt2) > 0: 257 | if cnt1 > cnt2: 258 | info = '{}\t{:^10d}\t{:^6d}\t{:^6d}\t{:^5d}\t{:^3d}'.format(ske_name, num_frames, 259 | cnt1, cnt2, start, end) 260 | missing_skes_logger1.info(info) 261 | else: 262 | info = '{}\t{:^10d}\t{:^6d}\t{:^6d}'.format(ske_name, num_frames, cnt1, cnt2) 263 | missing_skes_logger2.info(info) 264 | 265 | # Find valid frame indices that the data is not missing or lost 266 | # For two-subjects action, this means both data of actor1 and actor2 is missing. 267 | valid_indices = np.where(joints.sum(axis=1) != 0)[0] # 0-based index 268 | missing_indices = np.where(joints.sum(axis=1) == 0)[0] 269 | num_missing = len(missing_indices) 270 | 271 | if num_missing > 0: # Update joints and colors 272 | joints = joints[valid_indices] 273 | colors[missing_indices] = np.nan 274 | global missing_count 275 | missing_count += 1 276 | missing_skes_logger.info('{}\t{:^10d}\t{:^11d}'.format(ske_name, num_frames, num_missing)) 277 | 278 | return joints, colors 279 | 280 | 281 | def get_bodies_info(bodies_data): 282 | bodies_info = '{:^17}\t{}\t{:^8}\n'.format('bodyID', 'Interval', 'Motion') 283 | for (bodyID, body_data) in bodies_data.items(): 284 | start, end = body_data['interval'][0], body_data['interval'][-1] 285 | bodies_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), body_data['motion']) 286 | 287 | return bodies_info + '\n' 288 | 289 | 290 | def get_two_actors_points(bodies_data): 291 | """ 292 | Get the first and second actor's joints positions and colors locations. 293 | 294 | # Arguments: 295 | bodies_data (dict): 3 key-value pairs: 'name', 'data', 'num_frames'. 296 | bodies_data['data'] is also a dict, while the key is bodyID, the value is 297 | the corresponding body_data which is also a dict with 4 keys: 298 | - joints: raw 3D joints positions. Shape: (num_frames x 25, 3) 299 | - colors: raw 2D color locations. Shape: (num_frames, 25, 2) 300 | - interval: a list which records the frame indices. 301 | - motion: motion amount 302 | 303 | # Return: 304 | joints, colors. 305 | """ 306 | ske_name = bodies_data['name'] 307 | label = int(ske_name[-2:]) 308 | num_frames = bodies_data['num_frames'] 309 | bodies_info = get_bodies_info(bodies_data['data']) 310 | 311 | bodies_data, noise_info = denoising_bodies_data(bodies_data) # Denoising data 312 | bodies_info += noise_info 313 | 314 | bodies_data = list(bodies_data) 315 | if len(bodies_data) == 1: # Only left one actor after denoising 316 | if label >= 50: # DEBUG: Denoising failed for two-subjects action 317 | fail_logger_2.info(ske_name) 318 | 319 | bodyID, body_data = bodies_data[0] 320 | joints, colors = get_one_actor_points(body_data, num_frames) 321 | bodies_info += 'Main actor: %s' % bodyID 322 | else: 323 | if label < 50: # DEBUG: Denoising failed for one-subject action 324 | fail_logger_1.info(ske_name) 325 | 326 | joints = np.zeros((num_frames, 150), dtype=np.float32) 327 | colors = np.ones((num_frames, 2, 25, 2), dtype=np.float32) * np.nan 328 | 329 | bodyID, actor1 = bodies_data[0] # the 1st actor with largest motion 330 | start1, end1 = actor1['interval'][0], actor1['interval'][-1] 331 | joints[start1:end1 + 1, :75] = actor1['joints'].reshape(-1, 75) 332 | colors[start1:end1 + 1, 0] = actor1['colors'] 333 | actor1_info = '{:^17}\t{}\t{:^8}\n'.format('Actor1', 'Interval', 'Motion') + \ 334 | '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start1, end1]), actor1['motion']) 335 | del bodies_data[0] 336 | 337 | actor2_info = '{:^17}\t{}\t{:^8}\n'.format('Actor2', 'Interval', 'Motion') 338 | start2, end2 = [0, 0] # initial interval for actor2 (virtual) 339 | 340 | while len(bodies_data) > 0: 341 | bodyID, actor = bodies_data[0] 342 | start, end = actor['interval'][0], actor['interval'][-1] 343 | if min(end1, end) - max(start1, start) <= 0: # no overlap with actor1 344 | joints[start:end + 1, :75] = actor['joints'].reshape(-1, 75) 345 | colors[start:end + 1, 0] = actor['colors'] 346 | actor1_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion']) 347 | # Update the interval of actor1 348 | start1 = min(start, start1) 349 | end1 = max(end, end1) 350 | elif min(end2, end) - max(start2, start) <= 0: # no overlap with actor2 351 | joints[start:end + 1, 75:] = actor['joints'].reshape(-1, 75) 352 | colors[start:end + 1, 1] = actor['colors'] 353 | actor2_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion']) 354 | # Update the interval of actor2 355 | start2 = min(start, start2) 356 | end2 = max(end, end2) 357 | del bodies_data[0] 358 | 359 | bodies_info += ('\n' + actor1_info + '\n' + actor2_info) 360 | 361 | with open(osp.join(actors_info_dir, ske_name + '.txt'), 'w') as fw: 362 | fw.write(bodies_info + '\n') 363 | 364 | return joints, colors 365 | 366 | 367 | def get_raw_denoised_data(): 368 | """ 369 | Get denoised data (joints positions and color locations) from raw skeleton sequences. 370 | 371 | For each frame of a skeleton sequence, an actor's 3D positions of 25 joints represented 372 | by an 2D array (shape: 25 x 3) is reshaped into a 75-dim vector by concatenating each 373 | 3-dim (x, y, z) coordinates along the row dimension in joint order. Each frame contains 374 | two actor's joints positions constituting a 150-dim vector. If there is only one actor, 375 | then the last 75 values are filled with zeros. Otherwise, select the main actor and the 376 | second actor based on the motion amount. Each 150-dim vector as a row vector is put into 377 | a 2D numpy array where the number of rows equals the number of valid frames. All such 378 | 2D arrays are put into a list and finally the list is serialized into a cPickle file. 379 | 380 | For the skeleton sequence which contains two or more actors (mostly corresponds to the 381 | last 11 classes), the filename and actors' information are recorded into log files. 382 | For better understanding, also generate RGB+skeleton videos for visualization. 383 | """ 384 | 385 | with open(raw_data_file, 'rb') as fr: # load raw skeletons data 386 | raw_skes_data = pickle.load(fr) 387 | 388 | num_skes = len(raw_skes_data) 389 | print('Found %d available skeleton sequences.' % num_skes) 390 | 391 | raw_denoised_joints = [] 392 | raw_denoised_colors = [] 393 | frames_cnt = [] 394 | 395 | for (idx, bodies_data) in enumerate(raw_skes_data): 396 | ske_name = bodies_data['name'] 397 | print('Processing %s' % ske_name) 398 | num_bodies = len(bodies_data['data']) 399 | 400 | if num_bodies == 1: # only 1 actor 401 | num_frames = bodies_data['num_frames'] 402 | body_data = list(bodies_data['data'].values())[0] 403 | joints, colors = get_one_actor_points(body_data, num_frames) 404 | else: # more than 1 actor, select two main actors 405 | joints, colors = get_two_actors_points(bodies_data) 406 | # Remove missing frames 407 | joints, colors = remove_missing_frames(ske_name, joints, colors) 408 | num_frames = joints.shape[0] # Update 409 | # Visualize selected actors' skeletons on RGB videos. 410 | 411 | raw_denoised_joints.append(joints) 412 | raw_denoised_colors.append(colors) 413 | frames_cnt.append(num_frames) 414 | 415 | if (idx + 1) % 1000 == 0: 416 | print('Processed: %.2f%% (%d / %d), ' % \ 417 | (100.0 * (idx + 1) / num_skes, idx + 1, num_skes) + \ 418 | 'Missing count: %d' % missing_count) 419 | 420 | raw_skes_joints_pkl = osp.join(save_path, 'raw_denoised_joints.pkl') 421 | with open(raw_skes_joints_pkl, 'wb') as f: 422 | pickle.dump(raw_denoised_joints, f, pickle.HIGHEST_PROTOCOL) 423 | 424 | raw_skes_colors_pkl = osp.join(save_path, 'raw_denoised_colors.pkl') 425 | with open(raw_skes_colors_pkl, 'wb') as f: 426 | pickle.dump(raw_denoised_colors, f, pickle.HIGHEST_PROTOCOL) 427 | 428 | frames_cnt = np.array(frames_cnt, dtype=np.int) 429 | np.savetxt(osp.join(save_path, 'frames_cnt.txt'), frames_cnt, fmt='%d') 430 | 431 | print('Saved raw denoised positions of {} frames into {}'.format(np.sum(frames_cnt), 432 | raw_skes_joints_pkl)) 433 | print('Found %d files that have missing data' % missing_count) 434 | 435 | if __name__ == '__main__': 436 | get_raw_denoised_data() 437 | -------------------------------------------------------------------------------- /data/ntu120/get_raw_denoised_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import pickle 7 | import logging 8 | 9 | root_path = './' 10 | raw_data_file = osp.join(root_path, 'raw_data', 'raw_skes_data.pkl') 11 | save_path = osp.join(root_path, 'denoised_data') 12 | 13 | if not osp.exists(save_path): 14 | os.mkdir(save_path) 15 | 16 | rgb_ske_path = osp.join(save_path, 'rgb+ske') 17 | if not osp.exists(rgb_ske_path): 18 | os.mkdir(rgb_ske_path) 19 | 20 | actors_info_dir = osp.join(save_path, 'actors_info') 21 | if not osp.exists(actors_info_dir): 22 | os.mkdir(actors_info_dir) 23 | 24 | missing_count = 0 25 | noise_len_thres = 11 26 | noise_spr_thres1 = 0.8 27 | noise_spr_thres2 = 0.69754 28 | noise_mot_thres_lo = 0.089925 29 | noise_mot_thres_hi = 2 30 | 31 | noise_len_logger = logging.getLogger('noise_length') 32 | noise_len_logger.setLevel(logging.INFO) 33 | noise_len_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_length.log'))) 34 | noise_len_logger.info('{:^20}\t{:^17}\t{:^8}\t{}'.format('Skeleton', 'bodyID', 'Motion', 'Length')) 35 | 36 | noise_spr_logger = logging.getLogger('noise_spread') 37 | noise_spr_logger.setLevel(logging.INFO) 38 | noise_spr_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_spread.log'))) 39 | noise_spr_logger.info('{:^20}\t{:^17}\t{:^8}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion', 'Rate')) 40 | 41 | noise_mot_logger = logging.getLogger('noise_motion') 42 | noise_mot_logger.setLevel(logging.INFO) 43 | noise_mot_logger.addHandler(logging.FileHandler(osp.join(save_path, 'noise_motion.log'))) 44 | noise_mot_logger.info('{:^20}\t{:^17}\t{:^8}'.format('Skeleton', 'bodyID', 'Motion')) 45 | 46 | fail_logger_1 = logging.getLogger('noise_outliers_1') 47 | fail_logger_1.setLevel(logging.INFO) 48 | fail_logger_1.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_1.log'))) 49 | 50 | fail_logger_2 = logging.getLogger('noise_outliers_2') 51 | fail_logger_2.setLevel(logging.INFO) 52 | fail_logger_2.addHandler(logging.FileHandler(osp.join(save_path, 'denoised_failed_2.log'))) 53 | 54 | missing_skes_logger = logging.getLogger('missing_frames') 55 | missing_skes_logger.setLevel(logging.INFO) 56 | missing_skes_logger.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes.log'))) 57 | missing_skes_logger.info('{:^20}\t{}\t{}'.format('Skeleton', 'num_frames', 'num_missing')) 58 | 59 | missing_skes_logger1 = logging.getLogger('missing_frames_1') 60 | missing_skes_logger1.setLevel(logging.INFO) 61 | missing_skes_logger1.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_1.log'))) 62 | missing_skes_logger1.info('{:^20}\t{}\t{}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1', 63 | 'Actor2', 'Start', 'End')) 64 | 65 | missing_skes_logger2 = logging.getLogger('missing_frames_2') 66 | missing_skes_logger2.setLevel(logging.INFO) 67 | missing_skes_logger2.addHandler(logging.FileHandler(osp.join(save_path, 'missing_skes_2.log'))) 68 | missing_skes_logger2.info('{:^20}\t{}\t{}\t{}'.format('Skeleton', 'num_frames', 'Actor1', 'Actor2')) 69 | 70 | 71 | def denoising_by_length(ske_name, bodies_data): 72 | """ 73 | Denoising data based on the frame length for each bodyID. 74 | Filter out the bodyID which length is less or equal than the predefined threshold. 75 | 76 | """ 77 | noise_info = str() 78 | new_bodies_data = bodies_data.copy() 79 | for (bodyID, body_data) in new_bodies_data.items(): 80 | length = len(body_data['interval']) 81 | if length <= noise_len_thres: 82 | noise_info += 'Filter out: %s, %d (length).\n' % (bodyID, length) 83 | noise_len_logger.info('{}\t{}\t{:.6f}\t{:^6d}'.format(ske_name, bodyID, 84 | body_data['motion'], length)) 85 | del bodies_data[bodyID] 86 | if noise_info != '': 87 | noise_info += '\n' 88 | 89 | return bodies_data, noise_info 90 | 91 | 92 | def get_valid_frames_by_spread(points): 93 | """ 94 | Find the valid (or reasonable) frames (index) based on the spread of X and Y. 95 | 96 | :param points: joints or colors 97 | """ 98 | num_frames = points.shape[0] 99 | valid_frames = [] 100 | for i in range(num_frames): 101 | x = points[i, :, 0] 102 | y = points[i, :, 1] 103 | if (x.max() - x.min()) <= noise_spr_thres1 * (y.max() - y.min()): # 0.8 104 | valid_frames.append(i) 105 | return valid_frames 106 | 107 | 108 | def denoising_by_spread(ske_name, bodies_data): 109 | """ 110 | Denoising data based on the spread of Y value and X value. 111 | Filter out the bodyID which the ratio of noisy frames is higher than the predefined 112 | threshold. 113 | 114 | bodies_data: contains at least 2 bodyIDs 115 | """ 116 | noise_info = str() 117 | denoised_by_spr = False # mark if this sequence has been processed by spread. 118 | 119 | new_bodies_data = bodies_data.copy() 120 | # for (bodyID, body_data) in bodies_data.items(): 121 | for (bodyID, body_data) in new_bodies_data.items(): 122 | if len(bodies_data) == 1: 123 | break 124 | valid_frames = get_valid_frames_by_spread(body_data['joints'].reshape(-1, 25, 3)) 125 | num_frames = len(body_data['interval']) 126 | num_noise = num_frames - len(valid_frames) 127 | if num_noise == 0: 128 | continue 129 | 130 | ratio = num_noise / float(num_frames) 131 | motion = body_data['motion'] 132 | if ratio >= noise_spr_thres2: # 0.69754 133 | del bodies_data[bodyID] 134 | denoised_by_spr = True 135 | noise_info += 'Filter out: %s (spread rate >= %.2f).\n' % (bodyID, noise_spr_thres2) 136 | noise_spr_logger.info('%s\t%s\t%.6f\t%.6f' % (ske_name, bodyID, motion, ratio)) 137 | else: # Update motion 138 | joints = body_data['joints'].reshape(-1, 25, 3)[valid_frames] 139 | body_data['motion'] = min(motion, np.sum(np.var(joints.reshape(-1, 3), axis=0))) 140 | noise_info += '%s: motion %.6f -> %.6f\n' % (bodyID, motion, body_data['motion']) 141 | # TODO: Consider removing noisy frames for each bodyID 142 | 143 | if noise_info != '': 144 | noise_info += '\n' 145 | 146 | return bodies_data, noise_info, denoised_by_spr 147 | 148 | 149 | def denoising_by_motion(ske_name, bodies_data, bodies_motion): 150 | """ 151 | Filter out the bodyID which motion is out of the range of predefined interval 152 | 153 | """ 154 | # Sort bodies based on the motion, return a list of tuples 155 | # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True) 156 | bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True) 157 | 158 | # Reserve the body data with the largest motion 159 | denoised_bodies_data = [(bodies_motion[0][0], bodies_data[bodies_motion[0][0]])] 160 | noise_info = str() 161 | 162 | for (bodyID, motion) in bodies_motion[1:]: 163 | if (motion < noise_mot_thres_lo) or (motion > noise_mot_thres_hi): 164 | noise_info += 'Filter out: %s, %.6f (motion).\n' % (bodyID, motion) 165 | noise_mot_logger.info('{}\t{}\t{:.6f}'.format(ske_name, bodyID, motion)) 166 | else: 167 | denoised_bodies_data.append((bodyID, bodies_data[bodyID])) 168 | if noise_info != '': 169 | noise_info += '\n' 170 | 171 | return denoised_bodies_data, noise_info 172 | 173 | 174 | def denoising_bodies_data(bodies_data): 175 | """ 176 | Denoising data based on some heuristic methods, not necessarily correct for all samples. 177 | 178 | Return: 179 | denoised_bodies_data (list): tuple: (bodyID, body_data). 180 | """ 181 | ske_name = bodies_data['name'] 182 | bodies_data = bodies_data['data'] 183 | 184 | # Step 1: Denoising based on frame length. 185 | bodies_data, noise_info_len = denoising_by_length(ske_name, bodies_data) 186 | 187 | if len(bodies_data) == 1: # only has one bodyID left after step 1 188 | return bodies_data.items(), noise_info_len 189 | 190 | # Step 2: Denoising based on spread. 191 | bodies_data, noise_info_spr, denoised_by_spr = denoising_by_spread(ske_name, bodies_data) 192 | 193 | if len(bodies_data) == 1: 194 | return bodies_data.items(), noise_info_len + noise_info_spr 195 | 196 | bodies_motion = dict() # get body motion 197 | for (bodyID, body_data) in bodies_data.items(): 198 | bodies_motion[bodyID] = body_data['motion'] 199 | # Sort bodies based on the motion 200 | # bodies_motion = sorted(bodies_motion.items(), key=lambda x, y: cmp(x[1], y[1]), reverse=True) 201 | bodies_motion = sorted(bodies_motion.items(), key=lambda x: x[1], reverse=True) 202 | denoised_bodies_data = list() 203 | for (bodyID, _) in bodies_motion: 204 | denoised_bodies_data.append((bodyID, bodies_data[bodyID])) 205 | 206 | return denoised_bodies_data, noise_info_len + noise_info_spr 207 | 208 | # TODO: Consider denoising further by integrating motion method 209 | 210 | # if denoised_by_spr: # this sequence has been denoised by spread 211 | # bodies_motion = sorted(bodies_motion.items(), lambda x, y: cmp(x[1], y[1]), reverse=True) 212 | # denoised_bodies_data = list() 213 | # for (bodyID, _) in bodies_motion: 214 | # denoised_bodies_data.append((bodyID, bodies_data[bodyID])) 215 | # return denoised_bodies_data, noise_info 216 | 217 | # Step 3: Denoising based on motion 218 | # bodies_data, noise_info = denoising_by_motion(ske_name, bodies_data, bodies_motion) 219 | 220 | # return bodies_data, noise_info 221 | 222 | 223 | def get_one_actor_points(body_data, num_frames): 224 | """ 225 | Get joints and colors for only one actor. 226 | For joints, each frame contains 75 X-Y-Z coordinates. 227 | For colors, each frame contains 25 x 2 (X, Y) coordinates. 228 | """ 229 | joints = np.zeros((num_frames, 75), dtype=np.float32) 230 | colors = np.ones((num_frames, 1, 25, 2), dtype=np.float32) * np.nan 231 | start, end = body_data['interval'][0], body_data['interval'][-1] 232 | joints[start:end + 1] = body_data['joints'].reshape(-1, 75) 233 | colors[start:end + 1, 0] = body_data['colors'] 234 | 235 | return joints, colors 236 | 237 | 238 | def remove_missing_frames(ske_name, joints, colors): 239 | """ 240 | Cut off missing frames which all joints positions are 0s 241 | 242 | For the sequence with 2 actors' data, also record the number of missing frames for 243 | actor1 and actor2, respectively (for debug). 244 | """ 245 | num_frames = joints.shape[0] 246 | num_bodies = colors.shape[1] # 1 or 2 247 | 248 | if num_bodies == 2: # DEBUG 249 | missing_indices_1 = np.where(joints[:, :75].sum(axis=1) == 0)[0] 250 | missing_indices_2 = np.where(joints[:, 75:].sum(axis=1) == 0)[0] 251 | cnt1 = len(missing_indices_1) 252 | cnt2 = len(missing_indices_2) 253 | 254 | start = 1 if 0 in missing_indices_1 else 0 255 | end = 1 if num_frames - 1 in missing_indices_1 else 0 256 | if max(cnt1, cnt2) > 0: 257 | if cnt1 > cnt2: 258 | info = '{}\t{:^10d}\t{:^6d}\t{:^6d}\t{:^5d}\t{:^3d}'.format(ske_name, num_frames, 259 | cnt1, cnt2, start, end) 260 | missing_skes_logger1.info(info) 261 | else: 262 | info = '{}\t{:^10d}\t{:^6d}\t{:^6d}'.format(ske_name, num_frames, cnt1, cnt2) 263 | missing_skes_logger2.info(info) 264 | 265 | # Find valid frame indices that the data is not missing or lost 266 | # For two-subjects action, this means both data of actor1 and actor2 is missing. 267 | valid_indices = np.where(joints.sum(axis=1) != 0)[0] # 0-based index 268 | missing_indices = np.where(joints.sum(axis=1) == 0)[0] 269 | num_missing = len(missing_indices) 270 | 271 | if num_missing > 0: # Update joints and colors 272 | joints = joints[valid_indices] 273 | colors[missing_indices] = np.nan 274 | global missing_count 275 | missing_count += 1 276 | missing_skes_logger.info('{}\t{:^10d}\t{:^11d}'.format(ske_name, num_frames, num_missing)) 277 | 278 | return joints, colors 279 | 280 | 281 | def get_bodies_info(bodies_data): 282 | bodies_info = '{:^17}\t{}\t{:^8}\n'.format('bodyID', 'Interval', 'Motion') 283 | for (bodyID, body_data) in bodies_data.items(): 284 | start, end = body_data['interval'][0], body_data['interval'][-1] 285 | bodies_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), body_data['motion']) 286 | 287 | return bodies_info + '\n' 288 | 289 | 290 | def get_two_actors_points(bodies_data): 291 | """ 292 | Get the first and second actor's joints positions and colors locations. 293 | 294 | # Arguments: 295 | bodies_data (dict): 3 key-value pairs: 'name', 'data', 'num_frames'. 296 | bodies_data['data'] is also a dict, while the key is bodyID, the value is 297 | the corresponding body_data which is also a dict with 4 keys: 298 | - joints: raw 3D joints positions. Shape: (num_frames x 25, 3) 299 | - colors: raw 2D color locations. Shape: (num_frames, 25, 2) 300 | - interval: a list which records the frame indices. 301 | - motion: motion amount 302 | 303 | # Return: 304 | joints, colors. 305 | """ 306 | ske_name = bodies_data['name'] 307 | label = int(ske_name[-2:]) 308 | num_frames = bodies_data['num_frames'] 309 | bodies_info = get_bodies_info(bodies_data['data']) 310 | 311 | bodies_data, noise_info = denoising_bodies_data(bodies_data) # Denoising data 312 | bodies_info += noise_info 313 | 314 | bodies_data = list(bodies_data) 315 | if len(bodies_data) == 1: # Only left one actor after denoising 316 | if label >= 50: # DEBUG: Denoising failed for two-subjects action 317 | fail_logger_2.info(ske_name) 318 | 319 | bodyID, body_data = bodies_data[0] 320 | joints, colors = get_one_actor_points(body_data, num_frames) 321 | bodies_info += 'Main actor: %s' % bodyID 322 | else: 323 | if label < 50: # DEBUG: Denoising failed for one-subject action 324 | fail_logger_1.info(ske_name) 325 | 326 | joints = np.zeros((num_frames, 150), dtype=np.float32) 327 | colors = np.ones((num_frames, 2, 25, 2), dtype=np.float32) * np.nan 328 | 329 | bodyID, actor1 = bodies_data[0] # the 1st actor with largest motion 330 | start1, end1 = actor1['interval'][0], actor1['interval'][-1] 331 | joints[start1:end1 + 1, :75] = actor1['joints'].reshape(-1, 75) 332 | colors[start1:end1 + 1, 0] = actor1['colors'] 333 | actor1_info = '{:^17}\t{}\t{:^8}\n'.format('Actor1', 'Interval', 'Motion') + \ 334 | '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start1, end1]), actor1['motion']) 335 | del bodies_data[0] 336 | 337 | actor2_info = '{:^17}\t{}\t{:^8}\n'.format('Actor2', 'Interval', 'Motion') 338 | start2, end2 = [0, 0] # initial interval for actor2 (virtual) 339 | 340 | while len(bodies_data) > 0: 341 | bodyID, actor = bodies_data[0] 342 | start, end = actor['interval'][0], actor['interval'][-1] 343 | if min(end1, end) - max(start1, start) <= 0: # no overlap with actor1 344 | joints[start:end + 1, :75] = actor['joints'].reshape(-1, 75) 345 | colors[start:end + 1, 0] = actor['colors'] 346 | actor1_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion']) 347 | # Update the interval of actor1 348 | start1 = min(start, start1) 349 | end1 = max(end, end1) 350 | elif min(end2, end) - max(start2, start) <= 0: # no overlap with actor2 351 | joints[start:end + 1, 75:] = actor['joints'].reshape(-1, 75) 352 | colors[start:end + 1, 1] = actor['colors'] 353 | actor2_info += '{}\t{:^8}\t{:f}\n'.format(bodyID, str([start, end]), actor['motion']) 354 | # Update the interval of actor2 355 | start2 = min(start, start2) 356 | end2 = max(end, end2) 357 | del bodies_data[0] 358 | 359 | bodies_info += ('\n' + actor1_info + '\n' + actor2_info) 360 | 361 | with open(osp.join(actors_info_dir, ske_name + '.txt'), 'w') as fw: 362 | fw.write(bodies_info + '\n') 363 | 364 | return joints, colors 365 | 366 | 367 | def get_raw_denoised_data(): 368 | """ 369 | Get denoised data (joints positions and color locations) from raw skeleton sequences. 370 | 371 | For each frame of a skeleton sequence, an actor's 3D positions of 25 joints represented 372 | by an 2D array (shape: 25 x 3) is reshaped into a 75-dim vector by concatenating each 373 | 3-dim (x, y, z) coordinates along the row dimension in joint order. Each frame contains 374 | two actor's joints positions constituting a 150-dim vector. If there is only one actor, 375 | then the last 75 values are filled with zeros. Otherwise, select the main actor and the 376 | second actor based on the motion amount. Each 150-dim vector as a row vector is put into 377 | a 2D numpy array where the number of rows equals the number of valid frames. All such 378 | 2D arrays are put into a list and finally the list is serialized into a cPickle file. 379 | 380 | For the skeleton sequence which contains two or more actors (mostly corresponds to the 381 | last 11 classes), the filename and actors' information are recorded into log files. 382 | For better understanding, also generate RGB+skeleton videos for visualization. 383 | """ 384 | 385 | with open(raw_data_file, 'rb') as fr: # load raw skeletons data 386 | raw_skes_data = pickle.load(fr) 387 | 388 | num_skes = len(raw_skes_data) 389 | print('Found %d available skeleton sequences.' % num_skes) 390 | 391 | raw_denoised_joints = [] 392 | raw_denoised_colors = [] 393 | frames_cnt = [] 394 | 395 | for (idx, bodies_data) in enumerate(raw_skes_data): 396 | ske_name = bodies_data['name'] 397 | print('Processing %s' % ske_name) 398 | num_bodies = len(bodies_data['data']) 399 | 400 | if num_bodies == 1: # only 1 actor 401 | num_frames = bodies_data['num_frames'] 402 | body_data = list(bodies_data['data'].values())[0] 403 | joints, colors = get_one_actor_points(body_data, num_frames) 404 | else: # more than 1 actor, select two main actors 405 | joints, colors = get_two_actors_points(bodies_data) 406 | # Remove missing frames 407 | joints, colors = remove_missing_frames(ske_name, joints, colors) 408 | num_frames = joints.shape[0] # Update 409 | # Visualize selected actors' skeletons on RGB videos. 410 | 411 | raw_denoised_joints.append(joints) 412 | raw_denoised_colors.append(colors) 413 | frames_cnt.append(num_frames) 414 | 415 | if (idx + 1) % 1000 == 0: 416 | print('Processed: %.2f%% (%d / %d), ' % \ 417 | (100.0 * (idx + 1) / num_skes, idx + 1, num_skes) + \ 418 | 'Missing count: %d' % missing_count) 419 | 420 | raw_skes_joints_pkl = osp.join(save_path, 'raw_denoised_joints.pkl') 421 | with open(raw_skes_joints_pkl, 'wb') as f: 422 | pickle.dump(raw_denoised_joints, f, pickle.HIGHEST_PROTOCOL) 423 | 424 | raw_skes_colors_pkl = osp.join(save_path, 'raw_denoised_colors.pkl') 425 | with open(raw_skes_colors_pkl, 'wb') as f: 426 | pickle.dump(raw_denoised_colors, f, pickle.HIGHEST_PROTOCOL) 427 | 428 | frames_cnt = np.array(frames_cnt, dtype=np.int) 429 | np.savetxt(osp.join(save_path, 'frames_cnt.txt'), frames_cnt, fmt='%d') 430 | 431 | print('Saved raw denoised positions of {} frames into {}'.format(np.sum(frames_cnt), 432 | raw_skes_joints_pkl)) 433 | print('Found %d files that have missing data' % missing_count) 434 | 435 | if __name__ == '__main__': 436 | get_raw_denoised_data() 437 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import inspect 6 | import os 7 | import pickle 8 | import random 9 | import shutil 10 | import sys 11 | import time 12 | from collections import OrderedDict 13 | import traceback 14 | from sklearn.metrics import confusion_matrix 15 | import csv 16 | import numpy as np 17 | import glob 18 | 19 | # torch 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.optim as optim 25 | import yaml 26 | from tensorboardX import SummaryWriter 27 | from tqdm import tqdm 28 | 29 | from torchlight import DictAction 30 | 31 | 32 | import resource 33 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 34 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 35 | 36 | def init_seed(seed): 37 | torch.cuda.manual_seed_all(seed) 38 | torch.manual_seed(seed) 39 | np.random.seed(seed) 40 | random.seed(seed) 41 | # torch.backends.cudnn.enabled = False 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = False 44 | 45 | def import_class(import_str): 46 | mod_str, _sep, class_str = import_str.rpartition('.') 47 | __import__(mod_str) 48 | try: 49 | return getattr(sys.modules[mod_str], class_str) 50 | except AttributeError: 51 | raise ImportError('Class %s cannot be found (%s)' % (class_str, traceback.format_exception(*sys.exc_info()))) 52 | 53 | def str2bool(v): 54 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 55 | return True 56 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 57 | return False 58 | else: 59 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 60 | 61 | class LabelSmoothingCrossEntropy(nn.Module): 62 | def __init__(self, smoothing=0.1): 63 | super(LabelSmoothingCrossEntropy, self).__init__() 64 | self.smoothing = smoothing 65 | 66 | def forward(self, x, target): 67 | confidence = 1. - self.smoothing 68 | logprobs = F.log_softmax(x, dim=-1) 69 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 70 | nll_loss = nll_loss.squeeze(1) 71 | smooth_loss = -logprobs.mean(dim=-1) 72 | loss = confidence * nll_loss + self.smoothing * smooth_loss 73 | return loss.mean() 74 | 75 | 76 | def get_parser(): 77 | # parameter priority: command line > config > default 78 | parser = argparse.ArgumentParser( 79 | description='Spatial Temporal Graph Convolution Network') 80 | parser.add_argument( 81 | '--work-dir', 82 | default='./work_dir/temp', 83 | help='the work folder for storing results') 84 | 85 | parser.add_argument('-model_saved_name', default='') 86 | parser.add_argument( 87 | '--config', 88 | default='./config/nturgbd-cross-view/test_bone.yaml', 89 | help='path to the configuration file') 90 | 91 | # processor 92 | parser.add_argument( 93 | '--phase', default='train', help='must be train or test') 94 | parser.add_argument( 95 | '--save-score', 96 | type=str2bool, 97 | default=False, 98 | help='if ture, the classification score will be stored') 99 | 100 | # visulize and debug 101 | parser.add_argument( 102 | '--seed', type=int, default=1, help='random seed for pytorch') 103 | parser.add_argument( 104 | '--log-interval', 105 | type=int, 106 | default=100, 107 | help='the interval for printing messages (#iteration)') 108 | parser.add_argument( 109 | '--save-interval', 110 | type=int, 111 | default=1, 112 | help='the interval for storing models (#iteration)') 113 | parser.add_argument( 114 | '--save-epoch', 115 | type=int, 116 | default=30, 117 | help='the start epoch to save model (#iteration)') 118 | parser.add_argument( 119 | '--eval-interval', 120 | type=int, 121 | default=5, 122 | help='the interval for evaluating models (#iteration)') 123 | parser.add_argument( 124 | '--print-log', 125 | type=str2bool, 126 | default=True, 127 | help='print logging or not') 128 | parser.add_argument( 129 | '--show-topk', 130 | type=int, 131 | default=[1, 5], 132 | nargs='+', 133 | help='which Top K accuracy will be shown') 134 | 135 | # feeder 136 | parser.add_argument( 137 | '--feeder', default='feeder.feeder', help='data loader will be used') 138 | parser.add_argument( 139 | '--num-worker', 140 | type=int, 141 | default=32, 142 | help='the number of worker for data loader') 143 | parser.add_argument( 144 | '--train-feeder-args', 145 | action=DictAction, 146 | default=dict(), 147 | help='the arguments of data loader for training') 148 | parser.add_argument( 149 | '--test-feeder-args', 150 | action=DictAction, 151 | default=dict(), 152 | help='the arguments of data loader for test') 153 | 154 | # model 155 | parser.add_argument('--model', default=None, help='the model will be used') 156 | parser.add_argument( 157 | '--model-args', 158 | action=DictAction, 159 | default=dict(), 160 | help='the arguments of model') 161 | parser.add_argument( 162 | '--weights', 163 | default=None, 164 | help='the weights for network initialization') 165 | parser.add_argument( 166 | '--ignore-weights', 167 | type=str, 168 | default=[], 169 | nargs='+', 170 | help='the name of weights which will be ignored in the initialization') 171 | 172 | # optim 173 | parser.add_argument( 174 | '--base-lr', type=float, default=0.01, help='initial learning rate') 175 | parser.add_argument( 176 | '--step', 177 | type=int, 178 | default=[20, 40, 60], 179 | nargs='+', 180 | help='the epoch where optimizer reduce the learning rate') 181 | parser.add_argument( 182 | '--device', 183 | type=int, 184 | default=0, 185 | nargs='+', 186 | help='the indexes of GPUs for training or testing') 187 | parser.add_argument('--optimizer', default='SGD', help='type of optimizer') 188 | parser.add_argument( 189 | '--nesterov', type=str2bool, default=False, help='use nesterov or not') 190 | parser.add_argument( 191 | '--batch-size', type=int, default=256, help='training batch size') 192 | parser.add_argument( 193 | '--test-batch-size', type=int, default=256, help='test batch size') 194 | parser.add_argument( 195 | '--start-epoch', 196 | type=int, 197 | default=0, 198 | help='start training from which epoch') 199 | parser.add_argument( 200 | '--num-epoch', 201 | type=int, 202 | default=80, 203 | help='stop training in which epoch') 204 | parser.add_argument( 205 | '--weight-decay', 206 | type=float, 207 | default=0.0005, 208 | help='weight decay for optimizer') 209 | parser.add_argument( 210 | '--lr-ratio', 211 | type=float, 212 | default=0.001, 213 | help='decay rate for learning rate') 214 | parser.add_argument( 215 | '--lr-decay-rate', 216 | type=float, 217 | default=0.1, 218 | help='decay rate for learning rate') 219 | parser.add_argument('--warm_up_epoch', type=int, default=0) 220 | parser.add_argument('--loss-type', type=str, default='CE') 221 | 222 | return parser 223 | 224 | 225 | class Processor(): 226 | """ 227 | Processor for Skeleton-based Action Recgnition 228 | """ 229 | 230 | def __init__(self, arg): 231 | self.arg = arg 232 | self.save_arg() 233 | if arg.phase == 'train': 234 | if not arg.train_feeder_args['debug']: 235 | arg.model_saved_name = os.path.join(arg.work_dir, 'runs') 236 | if os.path.isdir(arg.model_saved_name): 237 | print('log_dir: ', arg.model_saved_name, 'already exist') 238 | answer = input('delete it? y/n:') 239 | if answer == 'y': 240 | shutil.rmtree(arg.model_saved_name) 241 | print('Dir removed: ', arg.model_saved_name) 242 | input('Refresh the website of tensorboard by pressing any keys') 243 | else: 244 | print('Dir not removed: ', arg.model_saved_name) 245 | self.train_writer = SummaryWriter(os.path.join(arg.model_saved_name, 'train'), 'train') 246 | self.val_writer = SummaryWriter(os.path.join(arg.model_saved_name, 'val'), 'val') 247 | else: 248 | self.train_writer = self.val_writer = SummaryWriter(os.path.join(arg.model_saved_name, 'test'), 'test') 249 | self.global_step = 0 250 | # pdb.set_trace() 251 | self.load_model() 252 | 253 | if self.arg.phase == 'model_size': 254 | pass 255 | else: 256 | self.load_optimizer() 257 | self.load_data() 258 | self.lr = self.arg.base_lr 259 | self.best_acc = 0 260 | self.best_acc_epoch = 0 261 | 262 | self.model = self.model.cuda(self.output_device) 263 | 264 | if type(self.arg.device) is list: 265 | if len(self.arg.device) > 1: 266 | self.model = nn.DataParallel( 267 | self.model, 268 | device_ids=self.arg.device, 269 | output_device=self.output_device) 270 | 271 | def load_data(self): 272 | Feeder = import_class(self.arg.feeder) 273 | self.data_loader = dict() 274 | if self.arg.phase == 'train': 275 | self.data_loader['train'] = torch.utils.data.DataLoader( 276 | dataset=Feeder(**self.arg.train_feeder_args), 277 | batch_size=self.arg.batch_size, 278 | shuffle=True, 279 | num_workers=self.arg.num_worker, 280 | drop_last=True, 281 | worker_init_fn=init_seed) 282 | self.data_loader['test'] = torch.utils.data.DataLoader( 283 | dataset=Feeder(**self.arg.test_feeder_args), 284 | batch_size=self.arg.test_batch_size, 285 | shuffle=False, 286 | num_workers=self.arg.num_worker, 287 | drop_last=False, 288 | worker_init_fn=init_seed) 289 | 290 | def load_model(self): 291 | output_device = self.arg.device[0] if type(self.arg.device) is list else self.arg.device 292 | self.output_device = output_device 293 | Model = import_class(self.arg.model) 294 | shutil.copy2(inspect.getfile(Model), self.arg.work_dir) 295 | print(Model) 296 | self.model = Model(**self.arg.model_args) 297 | if self.arg.loss_type == 'CE': 298 | self.loss = nn.CrossEntropyLoss().cuda(output_device) 299 | else: 300 | self.loss = LabelSmoothingCrossEntropy(smoothing=0.1).cuda(output_device) 301 | 302 | if self.arg.weights: 303 | self.global_step = int(arg.weights[:-3].split('-')[-1]) 304 | self.print_log('Load weights from {}.'.format(self.arg.weights)) 305 | if '.pkl' in self.arg.weights: 306 | with open(self.arg.weights, 'r') as f: 307 | weights = pickle.load(f) 308 | else: 309 | weights = torch.load(self.arg.weights) 310 | 311 | weights = OrderedDict([[k.split('module.')[-1], v.cuda(output_device)] for k, v in weights.items()]) 312 | 313 | keys = list(weights.keys()) 314 | for w in self.arg.ignore_weights: 315 | for key in keys: 316 | if w in key: 317 | if weights.pop(key, None) is not None: 318 | self.print_log('Sucessfully Remove Weights: {}.'.format(key)) 319 | else: 320 | self.print_log('Can Not Remove Weights: {}.'.format(key)) 321 | 322 | try: 323 | self.model.load_state_dict(weights) 324 | except: 325 | state = self.model.state_dict() 326 | diff = list(set(state.keys()).difference(set(weights.keys()))) 327 | print('Can not find these weights:') 328 | for d in diff: 329 | print(' ' + d) 330 | state.update(weights) 331 | self.model.load_state_dict(state) 332 | 333 | def load_optimizer(self): 334 | if self.arg.optimizer == 'SGD': 335 | self.optimizer = optim.SGD( 336 | self.model.parameters(), 337 | lr=self.arg.base_lr, 338 | momentum=0.9, 339 | nesterov=self.arg.nesterov, 340 | weight_decay=self.arg.weight_decay) 341 | elif self.arg.optimizer == 'Adam': 342 | self.optimizer = optim.Adam( 343 | self.model.parameters(), 344 | lr=self.arg.base_lr, 345 | weight_decay=self.arg.weight_decay) 346 | else: 347 | raise ValueError() 348 | 349 | self.print_log('using warm up, epoch: {}'.format(self.arg.warm_up_epoch)) 350 | 351 | def save_arg(self): 352 | # save arg 353 | arg_dict = vars(self.arg) 354 | if not os.path.exists(self.arg.work_dir): 355 | os.makedirs(self.arg.work_dir) 356 | with open('{}/config.yaml'.format(self.arg.work_dir), 'w') as f: 357 | f.write(f"# command line: {' '.join(sys.argv)}\n\n") 358 | yaml.dump(arg_dict, f) 359 | 360 | def adjust_learning_rate(self, epoch, idx): 361 | if self.arg.optimizer == 'SGD' or self.arg.optimizer == 'Adam': 362 | if epoch < self.arg.warm_up_epoch: 363 | lr = self.arg.base_lr * (epoch + 1) / self.arg.warm_up_epoch 364 | else: 365 | T_max = len(self.data_loader['train']) * (self.arg.num_epoch - self.arg.warm_up_epoch) 366 | T_cur = len(self.data_loader['train']) * (epoch - self.arg.warm_up_epoch) + idx 367 | 368 | eta_min = self.arg.base_lr * self.arg.lr_ratio 369 | lr = eta_min + 0.5 * (self.arg.base_lr - eta_min) * (1 + np.cos((T_cur / T_max) * np.pi)) 370 | for param_group in self.optimizer.param_groups: 371 | param_group['lr'] = lr 372 | return lr 373 | else: 374 | raise ValueError() 375 | 376 | def print_time(self): 377 | localtime = time.asctime(time.localtime(time.time())) 378 | self.print_log("Local current time : " + localtime) 379 | 380 | def print_log(self, str, print_time=True): 381 | if print_time: 382 | localtime = time.asctime(time.localtime(time.time())) 383 | str = "[ " + localtime + ' ] ' + str 384 | print(str) 385 | if self.arg.print_log: 386 | with open('{}/log.txt'.format(self.arg.work_dir), 'a') as f: 387 | print(str, file=f) 388 | 389 | def record_time(self): 390 | self.cur_time = time.time() 391 | return self.cur_time 392 | 393 | def split_time(self): 394 | split_time = time.time() - self.cur_time 395 | self.record_time() 396 | return split_time 397 | 398 | def train(self, epoch, save_model=False): 399 | self.model.train() 400 | self.print_log('Training epoch: {}'.format(epoch + 1)) 401 | loader = self.data_loader['train'] 402 | 403 | loss_value = [] 404 | acc_value = [] 405 | self.train_writer.add_scalar('epoch', epoch, self.global_step) 406 | self.record_time() 407 | timer = dict(dataloader=0.001, model=0.001, statistics=0.001) 408 | process = tqdm(loader) 409 | 410 | for batch_idx, (data, label, index) in enumerate(process): 411 | 412 | self.adjust_learning_rate(epoch, batch_idx) 413 | 414 | self.global_step += 1 415 | with torch.no_grad(): 416 | data = data.float().cuda(self.output_device) 417 | label = label.long().cuda(self.output_device) 418 | timer['dataloader'] += self.split_time() 419 | 420 | # forward 421 | output = self.model(data) 422 | loss = self.loss(output, label) 423 | # backward 424 | self.optimizer.zero_grad() 425 | loss.backward() 426 | self.optimizer.step() 427 | 428 | loss_value.append(loss.data.item()) 429 | timer['model'] += self.split_time() 430 | 431 | value, predict_label = torch.max(output.data, 1) 432 | acc = torch.mean((predict_label == label.data).float()) 433 | acc_value.append(acc.data.item()) 434 | self.train_writer.add_scalar('acc', acc, self.global_step) 435 | self.train_writer.add_scalar('loss', loss.data.item(), self.global_step) 436 | 437 | # statistics 438 | self.lr = self.optimizer.param_groups[0]['lr'] 439 | self.train_writer.add_scalar('lr', self.lr, self.global_step) 440 | timer['statistics'] += self.split_time() 441 | 442 | # statistics of time consumption and loss 443 | proportion = { 444 | k: '{:02d}%'.format(int(round(v * 100 / sum(timer.values())))) 445 | for k, v in timer.items() 446 | } 447 | self.print_log( 448 | '\tMean training loss: {:.4f}. Mean training acc: {:.2f}%.'.format(np.mean(loss_value), np.mean(acc_value)*100)) 449 | self.print_log('\tLearning Rate: {:.4f}'.format(self.lr)) 450 | self.print_log('\tTime consumption: [Data]{dataloader}, [Network]{model}'.format(**proportion)) 451 | 452 | if save_model: 453 | state_dict = self.model.state_dict() 454 | weights = OrderedDict([[k.split('module.')[-1], v.cpu()] for k, v in state_dict.items()]) 455 | 456 | torch.save(weights, self.arg.model_saved_name + '-' + str(epoch+1) + '-' + str(int(self.global_step)) + '.pt') 457 | 458 | def eval(self, epoch, save_score=False, loader_name=['test'], wrong_file=None, result_file=None): 459 | if wrong_file is not None: 460 | f_w = open(wrong_file, 'w') 461 | if result_file is not None: 462 | f_r = open(result_file, 'w') 463 | self.model.eval() 464 | self.print_log('Eval epoch: {}'.format(epoch + 1)) 465 | for ln in loader_name: 466 | loss_value = [] 467 | score_frag = [] 468 | label_list = [] 469 | pred_list = [] 470 | step = 0 471 | process = tqdm(self.data_loader[ln]) 472 | for batch_idx, (data, label, index) in enumerate(process): 473 | label_list.append(label) 474 | with torch.no_grad(): 475 | data = data.float().cuda(self.output_device) 476 | label = label.long().cuda(self.output_device) 477 | output = self.model(data) 478 | loss = self.loss(output, label) 479 | score_frag.append(output.data.cpu().numpy()) 480 | loss_value.append(loss.data.item()) 481 | 482 | _, predict_label = torch.max(output.data, 1) 483 | pred_list.append(predict_label.data.cpu().numpy()) 484 | step += 1 485 | 486 | if wrong_file is not None or result_file is not None: 487 | predict = list(predict_label.cpu().numpy()) 488 | true = list(label.data.cpu().numpy()) 489 | for i, x in enumerate(predict): 490 | if result_file is not None: 491 | f_r.write(str(x) + ',' + str(true[i]) + '\n') 492 | if x != true[i] and wrong_file is not None: 493 | f_w.write(str(index[i]) + ',' + str(x) + ',' + str(true[i]) + '\n') 494 | score = np.concatenate(score_frag) 495 | loss = np.mean(loss_value) 496 | if 'ucla' in self.arg.feeder: 497 | self.data_loader[ln].dataset.sample_name = np.arange(len(score)) 498 | accuracy = self.data_loader[ln].dataset.top_k(score, 1) 499 | if accuracy > self.best_acc: 500 | self.best_acc = accuracy 501 | self.best_acc_epoch = epoch + 1 502 | 503 | print('Accuracy: ', accuracy, ' model: ', self.arg.model_saved_name) 504 | if self.arg.phase == 'train': 505 | self.val_writer.add_scalar('loss', loss, self.global_step) 506 | self.val_writer.add_scalar('acc', accuracy, self.global_step) 507 | 508 | score_dict = dict( 509 | zip(self.data_loader[ln].dataset.sample_name, score)) 510 | self.print_log('\tMean {} loss of {} batches: {}.'.format( 511 | ln, len(self.data_loader[ln]), np.mean(loss_value))) 512 | for k in self.arg.show_topk: 513 | self.print_log('\tTop{}: {:.2f}%'.format( 514 | k, 100 * self.data_loader[ln].dataset.top_k(score, k))) 515 | 516 | if save_score: 517 | with open('{}/epoch{}_{}_score.pkl'.format( 518 | self.arg.work_dir, epoch + 1, ln), 'wb') as f: 519 | pickle.dump(score_dict, f) 520 | 521 | # acc for each class: 522 | label_list = np.concatenate(label_list) 523 | pred_list = np.concatenate(pred_list) 524 | confusion = confusion_matrix(label_list, pred_list) 525 | list_diag = np.diag(confusion) 526 | list_raw_sum = np.sum(confusion, axis=1) 527 | each_acc = list_diag / list_raw_sum 528 | with open('{}/epoch{}_{}_each_class_acc.csv'.format(self.arg.work_dir, epoch + 1, ln), 'w') as f: 529 | writer = csv.writer(f) 530 | writer.writerow(each_acc) 531 | writer.writerows(confusion) 532 | 533 | def start(self): 534 | if self.arg.phase == 'train': 535 | self.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) 536 | self.global_step = self.arg.start_epoch * len(self.data_loader['train']) / self.arg.batch_size 537 | def count_parameters(model): 538 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 539 | self.print_log(f'# Parameters: {count_parameters(self.model)}') 540 | for epoch in range(self.arg.start_epoch, self.arg.num_epoch): 541 | save_model = (((epoch + 1) % self.arg.save_interval == 0) or ( 542 | epoch + 1 == self.arg.num_epoch)) and (epoch+1) > self.arg.save_epoch 543 | 544 | self.train(epoch, save_model=True) 545 | self.eval(epoch, save_score=True, loader_name=['test']) 546 | 547 | # test the best model 548 | weights_path = glob.glob(os.path.join(self.arg.work_dir, 'runs-'+str(self.best_acc_epoch)+'*'))[0] 549 | weights = torch.load(weights_path) 550 | if type(self.arg.device) is list: 551 | if len(self.arg.device) > 1: 552 | weights = OrderedDict([['module.'+k, v.cuda(self.output_device)] for k, v in weights.items()]) 553 | self.model.load_state_dict(weights) 554 | 555 | wf = weights_path.replace('.pt', '_wrong.txt') 556 | rf = weights_path.replace('.pt', '_right.txt') 557 | self.arg.print_log = False 558 | self.eval(epoch=0, save_score=True, loader_name=['test'], wrong_file=wf, result_file=rf) 559 | self.arg.print_log = True 560 | 561 | 562 | num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 563 | self.print_log(f'Best accuracy: {self.best_acc}') 564 | self.print_log(f'Epoch number: {self.best_acc_epoch}') 565 | self.print_log(f'Model name: {self.arg.work_dir}') 566 | self.print_log(f'Model total number of params: {num_params}') 567 | self.print_log(f'Weight decay: {self.arg.weight_decay}') 568 | self.print_log(f'Base LR: {self.arg.base_lr}') 569 | self.print_log(f'Batch Size: {self.arg.batch_size}') 570 | self.print_log(f'Test Batch Size: {self.arg.test_batch_size}') 571 | self.print_log(f'seed: {self.arg.seed}') 572 | 573 | elif self.arg.phase == 'test': 574 | wf = self.arg.weights.replace('.pt', '_wrong.txt') 575 | rf = self.arg.weights.replace('.pt', '_right.txt') 576 | 577 | if self.arg.weights is None: 578 | raise ValueError('Please appoint --weights.') 579 | self.arg.print_log = False 580 | self.print_log('Model: {}.'.format(self.arg.model)) 581 | self.print_log('Weights: {}.'.format(self.arg.weights)) 582 | self.eval(epoch=0, save_score=self.arg.save_score, loader_name=['test'], wrong_file=wf, result_file=rf) 583 | self.print_log('Done.\n') 584 | 585 | if __name__ == '__main__': 586 | parser = get_parser() 587 | 588 | # load arg form config file 589 | p = parser.parse_args() 590 | if p.config is not None: 591 | with open(p.config, 'r') as f: 592 | default_arg = yaml.load(f) 593 | key = vars(p).keys() 594 | for k in default_arg.keys(): 595 | if k not in key: 596 | print('WRONG ARG: {}'.format(k)) 597 | assert (k in key) 598 | parser.set_defaults(**default_arg) 599 | 600 | arg = parser.parse_args() 601 | init_seed(arg.seed) 602 | processor = Processor(arg) 603 | processor.start() 604 | --------------------------------------------------------------------------------