├── LICENSE ├── README.md ├── dataset ├── __pycache__ │ ├── adhd.cpython-36.pyc │ ├── dhg_skeleton.cpython-36.pyc │ ├── ntu_skeleton.cpython-36.pyc │ ├── preparedata.cpython-36.pyc │ ├── rotation.cpython-36.pyc │ ├── skeleton.cpython-36.pyc │ └── video_data.cpython-36.pyc ├── dhg_skeleton.py ├── normalize_skeletons.py ├── ntu_skeleton.py ├── preparedata.py ├── rotation.py ├── skeleton.py └── video_data.py ├── method_choose ├── __pycache__ │ ├── data_choose.cpython-36.pyc │ ├── loss_choose.cpython-36.pyc │ ├── lr_scheduler_choose.cpython-36.pyc │ ├── model_choose.cpython-36.pyc │ ├── optimizer_choose.cpython-36.pyc │ └── tra_val_choose.cpython-36.pyc ├── data_choose.py ├── loss_choose.py ├── lr_scheduler_choose.py ├── model_choose.py ├── optimizer_choose.py └── tra_val_choose.py ├── model ├── __pycache__ │ └── dstanet.cpython-36.pyc └── dstanet.py ├── prepare ├── ADNI │ └── ADNI_data_gen.py ├── dhg │ ├── gendata.py │ ├── joints.txt │ └── label.txt ├── ntu_120 │ ├── gendata.py │ ├── joints.txt │ └── label.txt ├── ntu_60 │ ├── gendata.py │ ├── joints.txt │ └── label.txt └── shrec │ ├── gendata.py │ ├── joints.txt │ ├── label.txt │ └── label_28.txt ├── requirements.txt ├── train_val_test ├── __pycache__ │ ├── loss.cpython-36.pyc │ ├── parser_args.cpython-36.pyc │ └── train_val_model.cpython-36.pyc ├── config │ └── config.yaml ├── ensemble.py ├── eval.py ├── loss.py ├── optimizer.py ├── parser_args.py ├── train.py └── train_val_model.py └── utility ├── __pycache__ └── log.cpython-36.pyc └── log.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 seuzjj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion_kernel_attention_network 2 | 3 | This is the code for our TMI paper: 4 | 5 | J. Zhang, L. Zhou, L. Wang, M. Liu and D. Shen, "Diffusion Kernel Attention Network for Brain Disorder Classification," in IEEE Transactions on Medical Imaging, vol. 41, no. 10, pp. 2814-2827, Oct. 2022, doi: 10.1109/TMI.2022.3170701. 6 | 7 | Synthetic data can be downloaded from https://drive.google.com/file/d/17ju4HiZOX7MtdKxKv1aNXISdQkCg-qwq/view?usp=sharing 8 | Please note that the synthetic data are provided only to facilitate running the code and release of the original data used in the paper is out of the authority of the author. Please refer to the corresponding data resource. 9 | 10 | Citation Information: 11 | 12 | @ARTICLE{9763540, 13 | author={Zhang, Jianjia and Zhou, Luping and Wang, Lei and Liu, Mengting and Shen, Dinggang}, 14 | journal={IEEE Transactions on Medical Imaging}, 15 | title={Diffusion Kernel Attention Network for Brain Disorder Classification}, 16 | year={2022}, 17 | volume={41}, 18 | number={10}, 19 | pages={2814-2827}, 20 | doi={10.1109/TMI.2022.3170701}} 21 | -------------------------------------------------------------------------------- /dataset/__pycache__/adhd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/dataset/__pycache__/adhd.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dhg_skeleton.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/dataset/__pycache__/dhg_skeleton.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ntu_skeleton.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/dataset/__pycache__/ntu_skeleton.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/preparedata.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/dataset/__pycache__/preparedata.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/rotation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/dataset/__pycache__/rotation.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/skeleton.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/dataset/__pycache__/skeleton.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/video_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/dataset/__pycache__/video_data.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/dhg_skeleton.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import DataLoader, Dataset 3 | from dataset.skeleton import Skeleton, vis 4 | 5 | 6 | edge = ((0, 1), 7 | (1, 2), (2, 3), (3, 4), (4, 5), 8 | (1, 6), (6, 7), (7, 8), (8, 9), 9 | (1, 10), (10, 11), (11, 12), (12, 13), 10 | (1, 14), (14, 15), (15, 16), (16, 17), 11 | (1, 18), (18, 19), (19, 20), (20, 21)) 12 | 13 | 14 | class DHG_SKE(Skeleton): 15 | def __init__(self, data_path, label_path, window_size, final_size, mode='train', decouple_spatial=False, 16 | num_skip_frame=None, random_choose=False, center_choose=False): 17 | super().__init__(data_path, label_path, window_size, final_size, mode, decouple_spatial, num_skip_frame, 18 | random_choose, center_choose) 19 | self.edge = edge 20 | 21 | 22 | def test(data_path, label_path, vid=None, edge=None, is_3d=False, mode='train'): 23 | loader = DataLoader( 24 | dataset=DHG_SKE(data_path, label_path, window_size=150, final_size=128, mode=mode, 25 | random_choose=True, center_choose=False, decouple_spatial=False, num_skip_frame=None), 26 | batch_size=1, 27 | shuffle=False, 28 | num_workers=0) 29 | 30 | labels = open('../prepare/shrec/label_28.txt', 'r').readlines() 31 | for i, (data, label) in enumerate(loader): 32 | if i%100==0: 33 | vis(data[0].numpy(), edge=edge, view=0.2, pause=0.01, title=labels[label.item()].rstrip()) 34 | 35 | sample_name = loader.dataset.sample_name 36 | index = sample_name.index(vid) 37 | if mode != 'train': 38 | data, label, index = loader.dataset[index] 39 | else: 40 | data, label = loader.dataset[index] 41 | # skeleton 42 | vis(data, edge=edge, view=0.2, pause=0.1, title=labels[label].rstrip()) 43 | 44 | 45 | if __name__ == '__main__': 46 | data_path = "/your/path/to/shrec_hand/train_skeleton.pkl" 47 | label_path = "/your/path/to/shrec_hand/train_label_28.pkl" 48 | # data_path = "/your/path/to/dhg_hand_shrec/train_skeleton_ddnet.pkl" 49 | # label_path = "/your/path/to/dhg_hand_shrec/train_label_ddnet_14.pkl" 50 | # test(data_path, label_path, vid=1, edge=edge, is_3d=True, mode='train') 51 | test(data_path, label_path, vid='14_2_27_5', edge=edge, is_3d=True, mode='train') 52 | -------------------------------------------------------------------------------- /dataset/normalize_skeletons.py: -------------------------------------------------------------------------------- 1 | from dataset.rotation import * 2 | import numpy as np 3 | 4 | 5 | def normalize_skeletons(skeleton, origin=None, base_bone=None, zaxis=None, xaxis=None): 6 | ''' 7 | 8 | :param skeleton: M, T, V, C(x, y, z) 9 | :param origin: int 10 | :param base_bone: [int, int] 11 | :param zaxis: [int, int] 12 | :param xaxis: [int, int] 13 | :return: 14 | ''' 15 | 16 | M, T, V, C = skeleton.shape 17 | 18 | # print('move skeleton to begin') 19 | if skeleton.sum() == 0: 20 | raise RuntimeError('null skeleton') 21 | if skeleton[:, 0].sum() == 0: # pad top null frames 22 | index = (skeleton.sum(-1).sum(-1).sum(0) != 0) 23 | tmp = skeleton[:, index].copy() 24 | skeleton *= 0 25 | skeleton[:, :tmp.shape[1]] = tmp 26 | 27 | if origin is not None: 28 | # print('sub the center joint #0 (wrist)') 29 | main_body_center = skeleton[0, 0, origin].copy() # c 30 | for i_p, person in enumerate(skeleton): 31 | if person.sum() == 0: 32 | continue 33 | mask = (person.sum(-1) != 0).reshape(T, V, 1) # only for none zero frames 34 | skeleton[i_p] = (skeleton[i_p] - main_body_center) * mask 35 | 36 | if base_bone is not None: 37 | # skeleton /= base_bone 38 | # div base bone lenghth 39 | t = 0 40 | main_body_spine = 0 41 | while t < T and main_body_spine == 0: 42 | main_body_spine = np.linalg.norm(skeleton[0, t, base_bone[1]] - skeleton[0, t, base_bone[0]]) 43 | t += 1 44 | # print(main_body_spine) 45 | if main_body_spine == 0: 46 | print('zero bone') 47 | else: 48 | skeleton /= main_body_spine 49 | 50 | if zaxis is not None: 51 | # print('parallel the bone between wrist(jpt 0) and MMCP(jpt 1) of the first person to the z axis') 52 | joint_bottom = skeleton[0, 0, zaxis[0]] 53 | joint_top = skeleton[0, 0, zaxis[1]] 54 | axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) 55 | angle = angle_between(joint_top - joint_bottom, [0, 0, 1]) 56 | matrix_z = rotation_matrix(axis, angle) 57 | for i_p, person in enumerate(skeleton): 58 | if person.sum() == 0: 59 | continue 60 | for i_f, frame in enumerate(person): 61 | if frame.sum() == 0: 62 | continue 63 | for i_j, joint in enumerate(frame): 64 | skeleton[i_p, i_f, i_j] = np.dot(matrix_z, joint) 65 | 66 | if xaxis is not None: 67 | # print('parallel the bone in x plane between wrist(jpt 0) and TMCP(jpt 1) of the first person to the x axis') 68 | joint_left = skeleton[0, 0, xaxis[0]].copy() 69 | joint_right = skeleton[0, 0, xaxis[1]].copy() 70 | # axis = np.cross(joint_right - joint_left, [1, 0, 0]) 71 | joint_left[2] = 0 72 | joint_right[2] = 0 # rotate by zaxis 73 | axis = np.cross(joint_right - joint_left, [1, 0, 0]) 74 | angle = angle_between(joint_right - joint_left, [1, 0, 0]) 75 | matrix_x = rotation_matrix(axis, angle) 76 | for i_p, person in enumerate(skeleton): 77 | if person.sum() == 0: 78 | continue 79 | for i_f, frame in enumerate(person): 80 | if frame.sum() == 0: 81 | continue 82 | for i_j, joint in enumerate(frame): 83 | skeleton[i_p, i_f, i_j] = np.dot(matrix_x, joint) 84 | 85 | # print(skeleton[0, 0, zaxis[0]], skeleton[0, 0, zaxis[1]], skeleton[0, 0, xaxis[0]], skeleton[0, 0, xaxis[1]]) 86 | skeleton = np.transpose(skeleton, [3, 1, 2, 0]) # mtvc - ctvm 87 | return skeleton 88 | -------------------------------------------------------------------------------- /dataset/ntu_skeleton.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | from torch.utils.data import DataLoader, Dataset 4 | #from dataset.video_data import * 5 | from dataset.skeleton import Skeleton, vis 6 | 7 | edge = ((0, 1), (1, 20), (2, 20), (3, 2), (4, 20), (5, 4), (6, 5), 8 | (7, 6), (8, 20), (9, 8), (10, 9), (11, 10), (12, 0), 9 | (13, 12), (14, 13), (15, 14), (16, 0), (17, 16), (18, 17), 10 | (19, 18), (21, 22), (22, 7), (23, 24), (24, 11)) 11 | 12 | 13 | class NTU_SKE(Skeleton): 14 | def __init__(self, data_path, label_path, window_size, final_size, mode='train', decouple_spatial=False, 15 | num_skip_frame=None, random_choose=False, center_choose=False, random_noise=False, random_scale=False): 16 | super().__init__(data_path, label_path, window_size, final_size, mode, decouple_spatial, num_skip_frame, 17 | random_choose, center_choose, random_noise, random_scale) 18 | self.edge = edge 19 | 20 | def load_data(self): 21 | with open(self.label_path, 'rb') as f: 22 | self.sample_name, self.label = pickle.load(f) 23 | 24 | # load data 25 | self.data = np.load(self.data_path, mmap_mode='r')[:, :3] # NCTVM 26 | print(self.data.shape) 27 | 28 | 29 | def test(data_path, label_path, vid=None, edge=None, is_3d=False, mode='train'): 30 | dataset = NTU_SKE(data_path, label_path, window_size=48, final_size=32, mode=mode, 31 | random_choose=True, center_choose=False) 32 | loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 33 | 34 | labels = open('../prepare/ntu_120/label.txt', 'r').readlines() 35 | for i, (data, label) in enumerate(loader): 36 | if i%1000==0: 37 | vis(data[0].numpy(), edge=edge, view=1, pause=0.01, title=labels[label.item()].rstrip()) 38 | 39 | sample_name = loader.dataset.sample_name 40 | sample_id = [name.split('.')[0] for name in sample_name] 41 | index = sample_id.index(vid) 42 | if mode != 'train': 43 | data, label, index = loader.dataset[index] 44 | else: 45 | data, label = loader.dataset[index] 46 | # skeleton 47 | vis(data, edge=edge, view=1, pause=0.1) 48 | 49 | 50 | if __name__ == '__main__': 51 | data_path = "/your/path/to/ntu/xsub/val_data_joint.npy" 52 | label_path = "/your/path/to/ntu/xsub/val_label.pkl" 53 | test(data_path, label_path, vid='S004C001P003R001A032', edge=edge, is_3d=True, mode='train') 54 | # data_path = "/your/path/to/ntu/xsub/val_data_joint.npy" 55 | # label_path = "/your/path/to/ntu/xsub/val_label.pkl" 56 | # test(data_path, label_path, vid='S004C001P003R001A032', edge=edge, is_3d=True, mode='train') 57 | -------------------------------------------------------------------------------- /dataset/preparedata.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | from torch.utils.data import DataLoader, Dataset 4 | import numpy as np 5 | #from dataset.video_data import * 6 | from dataset.skeleton import Skeleton, vis, Skeleton_val 7 | 8 | 9 | class adni(Skeleton): 10 | def __init__(self, data_path, label_path, window_size, final_size,mode='train', decouple_spatial=False, 11 | num_skip_frame=None, random_choose=False, center_choose=False, random_noise=False, random_scale=False): 12 | super().__init__(data_path, label_path, window_size, final_size, mode, decouple_spatial, num_skip_frame, 13 | random_choose, center_choose, random_noise, random_scale) 14 | #self.edge = edge 15 | 16 | def load_data(self): 17 | with open(self.label_path, 'rb') as f: 18 | self.sample_name, self.label = pickle.load(f) 19 | # load data 20 | self.data = np.load(self.data_path, mmap_mode='r') # NCTVM 21 | 22 | 23 | 24 | class adni_val(Skeleton_val): 25 | def __init__(self, data_path, label_path, window_size, final_size, augtimes=1,mode='train', decouple_spatial=False, 26 | num_skip_frame=None, random_choose=False, center_choose=False, random_noise=False, random_scale=False): 27 | super().__init__(data_path, label_path, window_size, final_size, augtimes, mode, decouple_spatial, num_skip_frame, 28 | random_choose, center_choose, random_noise, random_scale) 29 | #self.edge = edge 30 | 31 | def load_data(self): 32 | with open(self.label_path, 'rb') as f: 33 | self.sample_name, self.label = pickle.load(f) 34 | 35 | # load data 36 | self.data = np.load(self.data_path, mmap_mode='r') # NCTVM 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | data_path = "/your/path/to/ntu/xsub/val_data_joint.npy" 42 | label_path = "/your/path/to/ntu/xsub/val_label.pkl" 43 | test(data_path, label_path, vid='S004C001P003R001A032', edge=edge, is_3d=True, mode='train') 44 | # data_path = "/your/path/to/ntu/xsub/val_data_joint.npy" 45 | # label_path = "/your/path/to/ntu/xsub/val_label.pkl" 46 | # test(data_path, label_path, vid='S004C001P003R001A032', edge=edge, is_3d=True, mode='train') 47 | -------------------------------------------------------------------------------- /dataset/rotation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | def rotation_matrix(axis, theta): 6 | """ 7 | Return the rotation matrix associated with counterclockwise rotation about 8 | the given axis by theta radians. 9 | """ 10 | if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: 11 | return np.eye(3) 12 | axis = np.asarray(axis) 13 | axis = axis / math.sqrt(np.dot(axis, axis)) 14 | a = math.cos(theta / 2.0) 15 | b, c, d = -axis * math.sin(theta / 2.0) 16 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 17 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 18 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 19 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 20 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 21 | 22 | 23 | def unit_vector(vector): 24 | """ Returns the unit vector of the vector. """ 25 | return vector / np.linalg.norm(vector) 26 | 27 | 28 | def angle_between(v1, v2): 29 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 30 | 31 | >>> angle_between((1, 0, 0), (0, 1, 0)) 32 | 1.5707963267948966 33 | >>> angle_between((1, 0, 0), (1, 0, 0)) 34 | 0.0 35 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 36 | 3.141592653589793 37 | """ 38 | if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: 39 | return 0 40 | v1_u = unit_vector(v1) 41 | v2_u = unit_vector(v2) 42 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 43 | 44 | 45 | def x_rotation(vector, theta): 46 | """Rotates 3-D vector around x-axis""" 47 | R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) 48 | return np.dot(R, vector) 49 | 50 | 51 | def y_rotation(vector, theta): 52 | """Rotates 3-D vector around y-axis""" 53 | R = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) 54 | return np.dot(R, vector) 55 | 56 | 57 | def z_rotation(vector, theta): 58 | """Rotates 3-D vector around z-axis""" 59 | R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) 60 | return np.dot(R, vector) 61 | 62 | 63 | if __name__ == '__main__': 64 | print(angle_between([0, 0, 1], [0, 1, 0])) 65 | print(angle_between([0, 1, 0], [0, 0, 1])) -------------------------------------------------------------------------------- /dataset/skeleton.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import pickle 5 | import torch 6 | import random 7 | from torch.utils.data import DataLoader, Dataset 8 | from dataset.video_data import * 9 | 10 | trainsiteindex = [list(range(83)),list(range(83,299)),list(range(299,347)),list(range(347,426)),list(range(426,511)), list(range(0,709))] 11 | testsiteindex = [list(range(37-26)),list(range(37-26,78-26)),list(range(78-26,103-26)),list(range(103-26,137-26)), list(range(137-26,188-26)), list(range(188-26,197-26))] 12 | 13 | class Skeleton(Dataset): 14 | def __init__(self, data_path, label_path, window_size, final_size, 15 | mode='train', decouple_spatial=False, num_skip_frame=None, 16 | random_choose=False, center_choose=False, random_noise=False, random_scale=False, site=0): 17 | self.data_path = data_path 18 | self.label_path = label_path 19 | self.mode = mode 20 | self.random_choose = random_choose 21 | self.center_choose = center_choose 22 | self.window_size = window_size 23 | self.final_size = final_size 24 | self.num_skip_frame = num_skip_frame 25 | self.decouple_spatial = decouple_spatial 26 | self.edge = None 27 | 28 | self.random_noise = random_noise 29 | self.random_scale = random_scale 30 | #self.augtimes = augtimes 31 | self.site = site 32 | self.load_data() 33 | 34 | def load_data(self): 35 | with open(self.label_path, 'rb') as f: 36 | self.sample_name, self.label = pickle.load(f) 37 | with open(self.data_path, 'rb') as f: 38 | self.data = pickle.load(f) 39 | if self.site: 40 | self.sample_name = self.sample_name[trainsiteindex[self.site]] 41 | self.label = self.label[trainsiteindex[self.site]] 42 | self.data = self.data[trainsiteindex[self.site]] 43 | 44 | def __len__(self): 45 | return len(self.label) 46 | 47 | def __getitem__(self, index): 48 | data_numpy = self.data[index] 49 | label = int(self.label[index]) 50 | sample_name = self.sample_name[index] 51 | data_numpy = np.array(data_numpy) # nctv 52 | data_numpy = data_numpy[:, data_numpy.sum(0).sum(-1).sum(-1) != 0] # CTVM 53 | # data transform 54 | if self.decouple_spatial: 55 | data_numpy = decouple_spatial(data_numpy, edges=self.edge) 56 | if self.num_skip_frame is not None: 57 | velocity = decouple_temporal(data_numpy, self.num_skip_frame) 58 | C, T, V, M = velocity.shape 59 | data_numpy = np.concatenate((velocity, np.zeros((C, 1, V, M))), 1) 60 | 61 | # data_numpy = pad_recurrent_fix(data_numpy, self.window_size) # if short: pad recurrent 62 | # data_numpy = uniform_sample_np(data_numpy, self.window_size) # if long: resize 63 | ############################################################################# 64 | if self.random_choose: 65 | data_numpy = random_sample_np(data_numpy, self.window_size) 66 | # data_numpy = random_choose_simple(data_numpy, self.final_size) 67 | else: 68 | data_numpy = uniform_sample_np(data_numpy, self.window_size) 69 | ############################################################################# 70 | 71 | if self.center_choose: 72 | # data_numpy = uniform_sample_np(data_numpy, self.final_size) 73 | data_numpy = random_choose_simple(data_numpy, self.final_size, center=True) 74 | else: 75 | data_numpy = random_choose_simple(data_numpy, self.final_size) 76 | 77 | if self.mode == 'train': 78 | return data_numpy.astype(np.float32), label 79 | else: 80 | return data_numpy.astype(np.float32), label, sample_name 81 | 82 | def top_k(self, score, top_k): 83 | rank = score.argsort() 84 | hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)] 85 | return sum(hit_top_k) * 1.0 / len(hit_top_k) 86 | 87 | 88 | 89 | class Skeleton_val(Dataset): 90 | def __init__(self, data_path, label_path, window_size, final_size,augtimes=1, 91 | mode='train', decouple_spatial=False, num_skip_frame=None, 92 | random_choose=False, center_choose=False, random_noise=False, random_scale=False, site=0): 93 | self.data_path = data_path 94 | self.label_path = label_path 95 | self.mode = mode 96 | self.random_choose = random_choose 97 | self.center_choose = center_choose 98 | self.window_size = window_size 99 | self.final_size = final_size 100 | self.num_skip_frame = num_skip_frame 101 | self.decouple_spatial = decouple_spatial 102 | self.edge = None 103 | 104 | self.random_noise = random_noise 105 | self.random_scale = random_scale 106 | self.augtimes = augtimes 107 | self.site = site 108 | self.load_data() 109 | 110 | def load_data(self): 111 | with open(self.label_path, 'rb') as f: 112 | self.sample_name, self.label = pickle.load(f) 113 | with open(self.data_path, 'rb') as f: 114 | self.data = pickle.load(f) 115 | if self.site: 116 | self.sample_name = self.sample_name[testsiteindex[self.site]] 117 | self.label = self.label[testsiteindex[self.site]] 118 | self.data = self.data[testsiteindex[self.site]] 119 | 120 | def __len__(self): 121 | return len(self.label) 122 | 123 | def __getitem__(self, index): 124 | data_numpy = self.data[index] 125 | label = int(self.label[index]) 126 | sample_name = self.sample_name[index] 127 | data_numpy = np.array(data_numpy) # nctv 128 | data_numpy = data_numpy[:, data_numpy.sum(0).sum(-1).sum(-1) != 0] # CTVM 129 | # data transform 130 | if self.decouple_spatial: 131 | data_numpy = decouple_spatial(data_numpy, edges=self.edge) 132 | if self.num_skip_frame is not None: 133 | velocity = decouple_temporal(data_numpy, self.num_skip_frame) 134 | C, T, V, M = velocity.shape 135 | data_numpy = np.concatenate((velocity, np.zeros((C, 1, V, M))), 1) 136 | 137 | # data_numpy = pad_recurrent_fix(data_numpy, self.window_size) # if short: pad recurrent 138 | # data_numpy = uniform_sample_np(data_numpy, self.window_size) # if long: resize 139 | ############################################################################# 140 | if self.random_choose: 141 | data_numpy = random_sample_np(data_numpy, self.window_size) 142 | # data_numpy = random_choose_simple(data_numpy, self.final_size) 143 | else: 144 | data_numpy = uniform_sample_np(data_numpy, self.window_size) 145 | ############################################################################# 146 | 147 | if self.center_choose: 148 | # data_numpy = uniform_sample_np(data_numpy, self.final_size) 149 | data_numpy = random_choose_simple(data_numpy, self.final_size, center=True) 150 | else: 151 | data_numpy = random_choose_simple(data_numpy, self.final_size) 152 | 153 | if self.mode == 'train': 154 | return data_numpy.astype(np.float32), label 155 | else: 156 | return data_numpy.astype(np.float32), label, sample_name 157 | 158 | def top_k(self, score, top_k): 159 | rank = score.argsort() 160 | hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)] 161 | return sum(hit_top_k) * 1.0 / len(hit_top_k) 162 | 163 | 164 | def vis(data, edge, is_3d=True, pause=0.01, view=0.25, title=''): 165 | import os 166 | 167 | os.environ['DISPLAY'] = 'localhost:10.0' 168 | import matplotlib.pyplot as plt 169 | import matplotlib 170 | 171 | matplotlib.use('Qt5Agg') 172 | C, T, V, M = data.shape 173 | 174 | plt.ion() 175 | fig = plt.figure() 176 | if is_3d: 177 | from mpl_toolkits.mplot3d import Axes3D 178 | ax = fig.add_subplot(111, projection='3d') 179 | else: 180 | ax = fig.add_subplot(111) 181 | ax.set_title(title) 182 | p_type = ['b-', 'g-', 'r-', 'c-', 'm-', 'y-', 'k-', 'k-', 'k-', 'k-'] 183 | import sys 184 | from os import path 185 | sys.path.append( 186 | path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) 187 | pose = [] 188 | for m in range(M): 189 | a = [] 190 | for i in range(len(edge)): 191 | if is_3d: 192 | a.append(ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0]) 193 | else: 194 | a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0]) 195 | pose.append(a) 196 | ax.axis([-view, view, -view, view]) 197 | if is_3d: 198 | ax.set_zlim3d(-view, view) 199 | for t in range(T): 200 | for m in range(M): 201 | for i, (v1, v2) in enumerate(edge): 202 | x1 = data[:2, t, v1, m] 203 | x2 = data[:2, t, v2, m] 204 | if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1: 205 | pose[m][i].set_xdata(data[0, t, [v1, v2], m]) 206 | pose[m][i].set_ydata(data[1, t, [v1, v2], m]) 207 | if is_3d: 208 | pose[m][i].set_3d_properties(data[2, t, [v1, v2], m]) 209 | fig.canvas.draw() 210 | plt.pause(pause) 211 | plt.close() 212 | plt.ioff() 213 | 214 | -------------------------------------------------------------------------------- /dataset/video_data.py: -------------------------------------------------------------------------------- 1 | #import cv2 2 | from numpy import random as nprand 3 | import random 4 | #import imutils 5 | import torch 6 | from dataset.rotation import * 7 | 8 | 9 | def video_aug(video, 10 | brightness_delta=32, 11 | contrast_range=(0.5, 1.5), 12 | saturation_range=(0.5, 1.5), 13 | angle_range=(-30, 30), 14 | hue_delta=18): 15 | ''' 16 | 17 | :param video: list of images 18 | :param brightness_delta: 19 | :param contrast_range: 20 | :param saturation_range: 21 | :param angle_range: 22 | :param hue_delta: 23 | :return: 24 | ''' 25 | # brightness_delta = brightness_delta 26 | # contrast_lower, contrast_upper = contrast_range 27 | # saturation_lower, saturation_upper = saturation_range 28 | # angle_lower, angle_upper = angle_range 29 | # hue_delta = hue_delta 30 | # for index, img in enumerate(video): 31 | # video[index] = img.astype(np.float32) 32 | # 33 | # # random brightness 34 | # if nprand.randint(2): 35 | # delta = nprand.uniform(-brightness_delta, 36 | # brightness_delta) 37 | # for index, img in enumerate(video): 38 | # video[index] += delta 39 | # 40 | # # random rotate 41 | # if nprand.randint(2): 42 | # angle = nprand.uniform(angle_lower, 43 | # angle_upper) 44 | # for index, img in enumerate(video): 45 | # video[index] = imutils.rotate(img, angle) 46 | # 47 | # # if nprand.randint(2): 48 | # # alpha = nprand.uniform(contrast_lower, 49 | # # contrast_upper) 50 | # # for index, img in enumerate(video): 51 | # # video[index] *= alpha 52 | # 53 | # # mode == 0 --> do random contrast first 54 | # # mode == 1 --> do random contrast last 55 | # mode = nprand.randint(2) 56 | # if mode == 1: 57 | # if nprand.randint(2): 58 | # alpha = nprand.uniform(contrast_lower, 59 | # contrast_upper) 60 | # for index, img in enumerate(video): 61 | # video[index] *= alpha 62 | # 63 | # # convert color from BGR to HSV 64 | # for index, img in enumerate(video): 65 | # video[index] = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 66 | # 67 | # # random saturation 68 | # if nprand.randint(2): 69 | # for index, img in enumerate(video): 70 | # video[index][..., 1] *= nprand.uniform(saturation_lower, 71 | # saturation_upper) 72 | # 73 | # # random hue 74 | # if nprand.randint(2): 75 | # for index, img in enumerate(video): 76 | # video[index][..., 0] += nprand.uniform(-hue_delta, hue_delta) 77 | # video[index][..., 0][video[index][..., 0] > 360] -= 360 78 | # video[index][..., 0][video[index][..., 0] < 0] += 360 79 | # 80 | # # convert color from HSV to BGR 81 | # for index, img in enumerate(video): 82 | # video[index] = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 83 | # 84 | # # random contrast 85 | # if mode == 0: 86 | # if nprand.randint(2): 87 | # alpha = nprand.uniform(contrast_lower, 88 | # contrast_upper) 89 | # for index, img in enumerate(video): 90 | # video[index] *= alpha 91 | # 92 | # # randomly swap channels 93 | # # if nprand.randint(2): 94 | # # for index, img in enumerate(video): 95 | # # video[index] = img[..., nprand.permutation(3)] 96 | # 97 | # return video 98 | 99 | 100 | def expand_list(l, length): 101 | if len(l) < length: 102 | while len(l) < length: 103 | tmp = [] 104 | [tmp.extend([x, x]) for x in l] 105 | l = tmp 106 | return sample_uniform_list(l, length) 107 | else: 108 | return l 109 | 110 | 111 | def sample_uniform_list(l, length): 112 | if len(l)==length: 113 | return l 114 | interval = len(l) / length 115 | uniform_list = [int(i * interval) for i in range(length)] 116 | tmp = [l[x] for x in uniform_list] 117 | return tmp 118 | 119 | 120 | def uniform_sample_np(data_numpy, size): 121 | C, T, V, M = data_numpy.shape 122 | if T == size: 123 | return data_numpy 124 | interval = T / size 125 | uniform_list = [int(i * interval) for i in range(size)] 126 | return data_numpy[:, uniform_list] 127 | 128 | 129 | def random_sample_np(data_numpy, size): 130 | C, T, V, M = data_numpy.shape 131 | if T == size: 132 | return data_numpy 133 | interval = int(np.ceil(size / T)) 134 | random_list = sorted(random.sample(list(range(T))*interval, size)) 135 | return data_numpy[:, random_list] 136 | 137 | def add_random_noise(data_numpy, scale=0.1): 138 | return data_numpy + scale * torch.randn(data_numpy.shape) 139 | 140 | def random_scale(data_numpy, range=0.2): 141 | scale = random.uniform(1-range,1+range) 142 | return scale * data_numpy 143 | 144 | def random_choose_simple(data_numpy, size, center=False): 145 | # input: C,T,V,M 随机选择其中一段,不是很合理。因为有0 146 | C, T, V, M = data_numpy.shape 147 | if size < 0: 148 | assert 'resize shape is not right' 149 | if T == size: 150 | return data_numpy 151 | elif T < size: 152 | return data_numpy 153 | else: 154 | if center: 155 | begin = (T - size) // 2 156 | else: 157 | begin = random.randint(0, T - size) 158 | return data_numpy[:, begin:begin + size, :, :] 159 | 160 | 161 | def random_move(data_numpy, 162 | angle_candidate=[-10., -5., 0., 5., 10.], 163 | scale_candidate=[0.9, 1.0, 1.1], 164 | transform_candidate=[0.0], 165 | move_time_candidate=[1]): 166 | # input: C,T,V,M 167 | C, T, V, M = data_numpy.shape 168 | move_time = random.choice(move_time_candidate) 169 | node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) # 需要变换的帧的段数 0, 16, 32 170 | node = np.append(node, T) 171 | num_node = len(node) 172 | 173 | A = np.random.choice(angle_candidate, num_node) 174 | S = np.random.choice(scale_candidate, num_node) 175 | T_x = np.random.choice(transform_candidate, num_node) 176 | T_y = np.random.choice(transform_candidate, num_node) 177 | 178 | a = np.zeros(T) 179 | s = np.zeros(T) 180 | t_x = np.zeros(T) 181 | t_y = np.zeros(T) 182 | 183 | # linspace 184 | for i in range(num_node - 1): # 使得每一帧的旋转都不一样 185 | a[node[i]:node[i + 1]] = np.linspace( 186 | A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 187 | s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], 188 | node[i + 1] - node[i]) 189 | t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], 190 | node[i + 1] - node[i]) 191 | t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], 192 | node[i + 1] - node[i]) 193 | 194 | theta = np.array([[np.cos(a) * s, -np.sin(a) * s], 195 | [np.sin(a) * s, np.cos(a) * s]]) # xuanzhuan juzhen 196 | 197 | # perform transformation 198 | for i_frame in range(T): 199 | xy = data_numpy[0:2, i_frame, :, :] 200 | new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) 201 | new_xy[0] += t_x[i_frame] 202 | new_xy[1] += t_y[i_frame] # pingyi bianhuan 203 | data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) 204 | 205 | return data_numpy 206 | 207 | 208 | def random_move_whole(data_numpy, agx=0, agy=0, s=1): 209 | # input: C,T,V,M 210 | C, T, V, M = data_numpy.shape 211 | data_numpy = data_numpy.transpose((1, 2, 3, 0)).reshape(-1, C) 212 | 213 | agx = math.radians(agx) 214 | agy = math.radians(agy) 215 | Rx = np.asarray([[1, 0, 0], [0, math.cos(agx), math.sin(agx)], [0, -math.sin(agx), math.cos(agx)]]) 216 | Ry = np.asarray([[math.cos(agy), 0, -math.sin(agy)], [0, 1, 0], [math.sin(agy), 0, math.cos(agy)]]) 217 | Ss = np.asarray([[s, 0, 0], [0, s, 0], [0, 0, s]]) 218 | 219 | data_numpy = np.dot(np.reshape(data_numpy, (-1, 3)), np.dot(Ry, np.dot(Rx, Ss))) 220 | data_numpy = data_numpy.reshape((T, V, M, C)).transpose((3, 0, 1, 2)) 221 | return data_numpy.astype(np.float32) 222 | 223 | 224 | def rot_to_fix_angle_fstframe(skeleton, jpts=[0, 1], axis=[0, 0, 1], frame=0, person=0): 225 | ''' 226 | :param skeleton: c t v m 227 | :param axis: 001 for z, 100 for x, 010 for y 228 | ''' 229 | skeleton = np.transpose(skeleton, [3, 1, 2, 0]) # M, T, V, C 230 | joint_bottom = skeleton[person, frame, jpts[0]] 231 | joint_top = skeleton[person, frame, jpts[1]] 232 | axis_c = np.cross(joint_top - joint_bottom, axis) 233 | angle = angle_between(joint_top - joint_bottom, axis) 234 | matrix_z = rotation_matrix(axis_c, angle) 235 | tmp = np.dot(np.reshape(skeleton, (-1, 3)), matrix_z.transpose()) 236 | skeleton = np.reshape(tmp, skeleton.shape) 237 | return skeleton.transpose((3, 1, 2, 0)) 238 | 239 | 240 | def sub_center_jpt_fstframe(skeleton, jpt=0, frame=0, person=0): 241 | C, T, V, M = skeleton.shape 242 | skeleton = np.transpose(skeleton, [3, 1, 2, 0]) # M, T, V, C 243 | main_body_center = skeleton[person, frame, jpt].copy() # c 244 | for i_p, person in enumerate(skeleton): 245 | if person.sum() == 0: 246 | continue 247 | mask = (person.sum(-1) != 0).reshape(T, V, 1) # only for none zero frames 248 | skeleton[i_p] = (skeleton[i_p] - main_body_center) * mask 249 | return skeleton.transpose((3, 1, 2, 0)) 250 | 251 | 252 | def sub_center_jpt_perframe(skeleton, jpt=0, person=0): 253 | C, T, V, M = skeleton.shape 254 | skeleton = np.transpose(skeleton, [3, 1, 2, 0]) # M, T, V, C 255 | main_body_center = skeleton[person, :, jpt].copy().reshape((T, 1, C)) # tc 256 | for i_p, person in enumerate(skeleton): 257 | if person.sum() == 0: 258 | continue 259 | skeleton[i_p] = (skeleton[i_p] - main_body_center) # TVC-T1C 260 | return skeleton.transpose((3, 1, 2, 0)) 261 | 262 | 263 | def decouple_spatial(skeleton, edges=()): 264 | tmp = np.zeros(skeleton.shape) 265 | for v1, v2 in edges: 266 | tmp[:, :, v2, :] = skeleton[:, :, v2] - skeleton[:, :, v1] 267 | return tmp 268 | 269 | 270 | def obtain_angle(skeleton, edges=()): 271 | tmp = skeleton.copy() 272 | for v1, v2 in edges: 273 | v1 -= 1 274 | v2 -= 1 275 | x = skeleton[0, :, v1, :] - skeleton[0, :, v2, :] 276 | y = skeleton[1, :, v1, :] - skeleton[1, :, v2, :] 277 | z = skeleton[2, :, v1, :] - skeleton[2, :, v2, :] 278 | atan0 = np.arctan2(y, x) / 3.14 279 | atan1 = np.arctan2(z, x) / 3.14 280 | atan2 = np.arctan2(z, y) / 3.14 281 | t = np.stack([atan0, atan1, atan2], 0) 282 | tmp[:, :, v1, :] = t 283 | return tmp 284 | 285 | 286 | def decouple_temporal(skeleton, inter_frame=1): # CTVM 287 | skeleton = skeleton[:, ::inter_frame] 288 | diff = skeleton[:, 1:] - skeleton[:, :-1] 289 | return diff 290 | 291 | 292 | def norm_len_fstframe(skeleton, jpts=[0, 1], frame=0, person=0): 293 | C, T, V, M = skeleton.shape 294 | skeleton = np.transpose(skeleton, [3, 1, 2, 0]) # M, T, V, C 295 | main_body_spine = np.linalg.norm(skeleton[person, frame, jpts[0]] - skeleton[person, frame, jpts[1]]) 296 | if main_body_spine == 0: 297 | print('zero bone') 298 | else: 299 | skeleton /= main_body_spine 300 | return skeleton.transpose((3, 1, 2, 0)) 301 | 302 | 303 | def random_move_joint(data_numpy, sigma=0.1): # 只随机扰动坐标点 304 | # input: C,T,V,M 305 | C, T, V, M = data_numpy.shape 306 | 307 | rand_joint = np.random.randn(C, T, V, M) * sigma 308 | 309 | return data_numpy + rand_joint 310 | 311 | 312 | def pad_recurrent(data): 313 | skeleton = np.transpose(data, [3, 1, 2, 0]) # C, T, V, M to M, T, V, C 314 | for i_p, person in enumerate(skeleton): 315 | if person.sum() == 0: 316 | continue 317 | if person[0].sum() == 0: # TVC 去掉头空帧,然后对齐到顶端 318 | index = (person.sum(-1).sum(-1) != 0) 319 | tmp = person[index].copy() 320 | person *= 0 321 | person[:len(tmp)] = tmp 322 | for i_f, frame in enumerate(person): 323 | if frame.sum() == 0: 324 | if person[i_f:].sum() == 0: # 循环pad之前的帧 325 | rest = len(person) - i_f 326 | num = int(np.ceil(rest / i_f)) 327 | pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest] 328 | skeleton[i_p, i_f:] = pad 329 | break 330 | return skeleton.transpose((3, 1, 2, 0)) # ctvm 331 | 332 | 333 | def pad_recurrent_fix(data, length): # CTVM 334 | if data.shape[1] < length: 335 | num = int(np.ceil(length / data.shape[1])) 336 | data = np.concatenate([data for _ in range(num)], 1)[:, :length] 337 | return data 338 | 339 | 340 | def pad_zero(data, length): 341 | if data.shape[1] < length: 342 | new = np.zeros([data.shape[0], length - data.shape[1], data.shape[2], data.shape[3]]) 343 | data = np.concatenate([data, new], 1) 344 | return data 345 | 346 | 347 | import scipy.ndimage.interpolation as inter 348 | from scipy.signal import medfilt 349 | import warnings 350 | 351 | warnings.filterwarnings('ignore', '.*output shape of zoom.*') 352 | 353 | 354 | def zoom_T(p, target_l=64): 355 | ''' 356 | 357 | :param p: ctv 358 | :param target_l: 359 | :return: 360 | ''' 361 | C, T, V, M = p.shape 362 | p_new = np.empty([C, target_l, V, M]) 363 | for m in range(M): 364 | for v in range(V): 365 | for c in range(C): 366 | p_new[c, :, v, m] = inter.zoom(p[c, :, v, m], target_l / T)[:target_l] 367 | return p_new 368 | 369 | 370 | def filter_T(p, kernel_size=3): 371 | C, T, V, M = p.shape 372 | p_new = np.empty([C, T, V, M]) 373 | for m in range(M): 374 | for v in range(V): 375 | for c in range(C): 376 | p_new[c, :, v, m] = medfilt(p[c, :, v, m], kernel_size=kernel_size) 377 | return p_new 378 | 379 | 380 | def coor_to_volume(data, size): 381 | ''' 382 | 383 | :param data: CTVM 384 | :param size: [D, H, W] 385 | :return: CTDHW 386 | ''' 387 | C, T, V, M = data.shape 388 | volume = np.zeros([V * M, T, size[0], size[1], size[2]], dtype=np.float32) 389 | fst_ind = np.indices([T, V, M])[0] # T, V, M 390 | # one_hots = np.concatenate([np.tile(np.eye(V), [M, 1]), np.repeat(np.eye(M), V, axis=0)], axis=1).reshape( 391 | # (V, M, V + M)).transpose((2, 0, 1)) 392 | one_hots = np.eye(V * M).reshape((M, V, V * M)).transpose((2, 1, 0)) # C, V, M 393 | scd_inds = (data[::-1, :, :, :] * (np.array(size) - 1)[:, np.newaxis, np.newaxis, np.newaxis]).astype( 394 | np.long) # 3, T, V, M 395 | scd_inds = np.split(scd_inds, 3, axis=0) 396 | volume[:, fst_ind, scd_inds[0][0], scd_inds[1][0], scd_inds[2][0]] = one_hots[:, np.newaxis, :, :] 397 | return volume 398 | 399 | 400 | def coor_to_sparse(data, size, dilate_value=0, edges=None): 401 | ''' 402 | 403 | :param data: CTVM 404 | :param size: [D, H, W] 405 | :return: coords->TVMx(MC) 406 | ''' 407 | C, T, V, M = data.shape 408 | # features = np.tile(np.concatenate([np.tile(np.eye(V), [M, 1]), np.repeat(np.eye(M), V, axis=0)], axis=1).reshape( 409 | # (V, M, V + M)), [T, 1, 1, 1]).reshape((T * V * M, V + M)) 410 | features = np.tile(np.eye(V * M).reshape((M, V, V * M)).transpose((1, 0, 2)), [T, 1, 1, 1]) 411 | coords = (data * (np.array(size) - 1)[::-1, np.newaxis, np.newaxis, np.newaxis]) 412 | coords = np.concatenate([np.repeat(np.array(list(range(T))), V * M).reshape(1, T, V, M), coords], axis=0) 413 | coords = coords.transpose((1, 2, 3, 0)).astype(np.int32) 414 | 415 | if edges is not None: 416 | ecoords = [] 417 | efeatures = [] 418 | for t in range(T): 419 | for m in range(M): 420 | for edge in edges: 421 | f1 = features[t, edge[0], m] 422 | f2 = features[t, edge[1], m] 423 | c1 = coords[t, edge[0], m] 424 | c2 = coords[t, edge[1], m] 425 | c = max(np.abs(c2 - c1)) 426 | ecoords.extend( 427 | np.array([np.linspace(cc1, cc2, c) for cc1, cc2 in zip(c1, c2)]).transpose((1, 0)).astype( 428 | np.int)) 429 | efeatures.extend([np.maximum(f1, f2) for _ in range(c)]) 430 | features = np.concatenate([features.reshape((T * V * M, V * M)), efeatures], axis=0) 431 | coords = np.concatenate([coords.reshape((T * V * M, C + 1)), ecoords], axis=0) 432 | else: 433 | features = features.reshape((T * V * M, V * M)) 434 | coords = coords.reshape((T * V * M, C + 1)) 435 | 436 | coords_new = [] 437 | features_new = [] 438 | if dilate_value == 0: # remove pts 439 | for i, coord in enumerate(coords): 440 | if list(coord) in coords_new: 441 | ind = coords_new.index(list(coord)) 442 | features_new[ind] = np.maximum(features[i], features_new[ind]) 443 | else: 444 | coords_new.append(list(coord)) 445 | features_new.append(features[i]) 446 | else: 447 | dilates = list(range(-dilate_value, dilate_value + 1)) 448 | for i, coord in enumerate(coords): 449 | for j in range(C): 450 | for k in dilates: 451 | coord_e = coord.copy() 452 | coord_e[-j] += k 453 | if list(coord_e) in coords_new: 454 | ind = coords_new.index(list(coord_e)) 455 | features_new[ind] = np.maximum(features[i], features_new[ind]) 456 | else: 457 | coords_new.append(list(coord_e)) 458 | features_new.append(features[i]) 459 | 460 | return np.array(coords_new, dtype=np.int32), np.array(features_new, dtype=np.float32) 461 | 462 | 463 | def judge_type(paths, final_shape): 464 | if type(paths[0]) is str: 465 | try: 466 | img = cv2.imread(paths[0]) 467 | pre_shape = [len(paths), *img.shape] 468 | except: 469 | print(paths[0], ' is wrong') 470 | pre_shape = [len(paths), *final_shape[1:]] 471 | else: 472 | pre_shape = [len(paths), *paths[0].shape] 473 | 474 | return pre_shape 475 | 476 | 477 | def crop_resize(imgs, starts, cshape, final_shape, mean, use_flip, other_aug): 478 | imgs_crop = imgs[starts[0]:starts[0] + cshape[0]] # TODO: paths < cshape[0] 479 | imgs_final = sample_uniform_list(imgs_crop, final_shape[0]) 480 | 481 | if other_aug: 482 | imgs_final = video_aug(imgs_final) 483 | clip = [] 484 | for index, img in enumerate(imgs_final): 485 | clip.append(cv2.resize(img[starts[1]:starts[1] + cshape[1], starts[2]:starts[2] + cshape[2]], 486 | (final_shape[2], final_shape[1])).astype(np.float32) / 255 - mean) 487 | clip = np.transpose(np.array(clip, dtype=np.float32), (3, 0, 1, 2)) 488 | for i, f in enumerate(use_flip): 489 | if f: 490 | clip = np.flip(clip, i + 1).copy() # avoid negative strides 491 | return clip 492 | 493 | 494 | def resize_crop(imgs, resize_shape, final_shape, starts, mean, use_flip, other_aug): 495 | imgs_resize = np.array(sample_uniform_list(imgs, resize_shape[0])) 496 | imgs_crop = imgs_resize[starts[0]:starts[0] + final_shape[0]] 497 | if other_aug: 498 | imgs_crop = video_aug(imgs_crop) 499 | 500 | clip = [] 501 | for index, img in enumerate(imgs_crop): 502 | clip.append(cv2.resize(img, (resize_shape[2], resize_shape[1]))[starts[1]:starts[1] + final_shape[1], 503 | starts[2]:starts[2] + final_shape[2]].astype( 504 | np.float32) / 255 - mean) 505 | 506 | clip = np.transpose(np.array(clip, dtype=np.float32), (3, 0, 1, 2)) 507 | for i, f in enumerate(use_flip): 508 | if f: 509 | clip = np.flip(clip, i + 1).copy() # avoid negative strides 510 | return clip 511 | 512 | 513 | def pose_flip(pose, use_flip): 514 | ''' 515 | 516 | :param pose: T V C[x,y] M 517 | :param use_flip: 518 | :return: 519 | ''' 520 | pose_new = pose 521 | if use_flip[0]: 522 | pose_new = pose[::-1] 523 | if use_flip[1] and use_flip[2]: 524 | pose_new = 1 - pose 525 | pose_new[pose_new == 1] = 0 526 | elif use_flip[1]: 527 | pose_new = 1 - pose 528 | pose_new[pose_new == 1] = 0 529 | pose_new[:, :, 0, :] = pose[:, :, 0, :] 530 | elif use_flip[2]: 531 | pose_new = 1 - pose 532 | pose_new[pose_new == 1] = 0 533 | pose_new[:, :, 1, :] = pose[:, :, 1, :] 534 | 535 | return pose_new 536 | 537 | 538 | def pose_crop(pose_old, start, cshape, width, height): 539 | ''' 540 | 541 | :param pose_old: T V C M 542 | :param start: T H,W 543 | :param cshape:T H,W 544 | :param width: 545 | :param height: 546 | :return: 547 | ''' 548 | # temporal crop 549 | pose_new = pose_old[start[0]:start[0] + cshape[0]] 550 | T, V, C, M = pose_new.shape 551 | # 复原到图像大小 552 | pose_new = pose_new * (np.array([width, height]).reshape([1, 1, C, 1])) # T V C M 553 | # 减去边框 554 | pose_new -= np.array([start[2], start[1]]).reshape([1, 1, C, 1]) 555 | # 小于0的置0 556 | pose_new[(np.min(pose_new, -2) < 0).reshape(T, V, 1, M).repeat(C, -2)] = 0 557 | # 新位置除以crop后的大小 558 | pose_new /= np.array([cshape[2], cshape[1]]).reshape([1, 1, C, 1]) 559 | # 大于1的值1 560 | pose_new[(np.max(pose_new, -2) > 1).reshape(T, V, 1, M).repeat(C, -2)] = 0 561 | return pose_new 562 | 563 | 564 | def gen_clip_simple(paths, starts, resize_shape, final_shape, mean, use_flip, other_aug=False): 565 | try: 566 | if type(paths[0]) is str: 567 | imgs = [] 568 | paths = sample_uniform_list(paths, resize_shape[0]) 569 | for path in paths: 570 | try: 571 | img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 572 | except: 573 | print(path, ' is wrong') 574 | img = np.zeros([*final_shape[1:], 3], dtype=np.uint8) 575 | imgs.append(img) 576 | clip = resize_crop(imgs, resize_shape, final_shape, starts, mean, use_flip, other_aug) 577 | return clip 578 | elif type(paths[0]) is tuple: 579 | imgs, poses = np.array([i[0] for i in paths]), np.array([i[1] for i in paths]) 580 | if len(imgs) != len(poses): 581 | imgs = np.array(sample_uniform_list(imgs, len(poses))) 582 | if poses.shape[2] >= 3: # T,V,C,M 583 | poses = poses[:, :, :2] 584 | clip = resize_crop(imgs, resize_shape, final_shape, starts, mean, use_flip, other_aug).transpose( 585 | (1, 0, 2, 3)) 586 | poses = np.array(sample_uniform_list(poses, resize_shape[0])) 587 | poses = pose_crop(poses, starts, final_shape, resize_shape[2], resize_shape[1]) 588 | poses = pose_flip(poses, use_flip) 589 | return clip, poses 590 | else: 591 | imgs = paths 592 | clip = resize_crop(imgs, resize_shape, final_shape, starts, mean, use_flip, other_aug) 593 | return clip 594 | except: 595 | print(paths) 596 | 597 | 598 | def gen_clip(paths, starts, cshape, final_shape, mean, use_flip=(0, 0, 0), other_aug=False): 599 | try: 600 | if type(paths[0]) is str: 601 | imgs = [] 602 | for path in paths: 603 | try: 604 | img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 605 | except: 606 | print(path, ' is wrong') 607 | img = np.zeros([*final_shape[1:], 3], dtype=np.uint8) 608 | imgs.append(img) 609 | clip = crop_resize(imgs, starts, cshape, final_shape, mean, use_flip, other_aug) 610 | return clip 611 | elif type(paths[0]) is tuple: 612 | imgs, poses = np.array([i[0] for i in paths]), np.array([i[1] for i in paths]) 613 | if len(imgs) != len(poses): 614 | imgs = np.array(sample_uniform_list(imgs, len(poses))) 615 | if poses.shape[2] >= 3: # T,V,C,M 616 | poses = poses[:, :, :2] 617 | clip = crop_resize(imgs, starts, cshape, final_shape, mean, use_flip, other_aug).transpose( 618 | (1, 0, 2, 3)).copy() 619 | poses = poses[starts[0]:starts[0] + cshape[0]] 620 | poses = pose_crop(poses, starts, cshape, imgs[0].shape[1], imgs[0].shape[0]) 621 | poses = pose_flip(poses, use_flip).copy() 622 | return clip, poses 623 | else: 624 | imgs = paths 625 | clip = crop_resize(imgs, starts, cshape, final_shape, mean, use_flip, other_aug) 626 | return clip 627 | except: 628 | print(paths) 629 | 630 | 631 | def train_video_simple(paths, resize_shape, final_shape, mean, use_flip=(0, 0, 0), other_aug=False): 632 | """ 633 | 634 | :param paths: [frame1, frame2 ....] 635 | :param resize_shape: [l, h, w] 636 | :param final_shape: [l, h, w] 637 | :param mean: [l, h, w, 3] 638 | :param use_flip: [0,0,0] 639 | :return: 640 | """ 641 | gap = [resize_shape[i] - final_shape[i] for i in range(3)] 642 | 643 | starts = [int(a * random.random()) for a in gap] 644 | 645 | clip = gen_clip_simple(paths, starts, resize_shape, final_shape, mean, use_flip, other_aug=other_aug) 646 | 647 | return clip 648 | 649 | 650 | def val_video_simple(paths, resize_shape, final_shape, mean, use_flip=(0, 0, 0), other_aug=False): 651 | """ 652 | 653 | :param paths: [frame1, frame2 ....] 654 | :param resize_shape: [l, h, w] 655 | :param final_shape: [l, h, w] 656 | :param mean: [l, h, w, 3] 657 | :param use_flip: [0,0,0 658 | :return: 659 | """ 660 | 661 | gap = [resize_shape[i] - final_shape[i] for i in range(3)] 662 | 663 | starts = [int(a * 0.5) for a in gap] 664 | clip = gen_clip_simple(paths, starts, resize_shape, final_shape, mean, use_flip, other_aug=other_aug) 665 | 666 | return clip 667 | 668 | 669 | def eval_video(paths, crop_ratios, crop_positions, final_shape, mean, use_flip=(0, 0, 0)): 670 | """ 671 | 672 | :param paths: [frame1, frame2 ....] 673 | :param crop_ratios: [[t0, t1 ...], [h0, h1, ...], [w0, w1, ...]] 0-1 674 | :param crop_positions: [[t0, t1 ...], [h0, h1, ...], [w0, w1, ...]] 0-1 675 | :param final_shape: [l, h, w] 676 | :param mean: [l, h, w, 3] 677 | :param use_flip: [False, False, False] 678 | :return: 679 | """ 680 | pre_shape = judge_type(paths, final_shape) 681 | 682 | clips = [] 683 | for crop_t in crop_ratios[0]: 684 | for crop_h in crop_ratios[1]: 685 | for crop_w in crop_ratios[2]: 686 | cshape = [int(x) for x in [crop_t * pre_shape[0], crop_h * pre_shape[1], crop_w * pre_shape[2]]] 687 | 688 | gap = [pre_shape[i] - cshape[i] for i in range(3)] 689 | for p_t in crop_positions[0]: 690 | for p_h in crop_positions[1]: 691 | for p_w in crop_positions[2]: 692 | starts = [int(a * b) for a in gap for b in [p_t, p_h, p_w]] 693 | clip = gen_clip(paths, starts, cshape, final_shape, mean) 694 | clips.append(clip) # clhw 695 | for i, f in enumerate(use_flip): 696 | if f: 697 | clip_flip = np.flip(clip, i + 1).copy() 698 | clips.append(clip_flip) 699 | 700 | return clips 701 | 702 | 703 | def train_video(paths, crop_ratios, crop_positions, final_shape, mean, use_flip=(0, 0, 0)): 704 | """ 705 | 706 | :param paths: [frame1, frame2 ....] 707 | :param crop_ratios: [[t0, t1 ...], [h0, h1, ...], [w0, w1, ...]] 0-1 708 | :param crop_positions: [[t0, t1 ...], [h0, h1, ...], [w0, w1, ...]] 0-1 709 | :param final_shape: [l, h, w] 710 | :param mean: [l, h, w, 3] 711 | :param use_flip: True or False 712 | :return: 713 | """ 714 | pre_shape = judge_type(paths, final_shape) 715 | 716 | crop_t = random.sample(crop_ratios[0], 1)[0] 717 | crop_h = random.sample(crop_ratios[1], 1)[0] 718 | crop_w = random.sample(crop_ratios[2], 1)[0] 719 | cshape = [int(x) for x in [crop_t * pre_shape[0], crop_h * pre_shape[1], crop_w * pre_shape[2]]] 720 | 721 | gap = [pre_shape[i] - cshape[i] for i in range(3)] 722 | 723 | p_t = random.sample(crop_positions[0], 1)[0] 724 | p_h = random.sample(crop_positions[1], 1)[0] 725 | p_w = random.sample(crop_positions[2], 1)[0] 726 | 727 | starts = [int(a * b) for a, b in list(zip(gap, [p_t, p_h, p_w]))] 728 | clip = gen_clip(paths, starts, cshape, final_shape, mean, use_flip, other_aug=True) 729 | 730 | # for i, f in enumerate(use_flip): 731 | # if f: 732 | # clip = np.flip(clip, i + 1) 733 | 734 | return clip 735 | 736 | 737 | def val_video(paths, final_shape, mean): 738 | """ 739 | 740 | :param paths: [frame1, frame2 ....] 741 | :param final_shape: [l, h, w] 742 | :param mean: [l, h, w, 3] 743 | :return: 744 | """ 745 | pre_shape = judge_type(paths, final_shape) 746 | 747 | crop_t = 1 748 | crop_h = 1 749 | crop_w = 1 750 | cshape = [int(x) for x in [crop_t * pre_shape[0], crop_h * pre_shape[1], crop_w * pre_shape[2]]] 751 | 752 | gap = [pre_shape[i] - cshape[i] for i in range(3)] 753 | 754 | p_t = 0.5 755 | p_h = 0.5 756 | p_w = 0.5 757 | 758 | starts = [int(a * b) for a, b in list(zip(gap, [p_t, p_h, p_w]))] 759 | clip = gen_clip(paths, starts, cshape, final_shape, mean) 760 | 761 | return clip 762 | -------------------------------------------------------------------------------- /method_choose/__pycache__/data_choose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/method_choose/__pycache__/data_choose.cpython-36.pyc -------------------------------------------------------------------------------- /method_choose/__pycache__/loss_choose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/method_choose/__pycache__/loss_choose.cpython-36.pyc -------------------------------------------------------------------------------- /method_choose/__pycache__/lr_scheduler_choose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/method_choose/__pycache__/lr_scheduler_choose.cpython-36.pyc -------------------------------------------------------------------------------- /method_choose/__pycache__/model_choose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/method_choose/__pycache__/model_choose.cpython-36.pyc -------------------------------------------------------------------------------- /method_choose/__pycache__/optimizer_choose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/method_choose/__pycache__/optimizer_choose.cpython-36.pyc -------------------------------------------------------------------------------- /method_choose/__pycache__/tra_val_choose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/method_choose/__pycache__/tra_val_choose.cpython-36.pyc -------------------------------------------------------------------------------- /method_choose/data_choose.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | import shutil 9 | import inspect 10 | from dataset.ntu_skeleton import NTU_SKE 11 | from dataset.dhg_skeleton import DHG_SKE 12 | from dataset.preparedata import adni, adni_val 13 | 14 | def init_seed(x): 15 | # pass 16 | torch.cuda.manual_seed_all(1) 17 | torch.manual_seed(1) 18 | np.random.seed(1) 19 | random.seed(1) 20 | torch.backends.cudnn.enabled = False 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | 25 | def data_choose(args, block): 26 | if args.mode == 'test' or args.mode == 'watch_off': 27 | if args.data == 'ntu_skeleton': 28 | workers = args.worker 29 | data_set_val = NTU_SKE(mode='eval_rot', **args.data_param['val_data_param']) 30 | else: 31 | raise (RuntimeError('No data loader')) 32 | data_loader_val = DataLoader(data_set_val, batch_size=args.batch_size, shuffle=False, 33 | num_workers=workers, drop_last=False, pin_memory=args.pin_memory, 34 | worker_init_fn=init_seed) 35 | data_loader_train = None 36 | 37 | block.log('Data load finished: ' + args.data) 38 | 39 | shutil.copy2(__file__, args.model_saved_name) 40 | return data_loader_train, data_loader_val 41 | else: 42 | if args.data == 'ntu_skeleton': 43 | workers = args.worker 44 | data_set_train = NTU_SKE(mode='train', **args.data_param['train_data_param']) 45 | data_set_val = NTU_SKE(mode='val', **args.data_param['val_data_param']) 46 | elif args.data == 'adhd': 47 | workers = args.worker 48 | data_set_train = adhd(mode='train', **args.data_param['train_data_param']) 49 | data_set_val = adhd_val(mode='val', **args.data_param['val_data_param']) 50 | elif args.data == 'adni': 51 | workers = args.worker 52 | data_set_train = adni(mode='train', **args.data_param['train_data_param']) 53 | data_set_val = adni_val(mode='val', **args.data_param['val_data_param']) 54 | elif args.data == 'dhg_skeleton': 55 | workers = args.worker 56 | data_set_train = DHG_SKE(mode='train', **args.data_param['train_data_param']) 57 | data_set_val = DHG_SKE(mode='val', **args.data_param['val_data_param']) 58 | elif args.data == 'shrec_skeleton': 59 | workers = args.worker 60 | data_set_train = DHG_SKE(mode='train', **args.data_param['train_data_param']) 61 | data_set_val = DHG_SKE(mode='val', **args.data_param['val_data_param']) 62 | else: 63 | raise (RuntimeError('No data loader')) 64 | data_loader_val = DataLoader(data_set_val, batch_size=args.batch_size, shuffle=False, 65 | num_workers=workers, drop_last=False, pin_memory=args.pin_memory, 66 | worker_init_fn=init_seed) 67 | data_loader_train = DataLoader(data_set_train, batch_size=args.batch_size, shuffle=True, 68 | num_workers=workers, drop_last=True, pin_memory=args.pin_memory, 69 | worker_init_fn=init_seed) 70 | 71 | block.log('Data load finished: ' + args.data) 72 | 73 | shutil.copy2(__file__, args.model_saved_name) 74 | return data_loader_train, data_loader_val 75 | -------------------------------------------------------------------------------- /method_choose/loss_choose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utility.log import TimerBlock 3 | from train_val_test.loss import L1, L2 4 | import torch.nn.functional as func 5 | import shutil 6 | import inspect 7 | import torch.nn as nn 8 | 9 | 10 | # def to_onehot(num_class, label, alpha): 11 | # return torch.zeros((label.shape[0], num_class)).fill_(alpha).scatter_(1, label.unsqueeze(1), 1 - alpha) 12 | 13 | 14 | # class naive_cross_entropy_loss(nn.Module): 15 | # def __init__(self, num_class, alpha): 16 | # self.num_class = num_class 17 | # self.alpha = alpha 18 | # super(naive_cross_entropy_loss, self).__init__() 19 | # 20 | # def forward(self, inputs, target): 21 | # target = to_onehot(self.num_class, target, self.alpha) 22 | # return - (func.log_softmax(inputs, dim=-1) * target).sum(dim=-1).mean() 23 | 24 | 25 | class multi_cross_entropy_loss(nn.Module): 26 | def __init__(self): 27 | self.loss = torch.nn.CrossEntropyLoss(size_average=True) 28 | super(multi_cross_entropy_loss, self).__init__() 29 | 30 | def forward(self, inputs, target): 31 | ''' 32 | 33 | :param inputs: N C S 34 | :param target: N C 35 | :return: 36 | ''' 37 | num = inputs.shape[-1] 38 | inputs_splits = torch.chunk(inputs, num, dim=-1) 39 | loss = self.loss(inputs_splits[0].squeeze(-1), target) 40 | for i in range(1, num): 41 | loss += self.loss(inputs_splits[i].squeeze(-1), target) 42 | loss /= num 43 | return loss 44 | 45 | 46 | def naive_cross_entropy_loss(inputs, target): 47 | return - (func.log_softmax(inputs, dim=-1) * target).sum(dim=-1).mean() 48 | 49 | 50 | # def multi_cross_entropy_loss(inputs, target): 51 | # ''' 52 | # 53 | # :param inputs: N C S 54 | # :param target: N C 55 | # :return: 56 | # ''' 57 | # num = inputs.shape[-1] 58 | # inputs_splits = torch.chunk(inputs, num, dim=-1) 59 | # loss = - (func.log_softmax(inputs_splits[0].squeeze(-1), dim=-1) * target).sum(dim=-1).mean() 60 | # for i in range(1, num): 61 | # loss += - (func.log_softmax(inputs_splits[i].squeeze(-1), dim=-1) * target).sum(dim=-1).mean() 62 | # loss /= num 63 | # return loss 64 | 65 | # from warpctc_pytorch import CTCLoss 66 | # class CTC(nn.Module): 67 | # def __init__(self, input_len, target_len): 68 | # super(CTC, self).__init__() 69 | # self.ctc = CTCLoss(size_average=True, length_average=False) # TNC 70 | # self.input_len = input_len 71 | # self.target_len = target_len 72 | # 73 | # def forward(self, input, target): 74 | # """ 75 | # blank is default as 0 in ctc, but is -1 in model prob 76 | # :param input: TxNxcls 77 | # :param target: N, begin with 0 78 | # :return: 79 | # """ 80 | # batch_size = target.shape[0] 81 | # input_ = torch.cat([input[:,:,-1:], input[:,:,:-1]], dim=-1).clone() 82 | # target_ = target + 1 83 | # input_.requires_grad_(True) 84 | # in_len = torch.IntTensor([self.input_len]*batch_size) # .to(input_.get_device()) 85 | # out_len = torch.IntTensor([self.target_len]*batch_size) # .to(input_.get_device()) 86 | # ls = self.ctc(input_.cpu(), target_.cpu(), in_len, out_len) 87 | # return ls 88 | 89 | 90 | class CTC(nn.Module): 91 | def __init__(self, input_len, target_len, blank=0): 92 | super(CTC, self).__init__() 93 | self.ctc = nn.CTCLoss(blank=blank, reduction='mean', zero_infinity=True) 94 | self.input_len = input_len 95 | self.target_len = target_len 96 | 97 | def forward(self, input, target): 98 | """ 99 | 100 | :param input: TxNxcls 101 | :param target: N 102 | :return: 103 | """ 104 | batch_size = target.shape[0] 105 | input_ = torch.cat([input[:,:,-1:], input[:,:,:-1]], dim=-1).clone() 106 | target_ = target + 1 107 | target_ = target_.unsqueeze(-1) 108 | # target = torch.cat([target.unsqueeze(-1), target.unsqueeze(-1)], dim=1) 109 | ls = self.ctc(input_.log_softmax(2), target_, [self.input_len]*batch_size, [self.target_len]*batch_size) 110 | return ls 111 | 112 | 113 | def loss_choose(args, block): 114 | loss = args.loss 115 | if loss == 'cross_entropy': 116 | # if args.mix_up_num > 0: 117 | loss_function = torch.nn.CrossEntropyLoss(size_average=True) 118 | # else: 119 | elif loss == 'cross_entropy_naive': 120 | loss_function = naive_cross_entropy_loss 121 | elif loss == 'ctc': 122 | p = args.ls_param 123 | loss_function = CTC(p.input_len, p.target_len) 124 | elif loss == 'multi_cross_entropy': 125 | loss_function = multi_cross_entropy_loss() 126 | elif loss == 'mse_ce': 127 | loss_function = [torch.nn.MSELoss(), torch.nn.CrossEntropyLoss(size_average=True)] 128 | elif loss == 'l1loss': 129 | loss_function = L1() 130 | elif loss == 'l2loss': 131 | loss_function = L2() 132 | else: 133 | loss_function = torch.nn.CrossEntropyLoss(size_average=True) 134 | 135 | block.log('Using loss: ' + loss) 136 | # shutil.copy2(inspect.getfile(loss_function), args.model_saved_name) 137 | shutil.copy2(__file__, args.model_saved_name) 138 | return loss_function 139 | 140 | 141 | if __name__ == '__main__': 142 | res_ctc = torch.Tensor([[[0, 0, 1]], [[0.5, 0.6, 0.2]], [[0, 0., 1.]]]) 143 | b = 1 # batch 144 | c = 2 # label有多少类 145 | in_len = 3 # 预测每个序列有多少个label 146 | label_len = 1 # 实际每个序列有多少个label 147 | 148 | # res_ctc = torch.rand([in_len, b, c+1]) 149 | target = torch.zeros([b*label_len,], dtype=torch.long) 150 | 151 | loss_ctc = CTC(in_len, label_len) 152 | ls_ctc = loss_ctc(res_ctc, target) 153 | 154 | # loss_ctcp = CTCP(in_len, label_len) 155 | # ls_ctcp = loss_ctcp(res_ctc, target) 156 | 157 | # print(ls_ctc, ls_ctcp) 158 | -------------------------------------------------------------------------------- /method_choose/lr_scheduler_choose.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR, CosineAnnealingLR 2 | from utility.log import TimerBlock 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | import shutil 5 | import inspect 6 | 7 | 8 | class GradualWarmupScheduler(): 9 | """ Gradually warm-up(increasing) learning rate in optimizer. 10 | Args: 11 | optimizer (Optimizer): Wrapped optimizer. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, total_epoch, after_scheduler=None, last_epoch=-1): 17 | self.total_epoch = total_epoch 18 | self.after_scheduler = after_scheduler 19 | self.finished = False 20 | self.last_epoch = last_epoch 21 | self.optimizer = optimizer 22 | if last_epoch == -1: 23 | for group in optimizer.param_groups: 24 | group.setdefault('initial_lr', group['lr']) 25 | else: 26 | for i, group in enumerate(optimizer.param_groups): 27 | if 'initial_lr' not in group: 28 | raise KeyError("param 'initial_lr' is not specified " 29 | "in param_groups[{}] when resuming an optimizer".format(i)) 30 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 31 | # 提供epoch0的初始学习率,改掉继承optimizer时继承的 32 | # if type(after_scheduler) is ReduceLROnPlateau: 33 | # print(type(after_scheduler)) 34 | # if after_scheduler.mode=='min': 35 | # init_metric = 100 36 | # else: 37 | # init_metric = 0 38 | # self.step(metric=init_metric, epoch=last_epoch+1) 39 | # else: 40 | # self.step(epoch=last_epoch+1) 41 | # super().__init__(optimizer) 42 | 43 | def get_lr(self): 44 | return [base_lr * (self.last_epoch + 1) / self.total_epoch for base_lr in self.base_lrs] 45 | 46 | def step(self, epoch=None, metric=None): 47 | if self.last_epoch >= self.total_epoch - 1: 48 | if metric is None: 49 | return self.after_scheduler.step(epoch) 50 | else: 51 | return self.after_scheduler.step(metric, epoch) 52 | else: 53 | # return super(GradualWarmupScheduler, self).step(epoch) 54 | if epoch is None: 55 | epoch = self.last_epoch + 1 56 | self.last_epoch = epoch 57 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 58 | param_group['lr'] = lr 59 | 60 | 61 | def lr_scheduler_choose(optimizer, args, last_epoch, block): 62 | lr_args = args.lr_param 63 | if args.lr_scheduler == 'reduce_by_acc': 64 | lr_patience = lr_args['lr_patience'] 65 | lr_threshold = lr_args['lr_threshold'] 66 | lr_delay = lr_args['lr_delay'] 67 | block.log('lr scheduler: lr:{} DecayRatio:{} Patience:{} Threshold:{} Before_epoch:{}' 68 | .format(args.lr, args.lr_decay_ratio, lr_patience, lr_threshold, lr_delay)) 69 | lr_scheduler_pre = ReduceLROnPlateau(optimizer, mode='max', factor=args.lr_decay_ratio, 70 | patience=lr_patience, verbose=True, 71 | threshold=lr_threshold, threshold_mode='abs', 72 | cooldown=lr_delay) 73 | lr_scheduler = GradualWarmupScheduler(optimizer, total_epoch=args.warm_up_epoch, 74 | after_scheduler=lr_scheduler_pre, last_epoch=last_epoch) 75 | elif args.lr_scheduler == 'reduce_by_loss': 76 | lr_patience = lr_args['lr_patience'] 77 | lr_threshold = lr_args['lr_threshold'] 78 | lr_delay = lr_args['lr_delay'] 79 | block.log('lr scheduler: lr: {} DecayRatio: {} Patience: {} Threshold: {} Before_epoch: {}' 80 | .format(args.lr, args.lr_decay_ratio, lr_patience, lr_threshold, lr_delay)) 81 | lr_scheduler_pre = ReduceLROnPlateau(optimizer, mode='min', factor=args.lr_decay_ratio, 82 | patience=lr_patience, verbose=True, 83 | threshold=lr_threshold, threshold_mode='abs', 84 | cooldown=lr_delay) 85 | lr_scheduler = GradualWarmupScheduler(optimizer, total_epoch=args.warm_up_epoch, 86 | after_scheduler=lr_scheduler_pre, last_epoch=last_epoch) 87 | elif args.lr_scheduler == 'reduce_by_epoch': 88 | step = lr_args['step'] 89 | block.log('lr scheduler: Reduce by epoch, step: ' + str(step)) 90 | lr_scheduler_pre = MultiStepLR(optimizer, step, last_epoch=last_epoch, gamma=args.lr_decay_ratio) 91 | lr_scheduler = GradualWarmupScheduler(optimizer, total_epoch=args.warm_up_epoch, 92 | after_scheduler=lr_scheduler_pre, last_epoch=last_epoch) 93 | elif args.lr_scheduler == 'cosine_annealing_lr': 94 | lr_scheduler_pre = CosineAnnealingLR(optimizer, lr_args.max_epoch + 1, eta_min=0.0001, last_epoch=last_epoch) 95 | lr_scheduler = GradualWarmupScheduler(optimizer, total_epoch=args.warm_up_epoch, 96 | after_scheduler=lr_scheduler_pre, last_epoch=last_epoch) 97 | else: 98 | raise ValueError() 99 | # shutil.copy2(inspect.getfile(lr_scheduler), args.model_saved_name) 100 | shutil.copy2(__file__, args.model_saved_name) 101 | return lr_scheduler 102 | -------------------------------------------------------------------------------- /method_choose/model_choose.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | import shutil 6 | import inspect 7 | from model.dstanet import DSTANet 8 | 9 | 10 | def rm_module(old_dict): 11 | new_state_dict = OrderedDict() 12 | for k, v in old_dict.items(): 13 | head = k[:7] 14 | if head == 'module.': 15 | name = k[7:] # remove `module.` 16 | else: 17 | name = k 18 | new_state_dict[name] = v 19 | return new_state_dict 20 | 21 | 22 | def model_choose(args, block): 23 | m = args.model 24 | if m == 'dstanet': 25 | model = DSTANet(num_class=args.class_num, **args.model_param) 26 | shutil.copy2(inspect.getfile(DSTANet), args.model_saved_name) 27 | else: 28 | raise (RuntimeError("No modules")) 29 | 30 | shutil.copy2(__file__, args.model_saved_name) 31 | block.log('Model load finished: ' + args.model + ' mode: train') 32 | optimizer_dict = None 33 | 34 | if args.pre_trained_model is not None: 35 | model_dict = model.state_dict() 36 | pretrained_dict = torch.load(args.pre_trained_model) # ['state_dict'] 37 | if type(pretrained_dict) is dict and ('optimizer' in pretrained_dict.keys()): 38 | optimizer_dict = pretrained_dict['optimizer'] 39 | pretrained_dict = pretrained_dict['model'] 40 | pretrained_dict = rm_module(pretrained_dict) 41 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 42 | keys = list(pretrained_dict.keys()) 43 | for key in keys: 44 | for weight in args.ignore_weights: 45 | if weight in key: 46 | if pretrained_dict.pop(key) is not None: 47 | block.log('Sucessfully Remove Weights: {}.'.format(key)) 48 | else: 49 | block.log('Can Not Remove Weights: {}.'.format(key)) 50 | block.log('following weight not load: ' + str(set(model_dict) - set(pretrained_dict))) 51 | model_dict.update(pretrained_dict) 52 | # block.log(model_dict) 53 | model.load_state_dict(model_dict) 54 | block.log('Pretrained model load finished: ' + args.pre_trained_model) 55 | 56 | global_step = 0 57 | global_epoch = 0 58 | # The name for model must be **_**-$(step).state 59 | if args.last_model is not None: 60 | model_dict = model.state_dict() 61 | pretrained_dict = torch.load(args.last_model) # ['state_dict'] 62 | if type(pretrained_dict) is dict and ('optimizer' in pretrained_dict.keys()): 63 | optimizer_dict = pretrained_dict['optimizer'] 64 | pretrained_dict = pretrained_dict['model'] 65 | pretrained_dict = rm_module(pretrained_dict) 66 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 67 | block.log('In last model, following weight not load: ' + str(set(model_dict) - set(pretrained_dict))) 68 | model_dict.update(pretrained_dict) 69 | model.load_state_dict(model_dict) 70 | 71 | try: 72 | global_step = int(args.last_model[:-6].split('-')[2]) 73 | global_epoch = int(args.last_model[:-6].split('-')[1]) 74 | except: 75 | global_epoch = global_step = 0 76 | block.log('Training continue, last model load finished, step is {}, epoch is {}'.format(str(global_step), 77 | str(global_epoch))) 78 | print('--------------') 79 | print(torch.cuda.is_available()) 80 | model.cuda() 81 | model = nn.DataParallel(model, device_ids=args.device_id) 82 | block.log('copy model to gpu') 83 | return global_step, global_epoch, model, optimizer_dict 84 | -------------------------------------------------------------------------------- /method_choose/optimizer_choose.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | from utility.log import TimerBlock 5 | from torch.optim.sgd import SGD 6 | import shutil 7 | import inspect 8 | 9 | 10 | def optimizer_choose(model, args, writer, block): 11 | params = [] 12 | for key, value in model.named_parameters(): 13 | if value.requires_grad: 14 | params += [{'params': [value], 'lr': args.lr, 'key': key, 'weight_decay': args.wd}] 15 | 16 | if args.optimizer == 'adam': 17 | optimizer = torch.optim.Adam(params) 18 | block.log('Using Adam optimizer') 19 | elif args.optimizer == 'sgd': 20 | momentum = 0.9 21 | optimizer = SGD(params, momentum=momentum) 22 | block.log('Using SGD with momentum ' + str(momentum)) 23 | elif args.optimizer == 'sgd_nev': 24 | momentum = 0.9 25 | optimizer = SGD(params, momentum=momentum, nesterov=True) 26 | block.log('Using SGD with momentum ' + str(momentum) + 'and nesterov') 27 | else: 28 | momentum = 0.9 29 | optimizer = SGD(params, momentum=momentum) 30 | block.log('Using SGD with momentum ' + str(momentum)) 31 | 32 | # shutil.copy2(inspect.getfile(optimizer), args.model_saved_name) 33 | shutil.copy2(__file__, args.model_saved_name) 34 | return optimizer 35 | -------------------------------------------------------------------------------- /method_choose/tra_val_choose.py: -------------------------------------------------------------------------------- 1 | from train_val_test import train_val_model 2 | import shutil 3 | import inspect 4 | 5 | 6 | def train_val_choose(args, block): 7 | if args.train == 'classify': 8 | train_net = train_val_model.train_classifier 9 | val_net = train_val_model.val_classifier 10 | else: 11 | raise ValueError("args of train val is not right") 12 | 13 | shutil.copy2(inspect.getfile(train_net), args.model_saved_name) 14 | shutil.copy2(__file__, args.model_saved_name) 15 | 16 | return train_net, val_net 17 | 18 | 19 | if __name__ == '__main__': 20 | train_net = train_val_model.train_classifier 21 | print(inspect.getfile(train_net)) 22 | -------------------------------------------------------------------------------- /model/__pycache__/dstanet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/model/__pycache__/dstanet.cpython-36.pyc -------------------------------------------------------------------------------- /model/dstanet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | 6 | 7 | def conv_init(conv): 8 | nn.init.kaiming_normal_(conv.weight, mode='fan_out') 9 | # nn.init.constant_(conv.bias, 0) 10 | 11 | 12 | def bn_init(bn, scale): 13 | nn.init.constant_(bn.weight, scale) 14 | nn.init.constant_(bn.bias, 0) 15 | 16 | 17 | def fc_init(fc): 18 | nn.init.xavier_normal_(fc.weight) 19 | nn.init.constant_(fc.bias, 0) 20 | 21 | 22 | class PositionalEncoding(nn.Module): 23 | 24 | def __init__(self, channel, joint_num, time_len, domain): 25 | super(PositionalEncoding, self).__init__() 26 | self.joint_num = joint_num 27 | self.time_len = time_len 28 | 29 | self.domain = domain 30 | 31 | if domain == "temporal": 32 | # temporal embedding 33 | pos_list = [] 34 | for t in range(self.time_len): 35 | for j_id in range(self.joint_num): 36 | pos_list.append(t) 37 | elif domain == "spatial": 38 | # spatial embedding 39 | pos_list = [] 40 | for t in range(self.time_len): 41 | for j_id in range(self.joint_num): 42 | pos_list.append(j_id) 43 | 44 | position = torch.from_numpy(np.array(pos_list)).unsqueeze(1).float() 45 | # pe = position/position.max()*2 -1 46 | # pe = pe.view(time_len, joint_num).unsqueeze(0).unsqueeze(0) 47 | # Compute the positional encodings once in log space. 48 | pe = torch.zeros(self.time_len * self.joint_num, channel) 49 | 50 | div_term = torch.exp(torch.arange(0, channel, 2).float() * 51 | -(math.log(10000.0) / channel)) # channel//2 52 | pe[:, 0::2] = torch.sin(position * div_term) 53 | pe[:, 1::2] = torch.cos(position * div_term) 54 | pe = pe.view(time_len, joint_num, channel).permute(2, 0, 1).unsqueeze(0) 55 | self.register_buffer('pe', pe) 56 | 57 | def forward(self, x): # nctv 58 | x = x + self.pe[:, :, :x.size(2)] 59 | return x 60 | 61 | 62 | class STAttentionBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels, inter_channels, num_subset=3, num_node=25, num_frame=32, 64 | kernel_size=1, stride=1, glo_reg_s=True, att_s=True, glo_reg_t=True, att_t=True, 65 | use_temporal_att=True, use_spatial_att=True, attentiondrop=0, use_pes=True, use_pet=True): 66 | super(STAttentionBlock, self).__init__() 67 | self.inter_channels = inter_channels 68 | self.out_channels = out_channels 69 | self.in_channels = in_channels 70 | self.num_subset = num_subset 71 | self.glo_reg_s = glo_reg_s 72 | self.att_s = att_s 73 | self.glo_reg_t = glo_reg_t 74 | self.att_t = att_t 75 | self.use_pes = use_pes 76 | self.use_pet = use_pet 77 | 78 | pad = int((kernel_size - 1) / 2) 79 | self.use_spatial_att = use_spatial_att 80 | if use_spatial_att: 81 | atts = torch.zeros((1, num_subset, num_node, num_node)) 82 | self.register_buffer('atts', atts) 83 | self.pes = PositionalEncoding(in_channels, num_node, num_frame, 'spatial') 84 | self.ff_nets = nn.Sequential( 85 | nn.Conv2d(out_channels, out_channels, 1, 1, padding=0, bias=True), 86 | nn.BatchNorm2d(out_channels), 87 | ) 88 | if att_s: 89 | self.in_nets = nn.Conv2d(in_channels, 2 * num_subset * inter_channels, 1, bias=True) 90 | self.alphas = nn.Parameter(torch.ones(1, num_subset, 1, 1), requires_grad=True) 91 | if glo_reg_s: 92 | self.attention0s = nn.Parameter(torch.ones(1, num_subset, num_node, num_node) / num_node, 93 | requires_grad=True) 94 | 95 | self.out_nets = nn.Sequential( 96 | nn.Conv2d(in_channels * num_subset, out_channels, 1, bias=True), 97 | nn.BatchNorm2d(out_channels), 98 | ) 99 | else: 100 | self.out_nets = nn.Sequential( 101 | nn.Conv2d(in_channels, out_channels, (1, 3), padding=(0, 1), bias=True, stride=1), 102 | nn.BatchNorm2d(out_channels), 103 | ) 104 | self.use_temporal_att = use_temporal_att 105 | if use_temporal_att: 106 | attt = torch.zeros((1, num_subset, num_frame, num_frame)) 107 | self.register_buffer('attt', attt) 108 | self.pet = PositionalEncoding(out_channels, num_node, num_frame, 'temporal') 109 | self.ff_nett = nn.Sequential( 110 | nn.Conv2d(out_channels, out_channels, (kernel_size, 1), (stride, 1), padding=(pad, 0), bias=True), 111 | nn.BatchNorm2d(out_channels), 112 | ) 113 | if att_t: 114 | self.in_nett = nn.Conv2d(out_channels, 2 * num_subset * inter_channels, 1, bias=True) 115 | self.alphat = nn.Parameter(torch.ones(1, num_subset, 1, 1), requires_grad=True) 116 | if glo_reg_t: 117 | self.attention0t = nn.Parameter(torch.zeros(1, num_subset, num_frame, num_frame) + torch.eye(num_frame), 118 | requires_grad=True) 119 | self.out_nett = nn.Sequential( 120 | nn.Conv2d(out_channels * num_subset, out_channels, 1, bias=True), 121 | nn.BatchNorm2d(out_channels), 122 | ) 123 | else: 124 | self.out_nett = nn.Sequential( 125 | nn.Conv2d(out_channels, out_channels, (7, 1), padding=(3, 0), bias=True, stride=(stride, 1)), 126 | nn.BatchNorm2d(out_channels), 127 | ) 128 | 129 | if in_channels != out_channels or stride != 1: 130 | if use_spatial_att: 131 | self.downs1 = nn.Sequential( 132 | nn.Conv2d(in_channels, out_channels, 1, bias=True), 133 | nn.BatchNorm2d(out_channels), 134 | ) 135 | self.downs2 = nn.Sequential( 136 | nn.Conv2d(in_channels, out_channels, 1, bias=True), 137 | nn.BatchNorm2d(out_channels), 138 | ) 139 | if use_temporal_att: 140 | self.downt1 = nn.Sequential( 141 | nn.Conv2d(out_channels, out_channels, 1, 1, bias=True), 142 | nn.BatchNorm2d(out_channels), 143 | ) 144 | self.downt2 = nn.Sequential( 145 | nn.Conv2d(out_channels, out_channels, (kernel_size, 1), (stride, 1), padding=(pad, 0), bias=True), 146 | nn.BatchNorm2d(out_channels), 147 | ) 148 | else: 149 | if use_spatial_att: 150 | self.downs1 = lambda x: x 151 | self.downs2 = lambda x: x 152 | if use_temporal_att: 153 | self.downt1 = lambda x: x 154 | self.downt2 = lambda x: x 155 | 156 | self.soft = nn.Softmax(-2) 157 | self.tan = nn.Tanh() 158 | self.relu = nn.LeakyReLU(0.1) 159 | self.drop = nn.Dropout(attentiondrop) 160 | 161 | def forward(self, x): 162 | 163 | N, C, T, V = x.size() 164 | if self.use_spatial_att: 165 | attention = self.atts 166 | if self.use_pes: 167 | y = self.pes(x) 168 | else: 169 | y = x 170 | if self.att_s: 171 | q, k = torch.chunk(self.in_nets(y).view(N, 2 * self.num_subset, self.inter_channels, T, V), 2, 172 | dim=1) # nctv -> n num_subset c'tv 173 | attention = attention + self.tan( 174 | torch.einsum('nsctu,nsctv->nsuv', [q, k]) / (self.inter_channels * T)) * self.alphas 175 | if self.glo_reg_s: 176 | attention = attention + self.attention0s.repeat(N, 1, 1, 1) 177 | attention = self.drop(attention) 178 | y = torch.einsum('nctu,nsuv->nsctv', [x, attention]).contiguous() \ 179 | .view(N, self.num_subset * self.in_channels, T, V) 180 | y = self.out_nets(y) # nctv 181 | y = self.relu(self.downs1(x) + y) 182 | y = self.ff_nets(y) 183 | y = self.relu(self.downs2(x) + y) 184 | else: 185 | y = self.out_nets(x) 186 | y = self.relu(self.downs2(x) + y) 187 | 188 | if self.use_temporal_att: 189 | attention = self.attt 190 | if self.use_pet: 191 | z = self.pet(y) 192 | else: 193 | z = y 194 | if self.att_t: 195 | q, k = torch.chunk(self.in_nett(z).view(N, 2 * self.num_subset, self.inter_channels, T, V), 2, 196 | dim=1) # nctv -> n num_subset c'tv 197 | attention = attention + self.tan( 198 | torch.einsum('nsctv,nscqv->nstq', [q, k]) / (self.inter_channels * V)) * self.alphat 199 | if self.glo_reg_t: 200 | attention = attention + self.attention0t.repeat(N, 1, 1, 1) 201 | attention = self.drop(attention) 202 | z = torch.einsum('nctv,nstq->nscqv', [y, attention]).contiguous() \ 203 | .view(N, self.num_subset * self.out_channels, T, V) 204 | z = self.out_nett(z) # nctv 205 | 206 | z = self.relu(self.downt1(y) + z) 207 | 208 | z = self.ff_nett(z) 209 | 210 | z = self.relu(self.downt2(y) + z) 211 | else: 212 | z = self.out_nett(y) 213 | z = self.relu(self.downt2(y) + z) 214 | return z 215 | 216 | 217 | 218 | class STKernelAttentionBlock(nn.Module): 219 | def __init__(self, in_channels, out_channels, inter_channels, num_subset=3, num_node=25, num_frame=32, 220 | kernel_size=1, stride=1, glo_reg_s=True, att_s=True, glo_reg_t=True, att_t=True, 221 | use_temporal_att=True, use_spatial_att=True, attentiondrop=0, use_pes=True, use_pet=True, diffusion=0): 222 | super(STKernelAttentionBlock, self).__init__() 223 | self.inter_channels = inter_channels 224 | self.out_channels = out_channels 225 | self.in_channels = in_channels 226 | self.num_subset = num_subset 227 | self.glo_reg_s = glo_reg_s 228 | self.att_s = att_s 229 | self.glo_reg_t = glo_reg_t 230 | self.att_t = att_t 231 | self.use_pes = use_pes 232 | self.use_pet = use_pet 233 | self.diffusion = diffusion 234 | 235 | pad = int((kernel_size - 1) / 2) 236 | self.use_spatial_att = use_spatial_att 237 | self.dfw2 = nn.Parameter(torch.ones(1), requires_grad=True) 238 | self.dfw3 = nn.Parameter(torch.ones(1), requires_grad=True) 239 | self.dfw4 = nn.Parameter(torch.ones(1), requires_grad=True) 240 | self.dfw5 = nn.Parameter(torch.ones(1), requires_grad=True) 241 | 242 | if use_spatial_att: 243 | self.theta = torch.nn.Parameter(torch.zeros((1, num_subset, 1, 1))) 244 | atts = torch.zeros((1, num_subset, num_node, num_node)) 245 | self.register_buffer('atts', atts) 246 | self.pes = PositionalEncoding(in_channels, num_node, num_frame, 'spatial') 247 | self.ff_nets = nn.Sequential( 248 | nn.Conv2d(out_channels, out_channels, 1, 1, padding=0, bias=True), 249 | nn.BatchNorm2d(out_channels), 250 | ) 251 | if att_s: 252 | self.in_nets = nn.Conv2d(in_channels, 2 * num_subset * inter_channels, 1, bias=True) 253 | self.alphas = nn.Parameter(torch.ones(1, num_subset, 1, 1), requires_grad=True) 254 | if glo_reg_s: 255 | self.attention0s = nn.Parameter(torch.ones(1, num_subset, num_node, num_node) / num_node, 256 | requires_grad=True) 257 | 258 | self.out_nets = nn.Sequential( 259 | nn.Conv2d(in_channels * num_subset, out_channels, 1, bias=True), 260 | nn.BatchNorm2d(out_channels), 261 | ) 262 | else: 263 | self.out_nets = nn.Sequential( 264 | nn.Conv2d(in_channels, out_channels, (1, 3), padding=(0, 1), bias=True, stride=1), 265 | nn.BatchNorm2d(out_channels), 266 | ) 267 | self.use_temporal_att = use_temporal_att 268 | if use_temporal_att: 269 | attt = torch.zeros((1, num_subset, num_frame, num_frame)) 270 | self.register_buffer('attt', attt) 271 | self.pet = PositionalEncoding(out_channels, num_node, num_frame, 'temporal') 272 | self.ff_nett = nn.Sequential( 273 | nn.Conv2d(out_channels, out_channels, (kernel_size, 1), (stride, 1), padding=(pad, 0), bias=True), 274 | nn.BatchNorm2d(out_channels), 275 | ) 276 | if att_t: 277 | self.in_nett = nn.Conv2d(out_channels, 2 * num_subset * inter_channels, 1, bias=True) 278 | self.alphat = nn.Parameter(torch.ones(1, num_subset, 1, 1), requires_grad=True) 279 | if glo_reg_t: 280 | self.attention0t = nn.Parameter(torch.zeros(1, num_subset, num_frame, num_frame) + torch.eye(num_frame), 281 | requires_grad=True) 282 | self.out_nett = nn.Sequential( 283 | nn.Conv2d(out_channels * num_subset, out_channels, 1, bias=True), 284 | nn.BatchNorm2d(out_channels), 285 | ) 286 | else: 287 | self.out_nett = nn.Sequential( 288 | nn.Conv2d(out_channels, out_channels, (7, 1), padding=(3, 0), bias=True, stride=(stride, 1)), 289 | nn.BatchNorm2d(out_channels), 290 | ) 291 | 292 | if in_channels != out_channels or stride != 1: 293 | if use_spatial_att: 294 | self.downs1 = nn.Sequential( 295 | nn.Conv2d(in_channels, out_channels, 1, bias=True), 296 | nn.BatchNorm2d(out_channels), 297 | ) 298 | self.downs2 = nn.Sequential( 299 | nn.Conv2d(in_channels, out_channels, 1, bias=True), 300 | nn.BatchNorm2d(out_channels), 301 | ) 302 | if use_temporal_att: 303 | self.downt1 = nn.Sequential( 304 | nn.Conv2d(out_channels, out_channels, 1, 1, bias=True), 305 | nn.BatchNorm2d(out_channels), 306 | ) 307 | self.downt2 = nn.Sequential( 308 | nn.Conv2d(out_channels, out_channels, (kernel_size, 1), (stride, 1), padding=(pad, 0), bias=True), 309 | nn.BatchNorm2d(out_channels), 310 | ) 311 | else: 312 | if use_spatial_att: 313 | self.downs1 = lambda x: x 314 | self.downs2 = lambda x: x 315 | if use_temporal_att: 316 | self.downt1 = lambda x: x 317 | self.downt2 = lambda x: x 318 | 319 | self.soft = nn.Softmax(-2) 320 | self.tan = nn.Tanh() 321 | self.relu = nn.LeakyReLU(0.1) 322 | self.drop = nn.Dropout(attentiondrop) 323 | 324 | def forward(self, x): 325 | 326 | N, C, T, V = x.size() 327 | if self.use_spatial_att: 328 | attention = self.atts 329 | if self.use_pes: 330 | y = self.pes(x) 331 | else: 332 | y = x 333 | if self.att_s: 334 | y = y.permute(0, 3, 1, 2).contiguous() 335 | dism = torch.cdist(y.view(y.shape[0], y.shape[1], y.shape[2]*y.shape[3]), 336 | y.view(y.shape[0], y.shape[1], y.shape[2]*y.shape[3]), p=2) 337 | dism = torch.pow(dism /(C * T), 2) 338 | dsexp = torch.exp(self.theta) 339 | dism = torch.unsqueeze(dism, 1).repeat(1, self.num_subset, 1, 1) 340 | kernel_atten = torch.exp(-1 * dsexp * dism) 341 | 342 | if self.diffusion == 2: 343 | attention_df2 = torch.matmul(kernel_atten, kernel_atten) / kernel_atten.size()[3] 344 | kernel_atten = 0.5 * kernel_atten + attention_df2 * self.dfw2 * 0.5 345 | elif self.diffusion == 3: 346 | attention_df2 = torch.matmul(kernel_atten, kernel_atten) / kernel_atten.size()[3] 347 | attention_df3 = torch.matmul(attention_df2, kernel_atten) / kernel_atten.size()[3] 348 | kernel_atten = 0.4 * kernel_atten + attention_df2 * self.dfw2 * 0.3 + attention_df3 * self.dfw3 * 0.3 349 | elif self.diffusion == 4: 350 | attention_df2 = torch.matmul(kernel_atten, kernel_atten) / kernel_atten.size()[3] 351 | attention_df3 = torch.matmul(attention_df2, kernel_atten) / kernel_atten.size()[3] 352 | attention_df4 = torch.matmul(attention_df3, kernel_atten) / kernel_atten.size()[3] 353 | kernel_atten = 0.25 * kernel_atten + attention_df2 * self.dfw2 * 0.25 + attention_df3 * self.dfw3 * 0.25 + attention_df4 * self.dfw4 * 0.25 354 | elif self.diffusion == 5: 355 | attention_df2 = torch.matmul(kernel_atten, kernel_atten) / kernel_atten.size()[3] 356 | attention_df3 = torch.matmul(attention_df2, kernel_atten) / kernel_atten.size()[3] 357 | attention_df4 = torch.matmul(attention_df3, kernel_atten) / kernel_atten.size()[3] 358 | attention_df5 = torch.matmul(attention_df3, kernel_atten) / kernel_atten.size()[3] 359 | kernel_atten = 0.2 * kernel_atten + attention_df2 * self.dfw2 * 0.2 + attention_df3 * self.dfw3 * 0.2 + attention_df4 * self.dfw4 * 0.2 + attention_df5 * self.dfw5 * 0.2 360 | 361 | attention = attention + self.tan(kernel_atten) * self.alphas 362 | 363 | if self.glo_reg_s: 364 | attention = attention + self.attention0s.repeat(N, 1, 1, 1) 365 | attention = self.drop(attention) 366 | y = torch.einsum('nctu,nsuv->nsctv', [x, attention]).contiguous() \ 367 | .view(N, self.num_subset * self.in_channels, T, V) 368 | y = self.out_nets(y) # nctv 369 | y = self.relu(self.downs1(x) + y) 370 | y = self.ff_nets(y) 371 | y = self.relu(self.downs2(x) + y) 372 | else: 373 | y = self.out_nets(x) 374 | y = self.relu(self.downs2(x) + y) 375 | 376 | if self.use_temporal_att: 377 | attention = self.attt 378 | if self.use_pet: 379 | z = self.pet(y) 380 | else: 381 | z = y 382 | if self.att_t: 383 | q, k = torch.chunk(self.in_nett(z).view(N, 2 * self.num_subset, self.inter_channels, T, V), 2, 384 | dim=1) # nctv -> n num_subset c'tv 385 | attention = attention + self.tan( 386 | torch.einsum('nsctv,nscqv->nstq', [q, k]) / (self.inter_channels * V)) * self.alphat 387 | if self.glo_reg_t: 388 | attention = attention + self.attention0t.repeat(N, 1, 1, 1) 389 | attention = self.drop(attention) 390 | z = torch.einsum('nctv,nstq->nscqv', [y, attention]).contiguous() \ 391 | .view(N, self.num_subset * self.out_channels, T, V) 392 | z = self.out_nett(z) # nctv 393 | 394 | z = self.relu(self.downt1(y) + z) 395 | 396 | z = self.ff_nett(z) 397 | 398 | z = self.relu(self.downt2(y) + z) 399 | else: 400 | z = self.out_nett(y) 401 | z = self.relu(self.downt2(y) + z) 402 | return z 403 | 404 | 405 | class DSTANet(nn.Module): 406 | def __init__(self, num_class=60, num_point=25, num_frame=32, num_subset=3, dropout=0., config=None, num_person=2, 407 | num_channel=3, glo_reg_s=True, att_s=True, glo_reg_t=False, att_t=True, 408 | use_temporal_att=True, use_spatial_att=True, attentiondrop=0, dropout2d=0, use_pet=True, use_pes=True, kernelattention=True, diffusion=0): 409 | super(DSTANet, self).__init__() 410 | 411 | self.out_channels = config[-1][1] 412 | in_channels = config[0][0] 413 | 414 | self.input_map = nn.Sequential( 415 | nn.Conv2d(num_channel, in_channels, 1), 416 | nn.BatchNorm2d(in_channels), 417 | nn.LeakyReLU(0.1), 418 | ) 419 | 420 | param = { 421 | 'num_node': num_point, 422 | 'num_subset': num_subset, 423 | 'glo_reg_s': glo_reg_s, 424 | 'att_s': att_s, 425 | 'glo_reg_t': glo_reg_t, 426 | 'att_t': att_t, 427 | 'use_spatial_att': use_spatial_att, 428 | 'use_temporal_att': use_temporal_att, 429 | 'use_pet': use_pet, 430 | 'use_pes': use_pes, 431 | 'attentiondrop': attentiondrop, 432 | 'diffusion': diffusion 433 | } 434 | self.graph_layers = nn.ModuleList() 435 | for index, (in_channels, out_channels, inter_channels, stride) in enumerate(config): 436 | if not kernelattention:# use original linear attention 437 | self.graph_layers.append( 438 | STAttentionBlock(in_channels, out_channels, inter_channels, stride=stride, num_frame=num_frame, 439 | **param)) 440 | else: # use kernel attention. if diffusion >1, diffusion will be applied 441 | self.graph_layers.append( 442 | STKernelAttentionBlock(in_channels, out_channels, inter_channels, stride=stride, num_frame=num_frame, 443 | **param)) 444 | num_frame = int(num_frame / stride + 0.5) 445 | 446 | self.fc = nn.Linear(self.out_channels, num_class) 447 | 448 | self.drop_out = nn.Dropout(dropout) 449 | self.drop_out2d = nn.Dropout2d(dropout2d) 450 | 451 | for m in self.modules(): 452 | if isinstance(m, nn.Conv2d): 453 | conv_init(m) 454 | elif isinstance(m, nn.BatchNorm2d): 455 | bn_init(m, 1) 456 | elif isinstance(m, nn.Linear): 457 | fc_init(m) 458 | 459 | def forward(self, x): 460 | """ 461 | 462 | :param x: N M C T V 463 | :return: classes scores 464 | """ 465 | #N, C, T, V, M = x.shape 466 | N, M, T, V, C = x.shape 467 | x = x.permute(0, 1, 4, 2, 3).contiguous().view(N * M, C, T, V) 468 | x = self.input_map(x) 469 | 470 | for i, m in enumerate(self.graph_layers): 471 | x = m(x) 472 | 473 | # NM, C, T, V 474 | x = x.view(N, M, self.out_channels, -1) 475 | x = x.permute(0, 1, 3, 2).contiguous().view(N, -1, self.out_channels, 1) # whole channels of one spatial 476 | x = self.drop_out2d(x) 477 | x = x.mean(3).mean(1) 478 | x = self.drop_out(x) # whole spatial of one channel 479 | return self.fc(x) 480 | 481 | 482 | if __name__ == '__main__': 483 | config = [[64, 64, 16, 1], [64, 64, 16, 1], 484 | [64, 128, 32, 2], [128, 128, 32, 1], 485 | [128, 256, 64, 2], [256, 256, 64, 1], 486 | [256, 256, 64, 1], [256, 256, 64, 1], 487 | ] 488 | net = DSTANet(config=config) # .cuda() 489 | ske = torch.rand([2, 3, 32, 25, 2]) # .cuda() 490 | #print(net(ske).shape) 491 | -------------------------------------------------------------------------------- /prepare/ADNI/ADNI_data_gen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from tqdm import tqdm 4 | import sys 5 | 6 | sys.path.extend(['../']) 7 | 8 | import numpy as np 9 | import os 10 | from scipy.io import loadmat 11 | 12 | max_frame = 130 13 | 14 | def gendata(data_path, out_path, roinum, part,isplit): 15 | inputfilename = data_path + 'ADNI_TS_' + roinum + 'ROI__' + part + '_5split_' + str(isplit) + '.mat' 16 | matdata = loadmat(inputfilename) 17 | 18 | sample_label = matdata['label'].transpose().tolist()[0] 19 | sample_name = [] 20 | for i, s in enumerate(tqdm(range(len(sample_label)))): 21 | sample_name.append(part + '_TS_' + roinum + str(i)) 22 | 23 | with open('{}/{}_5split_{}_{}_label.pkl'.format(out_path, roinum, part,isplit), 'wb') as f: 24 | pickle.dump((sample_name, list(sample_label)), f) 25 | 26 | fp = np.zeros((len(sample_label), 1, max_frame, int(roinum),6), dtype=np.float32) 27 | 28 | for i, s in enumerate(tqdm(range(len(sample_label)))): 29 | data = matdata['ROISignals'][i][0] 30 | 31 | data = np.stack([data, 32 | matdata['ROISignals_5bands'][0][0][i][0], 33 | matdata['ROISignals_5bands'][1][0][i][0], 34 | matdata['ROISignals_5bands'][2][0][i][0], 35 | matdata['ROISignals_5bands'][3][0][i][0], 36 | matdata['ROISignals_5bands'][4][0][i][0]],axis=2) 37 | fp[i, 0, 0:data.shape[0], :,:] = data 38 | # nframe = data.shape[1] 39 | # nrep = int(np.floor(max_frame / nframe)) 40 | # for irep in range(nrep): 41 | # fp[i, :, (0 + irep * nframe):((irep + 1) * nframe), :] = data 42 | # fp[i, :, (0 + nrep * nframe):max_frame, :] = data[:,0:(max_frame-nrep*nframe),:] 43 | 44 | #fp = pre_normalization(fp) 45 | np.save('{}/{}_5split_{}_{}_data.npy'.format(out_path, roinum, part,isplit), fp) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser(description='ADHD Data Converter.') 50 | parser.add_argument('--data_path', default='/data/datasets/ADNI/') 51 | parser.add_argument('--out_folder', default='/data/datasets/ADNI/Processed/') 52 | 53 | Roinums = ['90', '42'] 54 | part = ['train', 'test'] 55 | arg = parser.parse_args() 56 | 57 | for roinum in Roinums: 58 | for p in part: 59 | out_path = os.path.join(arg.out_folder, roinum) 60 | if not os.path.exists(out_path): 61 | os.makedirs(out_path) 62 | for isplit in range(1,11): 63 | print(roinum, p,isplit) 64 | 65 | gendata( 66 | arg.data_path, 67 | out_path, 68 | roinum=roinum, 69 | part=p, isplit=isplit) 70 | -------------------------------------------------------------------------------- /prepare/dhg/gendata.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from tqdm import tqdm 3 | import sys 4 | from dataset.rotation import * 5 | from dataset.normalize_skeletons import normalize_skeletons 6 | sys.path.extend(['../../']) 7 | 8 | import numpy as np 9 | import os 10 | 11 | 12 | def read_skeleton(ske_txt): 13 | ske_txt = open(ske_txt, 'r').readlines() 14 | skeletons = [] 15 | for line in ske_txt: 16 | nums = line.split(' ') 17 | # num_frame = int(nums[0]) + 1 18 | coords_frame = np.array(nums).reshape((22, 3)).astype(np.float32) 19 | skeletons.append(coords_frame) 20 | num_frame = len(skeletons) 21 | skeletons = np.expand_dims(np.array(skeletons).transpose((2, 0, 1)), axis=-1) # CTVM 22 | skeletons = np.transpose(skeletons, [3, 1, 2, 0]) # CTVM-MTVC 23 | return skeletons, num_frame 24 | 25 | 26 | def gendata(): 27 | root = '/your/path/to/dhg_hand/' 28 | split_txt = 'informations_troncage_sequences.txt' 29 | split = open(os.path.join(root, split_txt), 'r').readlines() 30 | num_sub = 20 31 | skeletons_all_train = [[] for i in range(num_sub)] 32 | names_all_train = [[] for i in range(num_sub)] 33 | labels14_all_train = [[] for i in range(num_sub)] 34 | labels28_all_train = [[] for i in range(num_sub)] 35 | skeletons_all_val = [[] for i in range(num_sub)] 36 | names_all_val = [[] for i in range(num_sub)] 37 | labels14_all_val = [[] for i in range(num_sub)] 38 | labels28_all_val = [[] for i in range(num_sub)] 39 | for line in tqdm(split): 40 | line = line.split("\n")[0] 41 | data = line.split(" ") 42 | g_id = data[0] 43 | f_id = data[1] 44 | sub_id = data[2] 45 | e_id = data[3] 46 | start_frame = int(data[4]) 47 | end_frame = int(data[5]) 48 | src_path = os.path.join(root, "gesture_{}/finger_{}/subject_{}/essai_{}/skeleton_world.txt" 49 | .format(g_id, f_id, sub_id, e_id)) 50 | skeletons, num_frame = read_skeleton(src_path) 51 | skeletons = skeletons[:, start_frame:end_frame + 1] 52 | skeletons = normalize_skeletons(skeletons, origin=0) 53 | # ske_vis(skeletons, view=1, pause=0.1) 54 | label14 = int(g_id) - 1 55 | if int(f_id) == 1: 56 | label28 = int(g_id) - 1 57 | else: 58 | label28 = int(g_id) - 1 + 14 59 | for id in range(num_sub): 60 | if id == int(sub_id) - 1: 61 | skeletons_all_val[id].append(skeletons) 62 | labels14_all_val[id].append(label14) 63 | labels28_all_val[id].append(label28) 64 | names_all_val[id].append("{}_{}_{}_{}".format(g_id, f_id, sub_id, e_id)) 65 | else: 66 | skeletons_all_train[id].append(skeletons) 67 | labels14_all_train[id].append(label14) 68 | labels28_all_train[id].append(label28) 69 | names_all_train[id].append("{}_{}_{}_{}".format(g_id, f_id, sub_id, e_id)) 70 | for id in range(num_sub): 71 | pickle.dump(skeletons_all_train[id], open(os.path.join(root, 'train_skeleton_{}.pkl'.format(id)), 'wb')) 72 | pickle.dump(skeletons_all_val[id], open(os.path.join(root, 'val_skeleton_{}.pkl'.format(id)), 'wb')) 73 | pickle.dump([names_all_train[id], labels14_all_train[id]], 74 | open(os.path.join(root, 'train_label_{}_14.pkl'.format(id)), 'wb')) 75 | pickle.dump([names_all_val[id], labels14_all_val[id]], 76 | open(os.path.join(root, 'val_label_{}_14.pkl'.format(id)), 'wb')) 77 | pickle.dump([names_all_train[id], labels28_all_train[id]], 78 | open(os.path.join(root, 'train_label_{}_28.pkl'.format(id)), 'wb')) 79 | pickle.dump([names_all_val[id], labels28_all_val[id]], 80 | open(os.path.join(root, 'val_label_{}_28.pkl'.format(id)), 'wb')) 81 | 82 | 83 | def ske_vis(data, **kwargs): 84 | from dataset.skeleton import vis 85 | from dataset.dhg_skeleton import edge 86 | vis(data, edge=edge, **kwargs) 87 | 88 | 89 | if __name__ == '__main__': 90 | gendata() -------------------------------------------------------------------------------- /prepare/dhg/joints.txt: -------------------------------------------------------------------------------- 1 | Wrist 2 | Palm 3 | T1 4 | T2 5 | T3 6 | T4 7 | I1 8 | I2 9 | I3 10 | I4 11 | M1 12 | M2 13 | M3 14 | M4 15 | R1 16 | R2 17 | R3 18 | R4 19 | L1 20 | L2 21 | L3 22 | L4 23 | -------------------------------------------------------------------------------- /prepare/dhg/label.txt: -------------------------------------------------------------------------------- 1 | Grab 2 | Tap 3 | Expand 4 | Pinch 5 | Rotation Clockwise 6 | Rotation Counter Clockwise 7 | Swipe Right 8 | Swipe Left 9 | Swipe Up 10 | Swipe Down 11 | Swipe X 12 | Swipe + 13 | Swipe V 14 | Shake -------------------------------------------------------------------------------- /prepare/ntu_120/gendata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | import sys 5 | from prepare.ntu_60.gendata import gendata 6 | 7 | sys.path.extend(['../../']) 8 | 9 | training_subjects = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 10 | 38, 45, 46, 47, 49, 50, 52, 53, 54, 55, 56, 57, 58, 59, 70, 74, 78, 11 | 80, 81, 82, 83, 84, 85, 86, 89, 91, 92, 93, 94, 95, 97, 98, 100, 103] 12 | training_cameras = [2, 3] 13 | max_body_true = 2 14 | max_body_kinect = 4 15 | num_joint = 25 16 | max_frame = 300 17 | max_channel = 5 # xyz+xy 18 | num_channel = 3 19 | channel_name = ['x', 'y', 'z', 'colorX', 'colorY'] 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.') 24 | parser.add_argument('--data_path', default='/your/path/to/ntu_120_raw/') 25 | parser.add_argument('--ignored_sample_path', 26 | default='/your/path/to/ntu_120_raw/samples_with_missing_skeletons_120.txt') 27 | parser.add_argument('--out_folder', default='/your/path/to/ntu_120/') 28 | 29 | benchmark = ['xset', 'xsub'] 30 | part = ['train', 'val'] 31 | arg = parser.parse_args() 32 | 33 | for b in benchmark: 34 | for p in part: 35 | out_path = os.path.join(arg.out_folder, b) 36 | if not os.path.exists(out_path): 37 | os.makedirs(out_path) 38 | print(b, p) 39 | gendata( 40 | arg.data_path, 41 | out_path, 42 | arg.ignored_sample_path, 43 | benchmark=b, 44 | part=p, training_subjects=training_subjects) 45 | -------------------------------------------------------------------------------- /prepare/ntu_120/joints.txt: -------------------------------------------------------------------------------- 1 | 1-base of the spine 2 | 2-middle of the spine 3 | 3-neck 4 | 4-head 5 | 5-left shoulder 6 | 6-left elbow 7 | 7-left wrist 8 | 8-left hand 9 | 9-right shoulder 10 | 10-right elbow 11 | 11-right wrist 12 | 12-right hand 13 | 13-left hip 14 | 14-left knee 15 | 15-left ankle 16 | 16-left foot 17 | 17-right hip 18 | 18-right knee 19 | 19-right ankle 20 | 20-right foot 21 | 21-spine 22 | 22-tip of the left hand 23 | 23-left thumb 24 | 24-tip of the right hand 25 | 25-right thumb -------------------------------------------------------------------------------- /prepare/ntu_120/label.txt: -------------------------------------------------------------------------------- 1 | drink water. 2 | eat meal/snack. 3 | brushing teeth. 4 | brushing hair. 5 | drop. 6 | pickup. 7 | throw. 8 | sitting down. 9 | standing up (from sitting position). 10 | clapping. 11 | reading. 12 | writing. 13 | tear up paper. 14 | wear jacket. 15 | take off jacket. 16 | wear a shoe. 17 | take off a shoe. 18 | wear on glasses. 19 | take off glasses. 20 | put on a hat/cap. 21 | take off a hat/cap. 22 | cheer up. 23 | hand waving. 24 | kicking something. 25 | reach into pocket. 26 | hopping (one foot jumping). 27 | jump up. 28 | make a phone call/answer phone. 29 | playing with phone/tablet. 30 | typing on a keyboard. 31 | pointing to something with finger. 32 | taking a selfie. 33 | check time (from watch). 34 | rub two hands together. 35 | nod head/bow. 36 | shake head. 37 | wipe face. 38 | salute. 39 | put the palms together. 40 | cross hands in front (say stop). 41 | sneeze/cough. 42 | staggering. 43 | falling. 44 | touch head (headache). 45 | touch chest (stomachache/heart pain). 46 | touch back (backache). 47 | touch neck (neckache). 48 | nausea or vomiting condition. 49 | use a fan (with hand or paper)/feeling warm. 50 | punching/slapping other person. 51 | kicking other person. 52 | pushing other person. 53 | pat on back of other person. 54 | point finger at the other person. 55 | hugging other person. 56 | giving something to other person. 57 | touch other person's pocket. 58 | handshaking. 59 | walking towards each other. 60 | walking apart from each other. 61 | put on headphone. 62 | take off headphone. 63 | shoot at the basket. 64 | bounce ball. 65 | tennis bat swing. 66 | juggling table tennis balls. 67 | hush (quite). 68 | flick hair. 69 | thumb up. 70 | thumb down. 71 | make ok sign. 72 | make victory sign. 73 | staple book. 74 | counting money. 75 | cutting nails. 76 | cutting paper (using scissors). 77 | snapping fingers. 78 | open bottle. 79 | sniff (smell). 80 | squat down. 81 | toss a coin. 82 | fold paper. 83 | ball up paper. 84 | play magic cube. 85 | apply cream on face. 86 | apply cream on hand back. 87 | put on bag. 88 | take off bag. 89 | put something into a bag. 90 | take something out of a bag. 91 | open a box. 92 | move heavy objects. 93 | shake fist. 94 | throw up cap/hat. 95 | hands up (both hands). 96 | cross arms. 97 | arm circles. 98 | arm swings. 99 | running on the spot. 100 | butt kicks (kick backward). 101 | cross toe touch. 102 | side kick. 103 | yawn. 104 | stretch oneself. 105 | blow nose. 106 | hit other person with something. 107 | wield knife towards other person. 108 | knock over other person (hit with body). 109 | grab other person’s stuff. 110 | shoot at other person with a gun. 111 | step on foot. 112 | high-five. 113 | cheers and drink. 114 | carry something with other person. 115 | take a photo of other person. 116 | follow other person. 117 | whisper in other person’s ear. 118 | exchange things with other person. 119 | support somebody with hand. 120 | finger-guessing game (playing rock-paper-scissors). -------------------------------------------------------------------------------- /prepare/ntu_60/gendata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from tqdm import tqdm 4 | import sys 5 | from dataset.normalize_skeletons import normalize_skeletons 6 | 7 | sys.path.extend(['../../']) 8 | training_subjects = [ 9 | 1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38 10 | ] 11 | training_cameras = [2, 3] 12 | max_body_true = 2 13 | max_body_kinect = 4 14 | num_joint = 25 15 | max_frame = 300 16 | max_channel = 5 # xyz+xy 17 | num_channel = 3 18 | channel_name = ['x', 'y', 'z', 'colorX', 'colorY'] 19 | 20 | import numpy as np 21 | import os 22 | 23 | 24 | def read_skeleton_filter(file): 25 | with open(file, 'r') as f: 26 | skeleton_sequence = {} 27 | skeleton_sequence['numFrame'] = int(f.readline()) 28 | skeleton_sequence['frameInfo'] = [] 29 | # num_body = 0 30 | for t in range(skeleton_sequence['numFrame']): 31 | frame_info = {} 32 | frame_info['numBody'] = int(f.readline()) 33 | frame_info['bodyInfo'] = [] 34 | 35 | for m in range(frame_info['numBody']): 36 | body_info = {} 37 | body_info_key = [ 38 | 'bodyID', 'clipedEdges', 'handLeftConfidence', 39 | 'handLeftState', 'handRightConfidence', 'handRightState', 40 | 'isResticted', 'leanX', 'leanY', 'trackingState' 41 | ] 42 | body_info = { 43 | k: float(v) 44 | for k, v in zip(body_info_key, f.readline().split()) 45 | } 46 | body_info['bodyID'] = int(body_info['bodyID']) 47 | body_info['numJoint'] = int(f.readline()) 48 | body_info['jointInfo'] = [] 49 | for v in range(body_info['numJoint']): 50 | joint_info_key = [ 51 | 'x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY', 52 | 'orientationW', 'orientationX', 'orientationY', 53 | 'orientationZ', 'trackingState' 54 | ] 55 | joint_info = { 56 | k: float(v) 57 | for k, v in zip(joint_info_key, f.readline().split()) 58 | } 59 | body_info['jointInfo'].append(joint_info) 60 | frame_info['bodyInfo'].append(body_info) 61 | skeleton_sequence['frameInfo'].append(frame_info) 62 | 63 | return skeleton_sequence 64 | 65 | 66 | def get_body_info(skeletons): 67 | num_frames = skeletons['numFrame'] 68 | bodys = {} 69 | for index_f, frames in enumerate(skeletons['frameInfo']): 70 | num_body = frames['numBody'] 71 | frame_bodyid = [] 72 | for body in frames['bodyInfo']: 73 | body_id = body['bodyID'] 74 | if body_id not in frame_bodyid: 75 | frame_bodyid.append(body_id) 76 | else: 77 | while body_id in frame_bodyid: 78 | body_id += 1 79 | if body_id not in bodys.keys(): 80 | bodys[body_id] = np.zeros((max_channel, max_frame, num_joint)) 81 | for c, c_str in enumerate(channel_name): 82 | for j in range(num_joint): 83 | bodys[body_id][c, index_f, j] = body['jointInfo'][j][c_str] 84 | return bodys 85 | 86 | 87 | def get_nonzero_std(s): # ctv 88 | s = s - s[:, :, 0:1] # sub center joint 89 | index = s[:3].sum(0).sum(-1) != 0 # select valid frames 90 | s = s[:, index] 91 | if len(s) != 0: 92 | s = s[0].std() + s[1].std() + s[2].std() # std of three channels 93 | else: 94 | s = 0 95 | return s 96 | 97 | 98 | def xy_valid(body): 99 | ''' 100 | Judge whether the body is valid 101 | :param body: 102 | :return: True or False 103 | ''' 104 | index = body[:num_channel].sum(0).sum(-1) != 0 # select valid frames 105 | body = body[:, index] 106 | x = body[0, 0].max() - body[0, 0].min() 107 | y = body[1, 0].max() - body[1, 0].min() 108 | return y * 0.8 > x 109 | 110 | 111 | def filter_body(bodys): 112 | ''' 113 | Filter bodys to max number person, return mctv 114 | :param bodys: MCTV m=5 115 | :return: MCTV 116 | ''' 117 | if len(bodys) == 1: 118 | bodys = np.array([item for k, item in bodys.items()]) 119 | bodys = np.transpose(bodys, [0, 2, 3, 1]) # M, T, V, C 120 | return bodys 121 | 122 | bodys = np.array([item for k, item in bodys.items()]) 123 | # bodys[:, :, :1] = 0 # remove first frame 124 | # bodys = bodys[bodys[:, :num_channel].sum(-1).sum(-1).sum(-1) != 0] # remove 0 body 125 | 126 | # body sort by energy 127 | energy = np.array([get_nonzero_std(x) for x in bodys]) 128 | index = energy.argsort()[::-1] 129 | bodys = bodys[index] # 0.63 0.5 130 | 131 | # filter objs 132 | energy = np.array([get_nonzero_std(x) for x in bodys]) 133 | energy_min = max(energy) * 0.85 134 | del_list = np.where(np.array(energy < energy_min) == True)[0] 135 | for i in del_list[::-1]: 136 | if not xy_valid(bodys[i]): # delete obj should be obj 137 | bodys = np.concatenate([bodys[:i], bodys[i + 1:]], 0) 138 | 139 | # concat by durs 140 | # body_durs = [] 141 | # for i, body in enumerate(bodys): 142 | # valid_frames = np.where(body.sum(0).sum(-1) != 0)[0] 143 | # body_durs.append([valid_frames.min(), valid_frames.max()]) 144 | # 145 | # del_list = [] 146 | # for i, (begin, end) in enumerate(body_durs): 147 | # if begin == end: 148 | # continue 149 | # if i in del_list: 150 | # continue 151 | # for j, (begin2, end2) in enumerate(body_durs): 152 | # if j in del_list: 153 | # continue 154 | # if np.abs(begin2 - end) < 10: 155 | # pass 156 | # if end == begin2 - 1: 157 | # bodys[i] = bodys[i] + bodys[j] 158 | # del_list.append(j) 159 | # break 160 | # for i in del_list[::-1]: 161 | # bodys = np.concatenate([bodys[:i], bodys[i+1:]], 0) 162 | 163 | # del bodys that are too short 164 | # body_frames = [] 165 | # del_list = [] 166 | # for i, body in enumerate(bodys): 167 | # valid_frames = np.where(body.sum(0).sum(-1) != 0)[0] 168 | # body_frames.append(valid_frames.max() - valid_frames.min()) 169 | # body_frame_max = max(body_frames) 170 | # for i, f in enumerate(body_frames): 171 | # if f < body_frame_max * 0.2: 172 | # del_list.append(i) 173 | # for i in del_list[::-1]: 174 | # bodys = np.concatenate([bodys[:i], bodys[i + 1:]], 0) 175 | 176 | # remove incomplete frames 有些双人某个人很少 177 | # begins = [] 178 | # ends = [] 179 | # for i, body in enumerate(bodys): 180 | # valid_frames = np.where(body.sum(0).sum(-1) != 0)[0] 181 | # begins.append(valid_frames.min()) 182 | # ends.append(valid_frames.max()) 183 | # bodys[:, :, min(begins):max(begins)] = 0 184 | # bodys[:, :, min(ends):max(ends)] = 0 185 | 186 | # save max num bodys for new bodys 187 | energy = np.array([get_nonzero_std(x) for x in bodys]) 188 | index = energy.argsort()[::-1][0:max_body_true] 189 | bodys = bodys[index] 190 | 191 | bodys = np.transpose(bodys, [0, 2, 3, 1]) # M, T, V, C 192 | 193 | return bodys 194 | 195 | 196 | def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xview', part='eval', training_subjects=None): 197 | if ignored_sample_path != None: 198 | with open(ignored_sample_path, 'r') as f: 199 | ignored_samples = [ 200 | line.strip() + '.skeleton' for line in f.readlines() 201 | ] 202 | else: 203 | ignored_samples = [] 204 | 205 | sample_names = [] 206 | sample_labels = [] 207 | for filename in os.listdir(data_path): 208 | if filename in ignored_samples: 209 | continue 210 | setup_class = int( 211 | filename[filename.find('S') + 1:filename.find('S') + 4]) 212 | action_class = int( 213 | filename[filename.find('A') + 1:filename.find('A') + 4]) 214 | subject_id = int( 215 | filename[filename.find('P') + 1:filename.find('P') + 4]) 216 | camera_id = int( 217 | filename[filename.find('C') + 1:filename.find('C') + 4]) 218 | 219 | if benchmark == 'xview': 220 | istraining = (camera_id in training_cameras) 221 | elif benchmark == 'xsub': 222 | istraining = (subject_id in training_subjects) 223 | elif benchmark == 'xset': 224 | istraining = (setup_class % 2 == 1) 225 | else: 226 | raise ValueError() 227 | 228 | if part == 'train': 229 | issample = istraining 230 | elif part == 'val': 231 | issample = not (istraining) 232 | else: 233 | raise ValueError() 234 | 235 | if issample: 236 | sample_names.append(filename) 237 | sample_labels.append(action_class - 1) 238 | 239 | with open('{}/{}_label.pkl'.format(out_path, part), 'wb') as f: 240 | pickle.dump((sample_names, list(sample_labels)), f) 241 | 242 | data_skeleton = np.zeros((len(sample_labels), num_channel, max_frame, num_joint, max_body_true), dtype=np.float32) 243 | data_rgb_position = np.zeros((len(sample_labels), max_channel - num_channel, max_frame, num_joint, max_body_true), 244 | dtype=np.float32) 245 | num_frames = [] 246 | for i, sample_name in enumerate(tqdm(sample_names)): 247 | seq_info = read_skeleton_filter(os.path.join(data_path, sample_name)) 248 | num_frames.append(seq_info['numFrame']) 249 | bodys = get_body_info(seq_info) 250 | bodys = filter_body(bodys) # mtvc 251 | num_body = bodys.shape[0] 252 | skeletons = normalize_skeletons(bodys[..., :3], origin=0, base_bone=[0, 20], zaxis=[0, 20], xaxis=[20, 5]) 253 | # use this to see the preprocessed skeletons 254 | # ske_vis(skeletons, view=1, pause=0.1) 255 | data_skeleton[i, :, :, :, :num_body] = skeletons # ctvm 256 | data_rgb_position[i, :, :, :, :num_body] = bodys.transpose((3, 1, 2, 0))[3:5] 257 | print(max(num_frames)) # 15-300 300挺多的 平均75帧 258 | np.save('{}/{}_data_joint.npy'.format(out_path, part), data_skeleton) 259 | np.save('{}/{}_joint_position_in_img.npy'.format(out_path, part), data_rgb_position) 260 | 261 | 262 | def ske_vis(data, **kwargs): 263 | from dataset.skeleton import vis 264 | from dataset.ntu_skeleton import edge 265 | vis(data, edge=edge, **kwargs) 266 | 267 | 268 | if __name__ == '__main__': 269 | parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.') 270 | parser.add_argument('--data_path', default='/mnt/data/datasets/nturgb+d_skeletons/') 271 | parser.add_argument('--ignored_sample_path', 272 | default='./missing_new.txt') 273 | parser.add_argument('--out_folder', default='/mnt/data/datasets/ntu_60/') 274 | 275 | benchmark = ['xsub', 'xview'] 276 | part = ['train', 'val'] 277 | arg = parser.parse_args() 278 | 279 | for b in benchmark: 280 | for p in part: 281 | out_path = os.path.join(arg.out_folder, b) 282 | if not os.path.exists(out_path): 283 | os.makedirs(out_path) 284 | print(b, p) 285 | gendata( 286 | arg.data_path, 287 | out_path, 288 | arg.ignored_sample_path, 289 | benchmark=b, 290 | part=p, 291 | training_subjects=training_subjects) 292 | -------------------------------------------------------------------------------- /prepare/ntu_60/joints.txt: -------------------------------------------------------------------------------- 1 | 1-base of the spine 2 | 2-middle of the spine 3 | 3-neck 4 | 4-head 5 | 5-left shoulder 6 | 6-left elbow 7 | 7-left wrist 8 | 8-left hand 9 | 9-right shoulder 10 | 10-right elbow 11 | 11-right wrist 12 | 12-right hand 13 | 13-left hip 14 | 14-left knee 15 | 15-left ankle 16 | 16-left foot 17 | 17-right hip 18 | 18-right knee 19 | 19-right ankle 20 | 20-right foot 21 | 21-spine 22 | 22-tip of the left hand 23 | 23-left thumb 24 | 24-tip of the right hand 25 | 25-right thumb -------------------------------------------------------------------------------- /prepare/ntu_60/label.txt: -------------------------------------------------------------------------------- 1 | drink water. 2 | eat meal/snack. 3 | brushing teeth. 4 | brushing hair. 5 | drop. 6 | pickup. 7 | throw. 8 | sitting down. 9 | standing up (from sitting position). 10 | clapping. 11 | reading. 12 | writing. 13 | tear up paper. 14 | wear jacket. 15 | take off jacket. 16 | wear a shoe. 17 | take off a shoe. 18 | wear on glasses. 19 | take off glasses. 20 | put on a hat/cap. 21 | take off a hat/cap. 22 | cheer up. 23 | hand waving. 24 | kicking something. 25 | reach into pocket. 26 | hopping (one foot jumping). 27 | jump up. 28 | make a phone call/answer phone. 29 | playing with phone/tablet. 30 | typing on a keyboard. 31 | pointing to something with finger. 32 | taking a selfie. 33 | check time (from watch). 34 | rub two hands together. 35 | nod head/bow. 36 | shake head. 37 | wipe face. 38 | salute. 39 | put the palms together. 40 | cross hands in front (say stop). 41 | sneeze/cough. 42 | staggering. 43 | falling. 44 | touch head (headache). 45 | touch chest (stomachache/heart pain). 46 | touch back (backache). 47 | touch neck (neckache). 48 | nausea or vomiting condition. 49 | use a fan (with hand or paper)/feeling warm. 50 | punching/slapping other person. 51 | kicking other person. 52 | pushing other person. 53 | pat on back of other person. 54 | point finger at the other person. 55 | hugging other person. 56 | giving something to other person. 57 | touch other person's pocket. 58 | handshaking. 59 | walking towards each other. 60 | walking apart from each other. -------------------------------------------------------------------------------- /prepare/shrec/gendata.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from tqdm import tqdm 3 | import sys 4 | from dataset.rotation import * 5 | from dataset.normalize_skeletons import normalize_skeletons 6 | 7 | sys.path.extend(['../../']) 8 | 9 | import numpy as np 10 | import os 11 | 12 | 13 | 14 | def read_skeleton(ske_txt): 15 | ske_txt = open(ske_txt, 'r').readlines() 16 | skeletons = [] 17 | for line in ske_txt: 18 | nums = line.split(' ') 19 | # num_frame = int(nums[0]) + 1 20 | coords_frame = np.array(nums).reshape((22, 3)).astype(np.float32) 21 | skeletons.append(coords_frame) 22 | num_frame = len(skeletons) 23 | skeletons = np.expand_dims(np.array(skeletons).transpose((2, 0, 1)), axis=-1) # CTVM 24 | skeletons = np.transpose(skeletons, [3, 1, 2, 0]) # M, T, V, C 25 | return skeletons, num_frame 26 | 27 | 28 | def gendata(): 29 | root = '/your/path/to/shrec_hand/' 30 | train_split = open(os.path.join(root, 'train_gestures.txt'), 'r').readlines() 31 | val_split = open(os.path.join(root, 'test_gestures.txt'), 'r').readlines() 32 | 33 | skeletons_all_train = [] 34 | names_all_train = [] 35 | labels14_all_train = [] 36 | labels28_all_train = [] 37 | skeletons_all_val = [] 38 | names_all_val = [] 39 | labels14_all_val = [] 40 | labels28_all_val = [] 41 | 42 | for line in tqdm(train_split): 43 | line = line.rstrip() 44 | g_id, f_id, sub_id, e_id, label_14, label_28, size_seq = map(int, line.split(" ")) 45 | src_path = os.path.join(root, "gesture_{}/finger_{}/subject_{}/essai_{}/skeletons_world.txt" 46 | .format(g_id, f_id, sub_id, e_id)) 47 | skeletons, num_frame = read_skeleton(src_path) 48 | skeletons = normalize_skeletons(skeletons, origin=0, base_bone=[0, 10]) 49 | # ske_vis(skeletons, view=1, pause=0.1) 50 | skeletons_all_train.append(skeletons) 51 | labels14_all_train.append(label_14-1) 52 | labels28_all_train.append(label_28-1) 53 | names_all_train.append("{}_{}_{}_{}".format(g_id, f_id, sub_id, e_id)) 54 | 55 | pickle.dump(skeletons_all_train, open(os.path.join(root, 'train_skeleton.pkl'), 'wb')) 56 | pickle.dump([names_all_train, labels14_all_train], 57 | open(os.path.join(root, 'train_label_14.pkl'), 'wb')) 58 | pickle.dump([names_all_train, labels28_all_train], 59 | open(os.path.join(root, 'train_label_28.pkl'), 'wb')) 60 | 61 | for line in tqdm(val_split): 62 | line = line.rstrip() 63 | g_id, f_id, sub_id, e_id, label_14, label_28, size_seq = map(int, line.split(" ")) 64 | src_path = os.path.join(root, "gesture_{}/finger_{}/subject_{}/essai_{}/skeletons_world.txt" 65 | .format(g_id, f_id, sub_id, e_id)) 66 | skeletons, num_frame = read_skeleton(src_path) 67 | skeletons = normalize_skeletons(skeletons, origin=0, base_bone=[0, 10]) 68 | 69 | skeletons_all_val.append(skeletons) 70 | labels14_all_val.append(label_14-1) 71 | labels28_all_val.append(label_28-1) 72 | names_all_val.append("{}_{}_{}_{}".format(g_id, f_id, sub_id, e_id)) 73 | 74 | pickle.dump(skeletons_all_val, open(os.path.join(root, 'val_skeleton.pkl'), 'wb')) 75 | pickle.dump([names_all_val, labels14_all_val], 76 | open(os.path.join(root, 'val_label_14.pkl'), 'wb')) 77 | pickle.dump([names_all_val, labels28_all_val], 78 | open(os.path.join(root, 'val_label_28.pkl'), 'wb')) 79 | 80 | 81 | def ske_vis(data, **kwargs): 82 | from dataset.skeleton import vis 83 | from dataset.fpha_skeleton import edge 84 | vis(data, edge=edge, **kwargs) 85 | 86 | 87 | if __name__ == '__main__': 88 | gendata() 89 | -------------------------------------------------------------------------------- /prepare/shrec/joints.txt: -------------------------------------------------------------------------------- 1 | Wrist 2 | Palm 3 | T1 4 | T2 5 | T3 6 | T4 7 | I1 8 | I2 9 | I3 10 | I4 11 | M1 12 | M2 13 | M3 14 | M4 15 | R1 16 | R2 17 | R3 18 | R4 19 | L1 20 | L2 21 | L3 22 | L4 23 | -------------------------------------------------------------------------------- /prepare/shrec/label.txt: -------------------------------------------------------------------------------- 1 | 0Grab 2 | 1Tap 3 | 2Expand 4 | 3Pinch 5 | 4Rotation Clockwise 6 | 5Rotation Counter Clockwise 7 | 6Swipe Right 8 | 7Swipe Left 9 | 8Swipe Up 10 | 9Swipe Down 11 | 10Swipe X 12 | 11Swipe + 13 | 12Swipe V 14 | 13Shake -------------------------------------------------------------------------------- /prepare/shrec/label_28.txt: -------------------------------------------------------------------------------- 1 | 0Grab1 2 | 1Grab5 3 | 2Tap1 4 | 3Tap5 5 | 4Expand1 6 | 5Expand5 7 | 6Pinch1 8 | 7Pinch5 9 | 8Rotation Clockwise1 10 | 9Rotation Clockwise5 11 | 10Rotation Counter Clockwise1 12 | 11Rotation Counter Clockwise5 13 | 12Swipe Right1 14 | 13Swipe Right5 15 | 14Swipe Left1 16 | 15Swipe Left5 17 | 16Swipe Up1 18 | 17Swipe Up5 19 | 18Swipe Down1 20 | 19Swipe Down5 21 | 20Swipe X1 22 | 21Swipe X5 23 | 22Swipe +1 24 | 23Swipe +5 25 | 24Swipe V1 26 | 25Swipe V5 27 | 26Shake1 28 | 27Shake5 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorama==0.4.4 2 | cycler==0.11.0 3 | dataclasses==0.8 4 | distro==1.5.0 5 | easydict==1.9 6 | entmax==1.0 7 | imutils==0.5.4 8 | joblib==1.1.0 9 | kiwisolver==1.3.1 10 | matplotlib==3.3.4 11 | numpy==1.19.5 12 | packaging==21.0 13 | pandas==1.1.5 14 | Pillow==8.3.1 15 | pkg-resources==0.0.0 16 | plyfile==0.7.4 17 | protobuf==3.17.3 18 | pyparsing==2.4.7 19 | python-dateutil==2.8.2 20 | pytz==2021.3 21 | PyYAML==5.4.1 22 | scikit-build==0.11.1 23 | scikit-learn==0.24.2 24 | scipy==1.5.4 25 | setproctitle==1.2.2 26 | six==1.16.0 27 | tensorboardX==1.6 28 | threadpoolctl==3.0.0 29 | torch==1.9.0+cu111 30 | torchaudio==0.9.0 31 | torchvision==0.10.0+cu111 32 | tqdm==4.59.0 33 | typing-extensions==3.10.0.0 34 | -------------------------------------------------------------------------------- /train_val_test/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/train_val_test/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /train_val_test/__pycache__/parser_args.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/train_val_test/__pycache__/parser_args.cpython-36.pyc -------------------------------------------------------------------------------- /train_val_test/__pycache__/train_val_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/train_val_test/__pycache__/train_val_model.cpython-36.pyc -------------------------------------------------------------------------------- /train_val_test/config/config.yaml: -------------------------------------------------------------------------------- 1 | data: 'adni' 2 | data_param: 3 | train_data_param: 4 | data_path: /data/datasets/SyntheticDataForKernelTransformer/train_data.npy 5 | label_path: /data/datasets/SyntheticDataForKernelTransformer/train_label.pkl 6 | random_choose: True 7 | center_choose: False 8 | random_noise: True 9 | random_scale: True 10 | window_size: 200 11 | final_size: 180 12 | num_skip_frame: 13 | decouple_spatial: False 14 | val_data_param: 15 | data_path: /data/datasets/SyntheticDataForKernelTransformer/test_data.npy 16 | label_path: /data/datasets/SyntheticDataForKernelTransformer/test_label.pkl 17 | random_choose: True 18 | center_choose: False 19 | random_noise: False 20 | random_scale: False 21 | window_size: 200 22 | final_size: 180 23 | num_skip_frame: 24 | decouple_spatial: False 25 | augtimes: 1 26 | 27 | # model 28 | model: 'dstanet' 29 | class_num: 2 30 | model_param: 31 | num_point: 42 32 | num_frame: 180 33 | num_subset: 6 34 | num_channel: 6 35 | num_person: 1 36 | glo_reg_s: False 37 | att_s: True 38 | glo_reg_t: False 39 | att_t: False 40 | dropout: 0 41 | attentiondrop: 0 42 | dropout2d: 0 43 | use_spatial_att: True 44 | use_temporal_att: False 45 | use_pet: True 46 | use_pes: True 47 | kernelattention: True 48 | diffusion: 5 49 | config: [[64, 64, 16, 1], [64, 64, 16, 1], 50 | [64, 128, 32, 2], [128, 128, 32, 1], 51 | ] 52 | 53 | 54 | train: 'classify' 55 | mode: 'train_val' 56 | loss: 'cross_entropy' 57 | batch_size: 10 58 | worker: 32 59 | pin_memory: False 60 | num_epoch_per_save: 200 61 | model_saved_name: './work_dir/adni/' 62 | last_model: 63 | pre_trained_model: 64 | ignore_weights: ['fc'] 65 | label_smoothing_num: 0 66 | mix_up_num: 0 67 | device_id: [0] 68 | cuda_visible_device: '0' 69 | debug: False 70 | 71 | # lr 72 | lr_scheduler: 'reduce_by_epoch' 73 | lr_param: 74 | step: [60, 80] 75 | # lr_patience: 10 76 | # lr_threshold: 0.0001 77 | # lr_delay: 0 78 | warm_up_epoch: 5 79 | max_epoch: 100 80 | lr: 0.1 81 | wd: 0.0005 82 | lr_decay_ratio: 0.1 83 | lr_multi_keys: [] 84 | 85 | # optimizer 86 | optimizer: 'sgd_nev' 87 | freeze_keys: [] 88 | 89 | -------------------------------------------------------------------------------- /train_val_test/ensemble.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--label', default='/your/path/to/ntu_60/xsub/val_label.pkl', 9 | help='') 10 | parser.add_argument('--spatial_temporal', default='./work_dir/ntu60/dstanet_drop0_6090120_128_ST') 11 | parser.add_argument('--spatial', default='./work_dir/ntu60/dstanet_drop0_6090120_128_S') 12 | parser.add_argument('--temporal_slow', default='./work_dir/ntu60/dstanet_drop0_6090120_128_T1') 13 | parser.add_argument('--temporal_fast', default='./work_dir/ntu60/dstanet_drop0_6090120_128_T1') 14 | parser.add_argument('--alpha', default=[1, 1, 1, 1], help='weighted summation') 15 | arg = parser.parse_args() 16 | 17 | label = open(arg.label, 'rb') 18 | label = np.array(pickle.load(label)) 19 | r1 = open('{}/score.pkl'.format(arg.spatial_temporal), 'rb') 20 | r1 = list(pickle.load(r1).items()) 21 | r2 = open('{}/score.pkl'.format(arg.spatial), 'rb') 22 | r2 = list(pickle.load(r2).items()) 23 | r3 = open('{}/score.pkl'.format(arg.temporal_slow), 'rb') 24 | r3 = list(pickle.load(r3).items()) 25 | r4 = open('{}/score.pkl'.format(arg.temporal_fast), 'rb') 26 | r4 = list(pickle.load(r4).items()) 27 | right_num = total_num = right_num_5 = 0 28 | for i in tqdm(range(len(label[0]))): 29 | _, l = label[:, i] 30 | _, r11 = r1[i] 31 | _, r22 = r2[i] 32 | _, r33 = r3[i] 33 | _, r44 = r4[i] 34 | r = r11 * arg.alpha[0] + r22 * arg.alpha[1] + r33 * arg.alpha[2] + r44 * arg.alpha[3] 35 | rank_5 = r.argsort()[-5:] 36 | right_num_5 += int(int(l) in rank_5) 37 | r = np.argmax(r) 38 | right_num += int(r == int(l)) 39 | total_num += 1 40 | acc = right_num / total_num 41 | acc5 = right_num_5 / total_num 42 | print(acc, acc5) 43 | -------------------------------------------------------------------------------- /train_val_test/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | print('Python %s on %s' % (sys.version, sys.platform)) 5 | sys.path.extend(['../']) 6 | import pickle 7 | from train_val_test import train_val_model, parser_args 8 | from utility.log import TimerBlock, IteratorTimer 9 | from method_choose.data_choose import data_choose, init_seed 10 | from method_choose.model_choose import model_choose 11 | from method_choose.loss_choose import loss_choose 12 | 13 | with TimerBlock("Good Luck") as block: 14 | # params 15 | args = parser_args.parser_args(block) 16 | init_seed(1) 17 | 18 | data_loader_train, data_loader_val = data_choose(args, block) 19 | global_step, start_epoch, model, optimizer_dict = model_choose(args, block) 20 | loss_function = loss_choose(args, block) 21 | 22 | model.cuda() 23 | model.eval() 24 | 25 | print('Validate') 26 | loss, acc, score_dict, all_pre_true, wrong_path_pre_true = train_val_model.val_classifier(data_loader_val, model, 27 | loss_function, 0, args, 28 | None) 29 | save_score = os.path.join(args.model_saved_name, 'score.pkl') 30 | with open(save_score, 'wb') as f: 31 | pickle.dump(score_dict, f) 32 | with open(args.model_saved_name + '/all_pre_true.txt', 'w') as f: 33 | f.writelines(all_pre_true) 34 | with open(args.model_saved_name + '/wrong_path_pre_true.txt', 'w') as f: 35 | f.writelines(wrong_path_pre_true) 36 | print('Final: {}'.format(float(acc))) 37 | -------------------------------------------------------------------------------- /train_val_test/loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code copyright 2017, Clement Pinard 3 | ''' 4 | 5 | # freda (todo) : adversarial loss 6 | 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | 11 | 12 | class L1(nn.Module): 13 | def __init__(self): 14 | super(L1, self).__init__() 15 | 16 | def forward(self, output, target): 17 | lossvalue = torch.abs(output - target).mean() 18 | return lossvalue 19 | 20 | 21 | class L2(nn.Module): 22 | def __init__(self): 23 | super(L2, self).__init__() 24 | 25 | def forward(self, output, target): 26 | lossvalue = torch.norm(output - target, p=2, dim=1).mean() 27 | return lossvalue 28 | 29 | 30 | class L1Loss(nn.Module): 31 | def __init__(self, args): 32 | super(L1Loss, self).__init__() 33 | self.args = args 34 | self.loss = L1() 35 | 36 | def forward(self, output, target): 37 | lossvalue = self.loss(output, target) 38 | epevalue = EPE(output, target) 39 | return ['L1', 'EPE'], [lossvalue, epevalue] 40 | 41 | 42 | class L2Loss(nn.Module): 43 | def __init__(self, args): 44 | super(L2Loss, self).__init__() 45 | self.args = args 46 | self.loss = L2() 47 | 48 | def forward(self, output, target): 49 | lossvalue = self.loss(output, target) 50 | epevalue = EPE(output, target) 51 | return ['L2', 'EPE'], [lossvalue, epevalue] 52 | 53 | -------------------------------------------------------------------------------- /train_val_test/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | import numpy as np 4 | 5 | 6 | class SGD(Optimizer): 7 | def __init__(self, params, lr=required, momentum=0, dampening=0, 8 | weight_decay=0, nesterov=False, writer=None): 9 | if lr is not required and lr < 0.0: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if momentum < 0.0: 12 | raise ValueError("Invalid momentum value: {}".format(momentum)) 13 | if weight_decay < 0.0: 14 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 15 | 16 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 17 | weight_decay=weight_decay, nesterov=nesterov) 18 | if nesterov and (momentum <= 0 or dampening != 0): 19 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 20 | if not writer is None: 21 | self.writer = writer 22 | self.num = 0 23 | super(SGD, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(SGD, self).__setstate__(state) 27 | for group in self.param_groups: 28 | group.setdefault('nesterov', False) 29 | 30 | def step(self, closure=None): 31 | """Performs a single optimization step. 32 | 33 | Arguments: 34 | closure (callable, optional): A closure that reevaluates the model 35 | and returns the loss. 36 | """ 37 | loss = None 38 | if closure is not None: 39 | loss = closure() 40 | self.num += 1 41 | for i, group in enumerate(self.param_groups): # each param each group 42 | weight_decay = group['weight_decay'] 43 | momentum = group['momentum'] 44 | dampening = group['dampening'] 45 | nesterov = group['nesterov'] 46 | for j, p in enumerate(group['params']): 47 | if p.grad is None: 48 | continue 49 | d_p = p.grad.data 50 | if weight_decay != 0: 51 | d_p.add_(weight_decay, p.data) 52 | if momentum != 0: 53 | param_state = self.state[p] 54 | if 'momentum_buffer' not in param_state: 55 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 56 | buf.mul_(momentum).add_(d_p) 57 | else: 58 | buf = param_state['momentum_buffer'] 59 | buf.mul_(momentum).add_(1 - dampening, d_p) 60 | if nesterov: 61 | d_p = d_p.add(momentum, buf) 62 | else: 63 | d_p = buf 64 | # if i == 0 and j == 0 and self.writer is not None: 65 | # self.writer.add_histogram('Ibuf', buf, self.num) 66 | p.data.add_(-group['lr'], d_p) 67 | 68 | return loss 69 | -------------------------------------------------------------------------------- /train_val_test/parser_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from utility.log import TimerBlock 4 | import colorama 5 | import torch 6 | import shutil 7 | import yaml 8 | from easydict import EasyDict as ed 9 | 10 | 11 | def parser_args(block): 12 | # params 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-config', default='') 15 | parser.add_argument('-model', default='resnet3d_50') 16 | parser.add_argument('-model_param', default={}, help=None) 17 | # classify_multi_crop classify classify_pose 18 | parser.add_argument('-train', default='classify') 19 | parser.add_argument('-val_first', default=False) 20 | parser.add_argument('-data', default='jmdbgulp') 21 | parser.add_argument('-data_param', default={}, help='') 22 | # train_val test train_test 23 | parser.add_argument('-mode', default='train_val') 24 | # cross_entropy mse_ce 25 | parser.add_argument('-loss', default='cross_entropy') 26 | parser.add_argument('-ls_param', default={ 27 | }) 28 | # reduce_by_acc reduce_by_loss reduce_by_epoch cosine_annealing_lr 29 | parser.add_argument('-lr_scheduler', default='reduce_by_acc') 30 | parser.add_argument('-lr_param', default={}) 31 | parser.add_argument('-warm_up_epoch', default=0) 32 | parser.add_argument('-step', default=[80, ]) 33 | parser.add_argument('-lr', default=0.01) # 0.001 34 | parser.add_argument('-wd', default=1e-4) # 5e-4 35 | parser.add_argument('-lr_decay_ratio', default=0.1) 36 | parser.add_argument('-lr_multi_keys', default=[ 37 | ['fc', 1, 1, 0], ['bn', 1, 1, 0], 38 | ], help='key, lr ratio, wd ratio, epoch') 39 | parser.add_argument('-optimizer', default='sgd_nev') 40 | parser.add_argument('-freeze_keys', default=[ 41 | ['PA', 5], 42 | ], help='key, epoch') 43 | 44 | parser.add_argument('-class_num', default=12) 45 | parser.add_argument('-batch_size', default=32) 46 | parser.add_argument('-worker', default=16) 47 | parser.add_argument('-pin_memory', default=False) 48 | parser.add_argument('-max_epoch', default=50) 49 | 50 | parser.add_argument('-num_epoch_per_save', default=2) 51 | parser.add_argument('-model_saved_name', default='') 52 | parser.add_argument('-last_model', default=None, help='') 53 | parser.add_argument('-ignore_weights', default=['fc']) 54 | parser.add_argument('-pre_trained_model', default='') 55 | parser.add_argument('--label_smoothing_num', default=0, help='0-1: 0 denotes no smoothing') 56 | parser.add_argument('--mix_up_num', default=0, help='0-1: 1 denotes uniform distribution, smaller, more concave') 57 | parser.add_argument('-device_id', type=int, default=[0]) 58 | parser.add_argument('-debug', default=False) 59 | parser.add_argument('-cuda_visible_device', default='0, 1, 2, 3, 4, 5, 6, 7') 60 | parser.add_argument('-grad_clip', default=0) 61 | #parser.add_argument('-num_subsets', type=int, default=3) 62 | #parser.add_argument('-use_kernel_attention', type=int, default=5) 63 | p = parser.parse_args() 64 | 65 | if p.config is not None: 66 | with open(p.config, 'r') as f: 67 | default_arg = yaml.load(f) 68 | key = vars(p).keys() 69 | for k in default_arg.keys(): 70 | if k not in key: 71 | print('WRONG ARG: {}'.format(k)) 72 | assert (k in key) 73 | parser.set_defaults(**default_arg) 74 | 75 | args = parser.parse_args() 76 | 77 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_device 78 | 79 | if args.debug: 80 | args.device_id = [0] 81 | args.batch_size = 1 82 | args.worker = 0 83 | os.environ['DISPLAY'] = 'localhost:10.0' 84 | block.addr = os.path.join(args.model_saved_name, 'log.txt') 85 | 86 | if os.path.isdir(args.model_saved_name) and not args.last_model and not args.debug: 87 | print('log_dir: ' + args.model_saved_name + ' already exist') 88 | #answer = input('delete it? y/n:') 89 | # if answer == 'y': 90 | # shutil.rmtree(args.model_saved_name) 91 | # print('Dir removed: ' + args.model_saved_name) 92 | # input('refresh it') 93 | # else: 94 | # print('Dir not removed: ' + args.model_saved_name) 95 | 96 | if not os.path.exists(args.model_saved_name): 97 | os.makedirs(args.model_saved_name) 98 | # Get argument defaults (has tag #this is a hack) 99 | parser.add_argument('--IGNORE', action='store_true') 100 | # 会返回列表 101 | defaults = vars(parser.parse_args(['--IGNORE'])) 102 | # Print all arguments, color the non-defaults 103 | for argument, value in sorted(vars(args).items()): 104 | reset = colorama.Style.RESET_ALL 105 | color = reset if value == defaults[argument] else colorama.Fore.MAGENTA 106 | block.log('{}{}: {}{}'.format(color, argument, value, reset)) 107 | 108 | shutil.copy2(__file__, args.model_saved_name) 109 | shutil.copy2(args.config, args.model_saved_name) 110 | 111 | args = ed(vars(args)) 112 | return args 113 | -------------------------------------------------------------------------------- /train_val_test/train.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function, division 2 | 3 | import os 4 | 5 | os.environ['DISPLAY'] = 'localhost:10.0' 6 | import sys 7 | 8 | print('Python %s on %s' % (sys.version, sys.platform)) 9 | sys.path.extend(['../']) 10 | #os.environ["CUDA_VISIBLE_DEVICES"] = '2' 11 | 12 | import time 13 | import torch 14 | import setproctitle 15 | # from tensorboard_logger import configure, log_value 16 | from tensorboardX import SummaryWriter 17 | from tqdm import tqdm 18 | from method_choose.data_choose import data_choose, init_seed 19 | from method_choose.loss_choose import loss_choose 20 | from method_choose.lr_scheduler_choose import lr_scheduler_choose 21 | from method_choose.model_choose import model_choose 22 | from method_choose.optimizer_choose import optimizer_choose 23 | from method_choose.tra_val_choose import train_val_choose 24 | from train_val_test import parser_args 25 | from utility.log import TimerBlock, IteratorTimer 26 | from collections import OrderedDict 27 | import pickle 28 | import os 29 | 30 | def rm_module(old_dict): 31 | new_state_dict = OrderedDict() 32 | for k, v in old_dict.items(): 33 | head = k[:7] 34 | if head == 'module.': 35 | name = k[7:] # remove `module.` 36 | else: 37 | name = k 38 | new_state_dict[name] = v 39 | return new_state_dict 40 | 41 | def train(args, block): 42 | #with TimerBlock("Good Luck") as block: 43 | init_seed(1) 44 | setproctitle.setproctitle(args.model_saved_name) 45 | block.log('work dir: ' + args.model_saved_name) 46 | if args.mode == 'train_val': 47 | train_writer = SummaryWriter(os.path.join(args.model_saved_name, 'train'), 'train') 48 | val_writer = SummaryWriter(os.path.join(args.model_saved_name, 'val'), 'val') 49 | else: 50 | train_writer = val_writer = SummaryWriter(os.path.join(args.model_saved_name, 'test'), 'test') 51 | 52 | global_step, start_epoch, model, optimizer_dict = model_choose(args, block) 53 | optimizer = optimizer_choose(model, args, val_writer, block) 54 | 55 | if optimizer_dict is not None and args.last_model is not None: 56 | try: 57 | optimizer.load_state_dict(optimizer_dict) 58 | block.log('load optimizer from state dict') 59 | except: 60 | block.log('optimizer not matched') 61 | else: 62 | block.log('no pretrained optimizer is loaded') 63 | 64 | loss_function = loss_choose(args, block) 65 | 66 | data_loader_train, data_loader_val = data_choose(args, block) 67 | 68 | lr_scheduler = lr_scheduler_choose(optimizer, args, start_epoch - 1, block) 69 | 70 | train_net, val_net = train_val_choose(args, block) 71 | 72 | best_accu = 0 73 | best_step = 0 74 | best_epoch = 0 75 | acc = 0 76 | loss = 100 77 | process = tqdm(range(start_epoch, args.max_epoch), 'Process: ' + args.model_saved_name) 78 | block.log('start epoch {} -> max epoch {}'.format(start_epoch, args.max_epoch)) 79 | if args.val_first: 80 | model.eval() 81 | loss, acc, score_dict, all_pre_true, wrong_path_pre_true = val_net(data_loader_val, model, loss_function, global_step, args, val_writer) 82 | block.log('Init ACC: {}'.format(acc)) 83 | # lr = optimizer.param_groups[0]['lr'] 84 | for epoch in process: 85 | last_epoch_time = time.time() 86 | model.train() # Set model to training mode 87 | 88 | if args.lr_scheduler == 'reduce_by_epoch': 89 | lr_scheduler.step(epoch=epoch) 90 | elif args.lr_scheduler == 'reduce_by_acc': 91 | lr_scheduler.step(metric=acc, epoch=epoch) 92 | elif args.lr_scheduler == 'reduce_by_loss': 93 | lr_scheduler.step(metric=loss, epoch=epoch) 94 | else: 95 | lr_scheduler.step(epoch=epoch) 96 | lr = optimizer.param_groups[0]['lr'] 97 | block.log('Current lr: {}'.format(lr)) 98 | 99 | for key, value in model.named_parameters(): 100 | value.requires_grad = True 101 | for freeze_key, freeze_epoch in args.freeze_keys: 102 | if freeze_epoch > epoch: 103 | block.log('{} is froze'.format(freeze_key)) 104 | for key, value in model.named_parameters(): 105 | if freeze_key in key: 106 | # block.log('{} is froze'.format(key)) 107 | value.requires_grad = False 108 | 109 | for lr_key, ratio_lr, ratio_wd, lr_epoch in args.lr_multi_keys: 110 | if lr_epoch > epoch: 111 | block.log('lr for {}: {}*{}, wd: {}*{}'.format(lr_key, lr, ratio_lr, args.wd, ratio_wd)) 112 | for param in optimizer.param_groups: 113 | if lr_key in param['key']: 114 | param['lr'] *= ratio_lr 115 | param['weight_decay'] *= ratio_wd 116 | 117 | global_step = train_net(data_loader_train, model, loss_function, optimizer, global_step, args, train_writer) 118 | block.log('Training finished for epoch {}'.format(epoch)) 119 | model.eval() 120 | loss, acc, score_dict, all_pre_true, wrong_path_pre_true = val_net(data_loader_val, model, loss_function, global_step, args, val_writer) 121 | block.log('Validation finished for epoch {}'.format(epoch)) 122 | 123 | if args.mode == 'train_val': 124 | train_writer.add_scalar('epoch', epoch, global_step) 125 | train_writer.add_scalar('lr', lr, global_step) 126 | train_writer.add_scalar('epoch_time', time.time() - last_epoch_time, global_step) 127 | 128 | if acc > best_accu: 129 | best_accu = acc 130 | best_step = global_step 131 | best_epoch = epoch 132 | save_score = args.model_saved_name + '/score.pkl' 133 | with open(save_score, 'wb') as f: 134 | pickle.dump(score_dict, f) 135 | with open(args.model_saved_name + '/all_pre_true.txt', 'w') as f: 136 | f.writelines(all_pre_true) 137 | with open(args.model_saved_name + '/wrong_path_pre_true.txt', 'w') as f: 138 | f.writelines(wrong_path_pre_true) 139 | 140 | # save model 141 | # m = rm_module(model.state_dict()) 142 | # save = { 143 | # 'model': m, 144 | # 'optimizer': optimizer.state_dict() 145 | # } 146 | # torch.save(save, args.model_saved_name + '-latest.state') 147 | # if (epoch + 1) % args.num_epoch_per_save == 0: 148 | # torch.save(save, args.model_saved_name + '-' + str(epoch) + '-' + str(global_step) + '.state') 149 | 150 | process.set_description('Process: ' + args.model_saved_name + ' lr: ' + str(lr)) 151 | block.log('EPOCH: {}, ACC: {:4f}, LOSS: {:4f}, EPOCH_TIME: {:4f}, LR: {}, BEST_ACC: {:4f}' 152 | .format(epoch, acc, loss, time.time() - last_epoch_time, lr, best_accu)) 153 | 154 | if lr < 1e-5: 155 | break 156 | 157 | m = rm_module(model.cpu().state_dict()) 158 | # save = { 159 | # 'model': m, 160 | # 'optimizer': optimizer.state_dict() 161 | # } 162 | # torch.save(save, args.model_saved_name + '-' + str(epoch) + '-' + str(global_step) + '.state') 163 | block.log( 164 | 'Best model: ' + args.model_saved_name + '-' + str(best_epoch) + '-' + str(best_step) + '.state, acc: ' + str( 165 | best_accu)) 166 | # block.save(os.path.join(args.model_saved_name, 'log.txt')) 167 | 168 | return best_accu 169 | 170 | if __name__ == '__main__': 171 | with TimerBlock("Good Luck") as block: 172 | print(torch.cuda.is_available()) 173 | # params 174 | args = parser_args.parser_args(block) 175 | best = train(args, block) 176 | 177 | -------------------------------------------------------------------------------- /train_val_test/train_val_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from utility.log import IteratorTimer 6 | # import torchvision 7 | import numpy as np 8 | import time 9 | import pickle 10 | #import cv2 11 | 12 | 13 | def to_onehot(num_class, label, alpha): 14 | return torch.zeros((label.shape[0], num_class)).fill_(alpha).scatter_(1, label.unsqueeze(1), 1 - alpha) 15 | 16 | 17 | def mixup(input, target, gamma): 18 | # target is onehot format! 19 | perm = torch.randperm(input.size(0)) 20 | perm_input = input[perm] 21 | perm_target = target[perm] 22 | return input.mul_(gamma).add_(1 - gamma, perm_input), target.mul_(gamma).add_(1 - gamma, perm_target) 23 | 24 | 25 | def clip_grad_norm_(parameters, max_grad): 26 | if isinstance(parameters, torch.Tensor): 27 | parameters = [parameters] 28 | parameters = list(filter(lambda p: p[1].grad is not None, parameters)) 29 | max_grad = float(max_grad) 30 | 31 | for name, p in parameters: 32 | grad = p.grad.data.abs() 33 | if grad.isnan().any(): 34 | ind = grad.isnan() 35 | p.grad.data[ind] = 0 36 | grad = p.grad.data.abs() 37 | if grad.isinf().any(): 38 | ind = grad.isinf() 39 | p.grad.data[ind] = 0 40 | grad = p.grad.data.abs() 41 | if grad.max() > max_grad: 42 | ind = grad>max_grad 43 | p.grad.data[ind] = p.grad.data[ind]/grad[ind]*max_grad # sign x val 44 | 45 | 46 | def train_classifier(data_loader, model, loss_function, optimizer, global_step, args, writer): 47 | process = tqdm(IteratorTimer(data_loader), desc='Train: ') 48 | for index, (inputs, labels) in enumerate(process): 49 | 50 | # label_onehot = to_onehot(args.class_num, labels, args.label_smoothing_num) 51 | if args.mix_up_num > 0: 52 | # self.print_log('using mixup data: ', self.arg.mix_up_num) 53 | targets = to_onehot(args.class_num, labels, args.label_smoothing_num) 54 | inputs, targets = mixup(inputs, targets, np.random.beta(args.mix_up_num, args.mix_up_num)) 55 | elif args.label_smoothing_num != 0 or args.loss == 'cross_entropy_naive': 56 | targets = to_onehot(args.class_num, labels, args.label_smoothing_num) 57 | else: 58 | targets = labels 59 | 60 | # inputs, labels = Variable(inputs.cuda(non_blocking=True)), Variable(labels.cuda(non_blocking=True)) 61 | inputs, targets, labels = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True), labels.cuda(non_blocking=True) 62 | # net = torch.nn.DataParallel(model, device_ids=args.device_id) 63 | outputs = model(inputs) 64 | loss = loss_function(outputs, targets) 65 | optimizer.zero_grad() 66 | loss.backward() 67 | if args.grad_clip: 68 | clip_grad_norm_(model.named_parameters(), args.grad_clip) 69 | optimizer.step() 70 | global_step += 1 71 | if len(outputs.data.shape) == 3: # T N cls 72 | _, predict_label = torch.max(outputs.data[:, :, :-1].mean(0), 1) 73 | else: 74 | _, predict_label = torch.max(outputs.data, 1) 75 | loss = loss_function(outputs, targets) 76 | ls = loss.data.item() 77 | acc = torch.mean((predict_label == labels.data).float()).item() 78 | # ls = loss.data[0] 79 | # acc = torch.mean((predict_label == labels.data).float()) 80 | lr = optimizer.param_groups[0]['lr'] 81 | process.set_description( 82 | 'Train: acc: {:4f}, loss: {:4f}, batch time: {:4f}, lr: {:4f}'.format(acc, ls, 83 | process.iterable.last_duration, 84 | lr)) 85 | 86 | # 每个batch记录一次 87 | if args.mode == 'train_val': 88 | writer.add_scalar('acc', acc, global_step) 89 | writer.add_scalar('loss', ls, global_step) 90 | writer.add_scalar('batch_time', process.iterable.last_duration, global_step) 91 | # if len(inputs.shape) == 5: 92 | # if index % 500 == 0: 93 | # img = inputs.data.cpu().permute(2, 0, 1, 3, 4) 94 | # # NCLHW->LNCHW 95 | # img = torchvision.utils.make_grid(img[0::4, 0][0:4], normalize=True) 96 | # writer.add_image('img', img, global_step=global_step) 97 | # elif len(inputs.shape) == 4: 98 | # if index % 500 == 0: 99 | # writer.add_image('img', ((inputs.cpu().numpy()[0] + 128) * 1).astype(np.uint8).transpose(1, 2, 0), 100 | # global_step=global_step) 101 | 102 | process.close() 103 | return global_step 104 | 105 | 106 | def val_classifier(data_loader, model, loss_function, global_step, args, writer): 107 | right_num_total = 0 108 | total_num = 0 109 | loss_total = 0 110 | step = 0 111 | process = tqdm(IteratorTimer(data_loader), desc='Val: ') 112 | # s = time.time() 113 | # t=0 114 | score_frag = [] 115 | all_pre_true = [] 116 | wrong_path_pre_ture = [] 117 | for index, (inputs, labels, path) in enumerate(process): 118 | #inputs = np.squeeze(inputs,axis=0) 119 | # label_onehot = to_onehot(args.class_num, labels, args.label_smoothing_num) 120 | if args.loss == 'cross_entropy_naive': 121 | targets = to_onehot(args.class_num, labels, args.label_smoothing_num) 122 | else: 123 | targets = labels 124 | 125 | with torch.no_grad(): 126 | inputs, targets, labels = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True), labels.cuda( 127 | non_blocking=True) 128 | outputs = model(inputs) 129 | #outputs = outputs.mean(0, keepdim=True) 130 | #outputs = torch.max(outputs,0,True) 131 | if len(outputs.data.shape) == 3: # T N cls 132 | _, predict_label = torch.max(outputs.data[:, :, :-1].mean(0), 1) 133 | score_frag.append(outputs.data.cpu().numpy().transpose(1,0,2)) 134 | else: 135 | _, predict_label = torch.max(outputs.data, 1) 136 | score_frag.append(outputs.data.cpu().numpy()) 137 | loss = loss_function(outputs, targets) 138 | 139 | predict = list(predict_label.cpu().numpy()) 140 | true = list(labels.data.cpu().numpy()) 141 | for i, x in enumerate(predict): 142 | all_pre_true.append(str(x) + ',' + str(true[i]) + '\n') 143 | if x != true[i]: 144 | wrong_path_pre_ture.append(str(path[i]) + ',' + str(x) + ',' + str(true[i]) + '\n') 145 | 146 | right_num = torch.sum(predict_label == labels.data).item() 147 | # right_num = torch.sum(predict_label == labels.data) 148 | batch_num = labels.data.size(0) 149 | acc = right_num / batch_num 150 | ls = loss.data.item() 151 | # ls = loss.data[0] 152 | 153 | right_num_total += right_num 154 | total_num += batch_num 155 | loss_total += ls 156 | step += 1 157 | 158 | process.set_description( 159 | 'Val-batch: acc: {:4f}, loss: {:4f}, time: {:4f}'.format(acc, ls, process.iterable.last_duration)) 160 | # process.set_description_str( 161 | # 'Val: acc: {:4f}, loss: {:4f}, time: {:4f}'.format(t, t, t), refresh=False) 162 | # if len(inputs.shape) == 5: 163 | # if index % 50 == 0 and (writer is not None) and args.mode == 'train_val': 164 | # # NCLHW->LNCHW 165 | # img = inputs.data.cpu().permute(2, 0, 1, 3, 4) 166 | # img = torchvision.utils.make_grid(img[0::4, 0][0:4], normalize=True) 167 | # writer.add_image('img', img, global_step=global_step) 168 | # elif len(inputs.shape) == 4: 169 | # if index % 50 == 0 and (writer is not None) and args.mode == 'train_val': 170 | # writer.add_image('img', ((inputs.cpu().numpy()[0] + 128) * 1).astype(np.uint8).transpose(1, 2, 0), 171 | # global_step=global_step) 172 | # t = time.time()-s 173 | # print('time: ', t) 174 | score = np.concatenate(score_frag) 175 | score_dict = dict(zip(data_loader.dataset.sample_name, score)) 176 | 177 | process.close() 178 | loss = loss_total / step 179 | accuracy = right_num_total / total_num 180 | print('Accuracy: ', accuracy) 181 | if args.mode == 'train_val' and writer is not None: 182 | writer.add_scalar('loss', loss, global_step) 183 | writer.add_scalar('acc', accuracy, global_step) 184 | writer.add_scalar('batch time', process.iterable.last_duration, global_step) 185 | 186 | return loss, accuracy, score_dict, all_pre_true, wrong_path_pre_ture 187 | 188 | -------------------------------------------------------------------------------- /utility/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seuzjj/Diffusion_kernel_attention_network/e9141dc335192b74b0f7b60b37821133ce48f172/utility/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /utility/log.py: -------------------------------------------------------------------------------- 1 | # import torchvision 2 | import torch 3 | import numpy as np 4 | # import matplotlib.pyplot as plt 5 | import time 6 | 7 | 8 | class TimerBlock: 9 | """ 10 | with TimerBlock(title) as block: 11 | block.log(msg) 12 | block.log2file(addr,msg) 13 | """ 14 | 15 | def __init__(self, title): 16 | print("{}".format(title)) 17 | self.content = [] 18 | self.addr = None 19 | 20 | def __enter__(self): 21 | self.start = time.time() 22 | return self 23 | 24 | def __exit__(self, exc_type, exc_value, traceback): 25 | self.end = time.time() 26 | self.interval = self.end - self.start 27 | 28 | if exc_type is not None: 29 | self.log("Operation failed\n") 30 | else: 31 | self.log("Operation finished\n") 32 | 33 | def log(self, string): 34 | # duration = time.time() - self.start 35 | # units = 's' 36 | # if duration > 60: 37 | # duration = duration / 60. 38 | # units = 'm' 39 | # s = " [{:.3f}{}] {}".format(duration, units, string) 40 | s = time.ctime() + ' ' + string 41 | print(s) 42 | self.content.append(s + '\n') 43 | fid = open(self.addr, 'a') 44 | fid.write("%s\n" % (s)) 45 | fid.close() 46 | 47 | def save(self, fid): 48 | f = open(fid, 'a') 49 | f.writelines(self.content) 50 | f.close() 51 | 52 | def log2file(self, fid, string): 53 | fid = open(fid, 'a') 54 | fid.write("%s\n" % (string)) 55 | fid.close() 56 | 57 | 58 | class IteratorTimer(): 59 | """ 60 | An iterator to produce duration. self.last_duration 61 | """ 62 | 63 | def __init__(self, iterable): 64 | self.iterable = iterable 65 | self.iterator = self.iterable.__iter__() 66 | 67 | def __iter__(self): 68 | return self 69 | 70 | def __len__(self): 71 | return len(self.iterable) 72 | 73 | def __next__(self): 74 | start = time.time() 75 | n = self.iterator.__next__() 76 | self.last_duration = (time.time() - start) 77 | return n 78 | 79 | next = __next__ 80 | 81 | 82 | if __name__ == '__main__': 83 | with TimerBlock('Test') as block: 84 | block.log('1') 85 | block.log('2') 86 | block.save('../train_val_test/runs/test.txt') 87 | --------------------------------------------------------------------------------