├── module ├── __init__.py ├── Temporal_shift │ ├── __init__.py │ ├── cuda │ │ ├── __init__.py │ │ ├── setup.py │ │ ├── shift_cuda.cpp │ │ ├── shift.py │ │ └── shift_cuda_kernel.cu │ ├── readme.txt │ ├── run.sh │ └── demo.py ├── adapter.py ├── ntu_rgb_d.py ├── gcn │ ├── utils │ │ ├── tgcn.py │ │ └── graph.py │ └── st_gcn.py └── shift_gcn.py ├── descriptions ├── __init__.py ├── pku_labelmap.txt ├── ntu_labelmap.txt ├── pku_des.txt ├── ntu60_des.txt ├── ntu120_des.txt └── ntu_parts_from_GAP.txt ├── assets ├── prototype.png ├── testing.png └── training.png ├── dataset.py ├── LICENSE ├── KLLoss.py ├── .gitignore ├── logger.py ├── test_logger.py ├── one_shot_logger.py ├── one_shot_test_logger.py ├── sentence_bert_embedding.py ├── test_config.py ├── config.py ├── one_shot_test_config.py ├── one_shot_config.py ├── tool.py ├── README.md ├── one_shot_main.py ├── one_shot_test_main.py └── test_main.py /module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /descriptions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /module/Temporal_shift/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /module/Temporal_shift/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/prototype.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaai520/PGFA/HEAD/assets/prototype.png -------------------------------------------------------------------------------- /assets/testing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaai520/PGFA/HEAD/assets/testing.png -------------------------------------------------------------------------------- /assets/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaai520/PGFA/HEAD/assets/training.png -------------------------------------------------------------------------------- /module/Temporal_shift/readme.txt: -------------------------------------------------------------------------------- 1 | Run `bash run.sh` to compile the cuda extension of temporal adaptive shift op. 2 | This implementation borrows from active-shift in Caffe by jyh2986. We modified it for temporal adaptive shift. -------------------------------------------------------------------------------- /module/Temporal_shift/run.sh: -------------------------------------------------------------------------------- 1 | cd cuda 2 | 3 | rm -rf ./__pycache__ 4 | rm -rf ./dist 5 | rm -rf ./build 6 | rm -rf ./shift_cuda_linear_cpp.egg-info 7 | 8 | python setup.py install 9 | 10 | cd .. 11 | 12 | CUDA_VISIBLE_DEVICES="0" python demo.py 13 | -------------------------------------------------------------------------------- /module/Temporal_shift/cuda/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 3 | 4 | setup( 5 | name='shift_cuda_linear_cpp', 6 | ext_modules=[ 7 | CUDAExtension('shift_cuda', [ 8 | 'shift_cuda.cpp', 9 | 'shift_cuda_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class DataSet(torch.utils.data.Dataset): 5 | 6 | def __init__(self, 7 | data_path: str, 8 | label_path: str,): 9 | 10 | self.data_path = data_path 11 | self.label_path = label_path 12 | self.load_data() 13 | 14 | 15 | def load_data(self): 16 | 17 | self.data = np.load(self.data_path) 18 | self.label = np.load(self.label_path) 19 | self.size = len(self.label) 20 | 21 | def __len__(self) -> int: 22 | return self.size 23 | 24 | def __getitem__(self, index: int) -> tuple: 25 | 26 | data = self.data[index] 27 | label = self.label[index] 28 | 29 | return data, label 30 | -------------------------------------------------------------------------------- /module/Temporal_shift/demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import torch 6 | 7 | import torch.nn as nn 8 | from torch.autograd import Variable, gradcheck 9 | 10 | from cuda.shift import ShiftFunction,Shift 11 | 12 | 13 | shift_layer = Shift(channel=5,stride=2) 14 | 15 | input = Variable(torch.ones(1,5,8,4).cuda().float(), requires_grad=True) 16 | out = shift_layer(input) 17 | sum_out = torch.sum(out) 18 | sum_out.backward() 19 | 20 | print('*'*20 + ' input') 21 | print(input) 22 | print('*'*20 + ' out') 23 | print(out) 24 | print('*'*20 + ' input.grad') 25 | print(input.grad) 26 | print('*'*20 + ' shift_layer.temporal_position') 27 | print(shift_layer.ypos) 28 | print('*'*20 + ' shift_layer.temporal_position.grad') 29 | print(shift_layer.ypos.grad) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 KaiZhou-cs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /descriptions/pku_labelmap.txt: -------------------------------------------------------------------------------- 1 | Bow 2 | Brushing hair 3 | Brushing teeth 4 | Check time (from watch) 5 | Cheer up 6 | Clapping 7 | Cross hands in front (say stop) 8 | Drink water 9 | Drop 10 | Eat meal/snack 11 | Falling 12 | Giving something to other person 13 | Hand waving 14 | Handshaking 15 | Hopping (one foot jumping) 16 | Hugging other person 17 | Jumping up 18 | Kicking other person 19 | Kicking something 20 | Making a phone call/answering phone 21 | Patting on back of other person 22 | Pickup 23 | Playing with phone/tablet 24 | Pointing finger at the other person 25 | Pointing to something with finger 26 | Punching/slapping other person 27 | Pushing other person 28 | Put on a hat/cap 29 | Put something inside pocket 30 | Reading 31 | Rub two hands together 32 | Salute 33 | Sitting down 34 | Standing up 35 | Take off a hat/cap 36 | Take off glasses 37 | Take off jacket 38 | Take out something from pocket 39 | Taking a selfie 40 | Tear up paper 41 | Throw 42 | Touch back (backache) 43 | Touch chest (stomachache/heart pain) 44 | Touch head (headache) 45 | Touch neck (neckache) 46 | Typing on a keyboard 47 | Use a fan (with hand or paper)/feeling warm 48 | Wear jacket 49 | Wear on glasses 50 | Wipe face 51 | Writing -------------------------------------------------------------------------------- /module/Temporal_shift/cuda/shift_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | at::Tensor shift_cuda_forward( 5 | at::Tensor input,at::Tensor xpos,at::Tensor ypos,const int stride); 6 | 7 | std::vector shift_cuda_backward( 8 | at::Tensor grad_output, 9 | at::Tensor input, 10 | at::Tensor output, 11 | at::Tensor xpos, 12 | at::Tensor ypos, 13 | const int stride); 14 | 15 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | at::Tensor shift_forward( 20 | at::Tensor input,at::Tensor xpos,at::Tensor ypos,const int stride) { 21 | CHECK_INPUT(input); 22 | return shift_cuda_forward(input,xpos,ypos,stride); 23 | } 24 | 25 | std::vector shift_backward( 26 | at::Tensor grad_output, 27 | at::Tensor input, 28 | at::Tensor output, 29 | at::Tensor xpos, 30 | at::Tensor ypos, 31 | const int stride) 32 | { 33 | CHECK_INPUT(grad_output); 34 | CHECK_INPUT(output); 35 | return shift_cuda_backward( 36 | grad_output, 37 | input, 38 | output, 39 | xpos, 40 | ypos, 41 | stride); 42 | } 43 | 44 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 45 | m.def("forward", &shift_forward, "shift forward (CUDA)"); 46 | m.def("backward", &shift_backward, "shift backward (CUDA)"); 47 | } 48 | -------------------------------------------------------------------------------- /module/Temporal_shift/cuda/shift.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Parameter 2 | from torch.autograd import Function 3 | 4 | import torch 5 | import shift_cuda 6 | 7 | import numpy as np 8 | 9 | class ShiftFunction(Function): 10 | 11 | @staticmethod 12 | def forward(ctx, input,xpos,ypos,stride=1): 13 | if stride==1: 14 | xpos = xpos 15 | ypos = ypos 16 | else: 17 | ypos = ypos + 0.5 18 | # ypos = ypos + 0.5 19 | output = shift_cuda.forward(input,xpos,ypos,stride) 20 | ctx.save_for_backward(input, output, xpos, ypos) 21 | ctx.stride = stride 22 | return output 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | grad_output = grad_output.contiguous() 27 | input, output, xpos, ypos = ctx.saved_variables 28 | grad_input,grad_xpos,grad_ypos = shift_cuda.backward(grad_output, input, output, xpos, ypos, ctx.stride) 29 | return grad_input, grad_xpos, grad_ypos, None 30 | 31 | class Shift(Module): 32 | 33 | def __init__(self, channel, stride, init_scale=3): 34 | super(Shift, self).__init__() 35 | 36 | self.stride = stride 37 | 38 | self.xpos = Parameter(torch.zeros(channel,requires_grad=True,device='cuda')*1.5) 39 | self.ypos = Parameter(torch.zeros(channel,requires_grad=True,device='cuda')*1.5) 40 | 41 | self.xpos.data.uniform_(-1e-8,1e-8) 42 | self.ypos.data.uniform_(-init_scale,init_scale) 43 | 44 | def forward(self, input): 45 | return ShiftFunction.apply(input,self.xpos,self.ypos,self.stride) -------------------------------------------------------------------------------- /KLLoss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | 4 | class KLLoss(nn.Module): 5 | """Loss that uses a 'hinge' on the lower bound. 6 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is 7 | also smaller than that threshold. 8 | args: 9 | error_matric: What base loss to use (MSE by default). 10 | threshold: Threshold to use for the hinge. 11 | clip: Clip the loss if it is above this value. 12 | """ 13 | 14 | def __init__(self, error_metric=nn.KLDivLoss(size_average=True, reduce=True)): 15 | super().__init__() 16 | print('=========using KL Loss=and has temperature and * bz==========') 17 | self.error_metric = error_metric 18 | 19 | def forward(self, prediction, label): 20 | batch_size = prediction.shape[0] 21 | probs1 = F.log_softmax(prediction, 1) 22 | probs2 = F.softmax(label * 10, 1) 23 | loss = self.error_metric(probs1, probs2) * batch_size 24 | return loss 25 | 26 | class KDLoss(nn.Module): 27 | def __init__(self, error_metric=nn.KLDivLoss(size_average=True, reduce=True), T=0.1): 28 | super().__init__() 29 | print('=========using KL Loss=and has temperature and * bz==========') 30 | self.error_metric = error_metric 31 | self.T = T 32 | 33 | def forward(self, prediction, teacher_logits): 34 | batch_size = prediction.shape[0] 35 | probs1 = F.log_softmax(prediction/self.T, 1) 36 | probs2 = F.softmax(teacher_logits/self.T, 1) 37 | loss = self.error_metric(probs1, probs2) * batch_size 38 | return loss -------------------------------------------------------------------------------- /module/adapter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def conv_init(conv): 8 | if conv.weight is not None: 9 | nn.init.kaiming_normal_(conv.weight, mode='fan_out') 10 | if conv.bias is not None: 11 | nn.init.constant_(conv.bias, 0) 12 | 13 | class Linear(nn.Module): 14 | def __init__(self, hidden_size=256, output_size=768): 15 | super(Linear, self).__init__() 16 | self.adapter = nn.Linear(hidden_size, output_size) 17 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).cuda() 18 | self.logit_scale_v2 = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).cuda() 19 | conv_init(self.adapter) 20 | 21 | def forward(self, x): 22 | return self.adapter(x) 23 | 24 | def get_logit_scale(self): 25 | return self.logit_scale 26 | 27 | def get_logit_scale_v2(self): 28 | return self.logit_scale_v2 29 | 30 | 31 | 32 | 33 | class Adapter(nn.Module): 34 | def __init__(self, hidden_size=256, output_size=768): 35 | super(Adapter, self).__init__() 36 | self.fc1 = nn.Linear(hidden_size, hidden_size) 37 | self.fc2 = nn.Linear(hidden_size, output_size) 38 | self.fc3 = nn.Linear(hidden_size, output_size, bias=False) 39 | self.act = nn.GELU() 40 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).cuda() 41 | conv_init(self.fc1) 42 | conv_init(self.fc2) 43 | conv_init(self.fc3) 44 | 45 | def forward(self, x): 46 | xs = self.fc1(x) 47 | xs = self.act(xs) 48 | xs = self.fc2(x) 49 | return self.fc3(x) + xs 50 | 51 | def get_logit_scale(self): 52 | return self.logit_scale 53 | 54 | 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | .hypothesis/ 46 | .pytest_cache/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | db.sqlite3 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # Environments 83 | .env 84 | .venv 85 | env/ 86 | venv/ 87 | ENV/ 88 | env.bak/ 89 | venv.bak/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | /data 103 | .git 104 | data.zip 105 | /output 106 | /ckpts 107 | 108 | -------------------------------------------------------------------------------- /module/ntu_rgb_d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def edge2mat(link, num_node): 5 | A = np.zeros((num_node, num_node)) 6 | for i, j in link: 7 | A[j, i] = 1 8 | return A 9 | 10 | 11 | def normalize_digraph(A): # 除以每列的和 12 | Dl = np.sum(A, 0) 13 | h, w = A.shape 14 | Dn = np.zeros((w, w)) 15 | for i in range(w): 16 | if Dl[i] > 0: 17 | Dn[i, i] = Dl[i] ** (-1) 18 | AD = np.dot(A, Dn) 19 | return AD 20 | 21 | 22 | def get_spatial_graph(num_node, self_link, inward, outward): 23 | I = edge2mat(self_link, num_node) 24 | In = normalize_digraph(edge2mat(inward, num_node)) 25 | Out = normalize_digraph(edge2mat(outward, num_node)) 26 | A = np.stack((I, In, Out)) 27 | return A 28 | 29 | 30 | num_node = 25 31 | self_link = [(i, i) for i in range(num_node)] 32 | inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), 33 | (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1), 34 | (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18), 35 | (20, 19), (22, 23), (23, 8), (24, 25), (25, 12)] 36 | inward = [(i - 1, j - 1) for (i, j) in inward_ori_index] 37 | outward = [(j, i) for (i, j) in inward] 38 | neighbor = inward + outward 39 | 40 | 41 | class Graph: 42 | def __init__(self, labeling_mode='spatial'): 43 | self.A = self.get_adjacency_matrix(labeling_mode) 44 | self.num_node = num_node 45 | self.self_link = self_link 46 | self.inward = inward 47 | self.outward = outward 48 | self.neighbor = neighbor 49 | 50 | def get_adjacency_matrix(self, labeling_mode=None): 51 | if labeling_mode is None: 52 | return self.A 53 | if labeling_mode == 'spatial': 54 | A = get_spatial_graph(num_node, self_link, inward, outward) 55 | else: 56 | raise ValueError() 57 | return A 58 | 59 | 60 | if __name__ == '__main__': 61 | import matplotlib.pyplot as plt 62 | import os 63 | 64 | # os.environ['DISPLAY'] = 'localhost:11.0' 65 | A = Graph('spatial').get_adjacency_matrix() 66 | for i in A: 67 | plt.imshow(i, cmap='gray') 68 | plt.show() 69 | print(A) 70 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | from config import * 4 | 5 | def get_logger(filename, verbosity=1, name=None): 6 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 7 | formatter = logging.Formatter( 8 | "[%(asctime)s][%(levelname)s] %(message)s" 9 | ) 10 | os.makedirs(os.path.dirname(filename), exist_ok=True) 11 | logger = logging.getLogger(name) 12 | logger.setLevel(level_dict[verbosity]) 13 | 14 | fh = logging.FileHandler(filename, "w") 15 | fh.setFormatter(formatter) 16 | logger.addHandler(fh) 17 | 18 | sh = logging.StreamHandler() 19 | sh.setFormatter(formatter) 20 | logger.addHandler(sh) 21 | 22 | return logger 23 | 24 | class Log: 25 | @ex.capture 26 | def __init__(self, log_path) -> None: 27 | self.batch_data = dict() 28 | self.epoch_data = dict() 29 | self.max_data = {'best_epoch':-1, 'test_acc':-1} 30 | self.logger = get_logger(log_path) 31 | self.logger.info('Start') 32 | 33 | def update_batch(self, name, value): 34 | if name not in self.batch_data: 35 | self.batch_data[name] = list() 36 | self.batch_data[name].append(value) 37 | 38 | def info(self, msg): 39 | # logging.info(msg) 40 | self.logger.info(msg) 41 | 42 | @ex.capture 43 | def update_epoch(self, epoch, epoch_num, train_mode): 44 | self.logger.info('Epoch:[{}/{}]'.format(epoch + 1 , epoch_num)) 45 | for name in self.batch_data.keys(): 46 | if name not in self.epoch_data: 47 | self.epoch_data[name] = list() 48 | epoch_value = np.mean(self.batch_data[name]) 49 | self.epoch_data[name].append(epoch_value) 50 | self.batch_data[name] = list() 51 | if 'test/cls_acc' in name and epoch_value > self.max_data['test_acc']: 52 | self.max_data['test_acc'] = epoch_value 53 | self.max_data['best_epoch'] = epoch 54 | self.logger.info("{}: {}".format(name, self.epoch_data[name][-1])) 55 | if "loadweight" in train_mode: 56 | self.logger.info("Epoch:[{}] get the best test acc: {}" 57 | .format(self.max_data['best_epoch'], self.max_data['test_acc'])) 58 | -------------------------------------------------------------------------------- /test_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | from test_config import * 4 | 5 | def get_logger(filename, verbosity=1, name=None): 6 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 7 | formatter = logging.Formatter( 8 | "[%(asctime)s][%(levelname)s] %(message)s" 9 | ) 10 | os.makedirs(os.path.dirname(filename), exist_ok=True) 11 | logger = logging.getLogger(name) 12 | logger.setLevel(level_dict[verbosity]) 13 | 14 | fh = logging.FileHandler(filename, "w") 15 | fh.setFormatter(formatter) 16 | logger.addHandler(fh) 17 | 18 | sh = logging.StreamHandler() 19 | sh.setFormatter(formatter) 20 | logger.addHandler(sh) 21 | 22 | return logger 23 | 24 | class Log: 25 | @ex.capture 26 | def __init__(self, log_path) -> None: 27 | self.batch_data = dict() 28 | self.epoch_data = dict() 29 | self.max_data = {'best_epoch':-1, 'test_acc':-1} 30 | self.logger = get_logger(log_path) 31 | self.logger.info('Start') 32 | 33 | def update_batch(self, name, value): 34 | if name not in self.batch_data: 35 | self.batch_data[name] = list() 36 | self.batch_data[name].append(value) 37 | 38 | def info(self, msg): 39 | # logging.info(msg) 40 | self.logger.info(msg) 41 | 42 | @ex.capture 43 | def update_epoch(self, epoch, epoch_num, train_mode): 44 | self.logger.info('Epoch:[{}/{}]'.format(epoch + 1 , epoch_num)) 45 | for name in self.batch_data.keys(): 46 | if name not in self.epoch_data: 47 | self.epoch_data[name] = list() 48 | epoch_value = np.mean(self.batch_data[name]) 49 | self.epoch_data[name].append(epoch_value) 50 | self.batch_data[name] = list() 51 | if 'test/cls_acc' in name and epoch_value > self.max_data['test_acc']: 52 | self.max_data['test_acc'] = epoch_value 53 | self.max_data['best_epoch'] = epoch 54 | self.logger.info("{}: {}".format(name, self.epoch_data[name][-1])) 55 | if "loadweight" in train_mode: 56 | self.logger.info("Epoch:[{}] get the best test acc: {}" 57 | .format(self.max_data['best_epoch'], self.max_data['test_acc'])) 58 | -------------------------------------------------------------------------------- /one_shot_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | from one_shot_config import * 4 | 5 | def get_logger(filename, verbosity=1, name=None): 6 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 7 | formatter = logging.Formatter( 8 | "[%(asctime)s][%(levelname)s] %(message)s" 9 | ) 10 | os.makedirs(os.path.dirname(filename), exist_ok=True) 11 | logger = logging.getLogger(name) 12 | logger.setLevel(level_dict[verbosity]) 13 | 14 | fh = logging.FileHandler(filename, "w") 15 | fh.setFormatter(formatter) 16 | logger.addHandler(fh) 17 | 18 | sh = logging.StreamHandler() 19 | sh.setFormatter(formatter) 20 | logger.addHandler(sh) 21 | 22 | return logger 23 | 24 | class Log: 25 | @ex.capture 26 | def __init__(self, log_path) -> None: 27 | self.batch_data = dict() 28 | self.epoch_data = dict() 29 | self.max_data = {'best_epoch':-1, 'test_acc':-1} 30 | self.logger = get_logger(log_path) 31 | self.logger.info('Start') 32 | 33 | def update_batch(self, name, value): 34 | if name not in self.batch_data: 35 | self.batch_data[name] = list() 36 | self.batch_data[name].append(value) 37 | 38 | def info(self, msg): 39 | # logging.info(msg) 40 | self.logger.info(msg) 41 | 42 | @ex.capture 43 | def update_epoch(self, epoch, epoch_num, train_mode): 44 | self.logger.info('Epoch:[{}/{}]'.format(epoch + 1 , epoch_num)) 45 | for name in self.batch_data.keys(): 46 | if name not in self.epoch_data: 47 | self.epoch_data[name] = list() 48 | epoch_value = np.mean(self.batch_data[name]) 49 | self.epoch_data[name].append(epoch_value) 50 | self.batch_data[name] = list() 51 | if 'test/cls_acc' in name and epoch_value > self.max_data['test_acc']: 52 | self.max_data['test_acc'] = epoch_value 53 | self.max_data['best_epoch'] = epoch 54 | self.logger.info("{}: {}".format(name, self.epoch_data[name][-1])) 55 | if "loadweight" in train_mode: 56 | self.logger.info("Epoch:[{}] get the best test acc: {}" 57 | .format(self.max_data['best_epoch'], self.max_data['test_acc'])) 58 | -------------------------------------------------------------------------------- /one_shot_test_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | from one_shot_test_config import * 4 | 5 | def get_logger(filename, verbosity=1, name=None): 6 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 7 | formatter = logging.Formatter( 8 | "[%(asctime)s][%(levelname)s] %(message)s" 9 | ) 10 | os.makedirs(os.path.dirname(filename), exist_ok=True) 11 | logger = logging.getLogger(name) 12 | logger.setLevel(level_dict[verbosity]) 13 | 14 | fh = logging.FileHandler(filename, "w") 15 | fh.setFormatter(formatter) 16 | logger.addHandler(fh) 17 | 18 | sh = logging.StreamHandler() 19 | sh.setFormatter(formatter) 20 | logger.addHandler(sh) 21 | 22 | return logger 23 | 24 | class Log: 25 | @ex.capture 26 | def __init__(self, log_path) -> None: 27 | self.batch_data = dict() 28 | self.epoch_data = dict() 29 | self.max_data = {'best_epoch':-1, 'test_acc':-1} 30 | self.logger = get_logger(log_path) 31 | self.logger.info('Start') 32 | 33 | def update_batch(self, name, value): 34 | if name not in self.batch_data: 35 | self.batch_data[name] = list() 36 | self.batch_data[name].append(value) 37 | 38 | def info(self, msg): 39 | # logging.info(msg) 40 | self.logger.info(msg) 41 | 42 | @ex.capture 43 | def update_epoch(self, epoch, epoch_num, train_mode): 44 | self.logger.info('Epoch:[{}/{}]'.format(epoch + 1 , epoch_num)) 45 | for name in self.batch_data.keys(): 46 | if name not in self.epoch_data: 47 | self.epoch_data[name] = list() 48 | epoch_value = np.mean(self.batch_data[name]) 49 | self.epoch_data[name].append(epoch_value) 50 | self.batch_data[name] = list() 51 | if 'test/cls_acc' in name and epoch_value > self.max_data['test_acc']: 52 | self.max_data['test_acc'] = epoch_value 53 | self.max_data['best_epoch'] = epoch 54 | self.logger.info("{}: {}".format(name, self.epoch_data[name][-1])) 55 | if "loadweight" in train_mode: 56 | self.logger.info("Epoch:[{}] get the best test acc: {}" 57 | .format(self.max_data['best_epoch'], self.max_data['test_acc'])) 58 | -------------------------------------------------------------------------------- /sentence_bert_embedding.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer, LoggingHandler 2 | import numpy as np 3 | import logging 4 | from pprint import pprint 5 | 6 | def read_file(path): 7 | with open(path, 'r') as f: 8 | return f.read().split('\n') 9 | 10 | def write_file(content,path): 11 | with open(path, 'w') as f: 12 | f.write('\n'.join(content)) 13 | 14 | #### Just some code to print debug information to stdout 15 | np.set_printoptions(threshold=100) 16 | 17 | logging.basicConfig(format='%(asctime)s - %(message)s', 18 | datefmt='%Y-%m-%d %H:%M:%S', 19 | level=logging.INFO, 20 | handlers=[LoggingHandler()]) 21 | #### /print debug information to stdout 22 | def count_words(sentences_list): 23 | counter = 0 24 | for sentence in sentences_list: 25 | words = sentence.split() 26 | counter += len(words) 27 | return counter 28 | 29 | # path = './descriptions/ntu_chatgpt_des.txt' 30 | # des_list = read_file(path) 31 | # max_len = 0 32 | # label_list = [] 33 | # part_list = [] 34 | # for i, des in enumerate(des_list): 35 | # label, part_des = des.split(';',1) 36 | # label_list.append(label) 37 | # part_list.append(part_des) 38 | 39 | # label_path = './descriptions/ntu_labelmap.txt' 40 | # part_path = './descriptions/ntu_part_des.txt' 41 | # write_file(label_list, label_path) 42 | # write_file(part_list, part_path) 43 | # path = './descriptions/ntu_parts_des.txt' 44 | # path_des = './descriptions/ntu120_des.txt' 45 | 46 | # path = './descriptions/ntu_parts_from_GAP.txt' 47 | # part_list = read_file(path) 48 | # des_list = [des.replace(';', '.') for des in des_list] 49 | # des_part_list = part_list # [] 50 | # for i, part in enumerate(part_list): 51 | # label, part_des = part.split(';',1) 52 | # des_part_list.append(part_des) 53 | 54 | # write_file(des_part_list, './descriptions/ntu_des+parts_des.txt') 55 | 56 | # des_list = [des.replace(';', '.') for des in des_part_list] 57 | # ntu_60_list = des_list[:60] 58 | 59 | # path = './descriptions/ntu_labelmap.txt' 60 | path = './descriptions/pku_des.txt' 61 | 62 | des_list = read_file(path) 63 | ntu_60_list = des_list[:60] 64 | # pprint(ntu_60_list) 65 | 66 | 67 | # Load pre-trained Sentence Transformer Model. It will be downloaded automatically 68 | model = SentenceTransformer('all-mpnet-base-v2') 69 | 70 | 71 | sentence_embeddings = model.encode(des_list) 72 | print(sentence_embeddings.shape) 73 | save_path = './data/language/pku_des_embeddings.npy' 74 | np.save(save_path, sentence_embeddings) -------------------------------------------------------------------------------- /module/gcn/utils/tgcn.py: -------------------------------------------------------------------------------- 1 | # The based unit of graph convolutional networks. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class ConvTemporalGraphical(nn.Module): 7 | 8 | r"""The basic module for applying a graph convolution. 9 | 10 | Args: 11 | in_channels (int): Number of channels in the input sequence data 12 | out_channels (int): Number of channels produced by the convolution 13 | kernel_size (int): Size of the graph convolving kernel 14 | t_kernel_size (int): Size of the temporal convolving kernel 15 | t_stride (int, optional): Stride of the temporal convolution. Default: 1 16 | t_padding (int, optional): Temporal zero-padding added to both sides of 17 | the input. Default: 0 18 | t_dilation (int, optional): Spacing between temporal kernel elements. 19 | Default: 1 20 | bias (bool, optional): If ``True``, adds a learnable bias to the output. 21 | Default: ``True`` 22 | 23 | Shape: 24 | - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format 25 | - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format 26 | - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format 27 | - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format 28 | 29 | where 30 | :math:`N` is a batch size, 31 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, 32 | :math:`T_{in}/T_{out}` is a length of input/output sequence, 33 | :math:`V` is the number of graph nodes. 34 | """ 35 | 36 | def __init__(self, 37 | in_channels, 38 | out_channels, 39 | kernel_size, 40 | t_kernel_size=1, 41 | t_stride=1, 42 | t_padding=0, 43 | t_dilation=1, 44 | bias=True): 45 | super().__init__() 46 | 47 | self.kernel_size = kernel_size 48 | self.conv = nn.Conv2d( 49 | in_channels, 50 | out_channels * kernel_size, 51 | kernel_size=(t_kernel_size, 1), 52 | padding=(t_padding, 0), 53 | stride=(t_stride, 1), 54 | dilation=(t_dilation, 1), 55 | bias=bias) 56 | 57 | def forward(self, x, A): 58 | assert A.size(0) == self.kernel_size 59 | 60 | x = self.conv(x) 61 | 62 | n, kc, t, v = x.size() 63 | x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) 64 | x = torch.einsum('nkctv,kvw->nctw', (x, A)) 65 | 66 | return x.contiguous(), A 67 | -------------------------------------------------------------------------------- /test_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sacred import Experiment 3 | 4 | ex = Experiment("baseline", save_git_info=False) 5 | 6 | @ex.config 7 | def my_config(): 8 | track = "main" # main or sota 9 | split = '1' 10 | dataset = "ntu60" # ntu60: split 1-3, sota_split 5,12; ntu120: split 4-6, sota_split 10,24; pku: split 7-9 11 | lr = 0.05 #1e-5 12 | margin = 0.1 13 | weight_decay = 0.0005 14 | epoch_num = 50 15 | batch_size = 128 #128 16 | loss_type = "kl" # "kl" or "mse" or "kl+mse" or "kl+kd" or "kl+cosface" or "kl+sphereface" or "kl+margin" 17 | alpha = 1 18 | beta = 1 19 | m = 1 20 | DA = True 21 | support_factor = 0.9 # alpha in paper [0,1], 0.9 for NTU-60, 0.4 for NTU-120, 1.0 for PKU-MMD 22 | weight_path = './ckpts/split_{}_des_DA_epoch50_lr{}.pt'.format(split, lr) 23 | log_path = './output/log/test.log' # modify what you want 24 | save_path = './output/model/test.pt' 25 | loss_mode = "step" # "step" or "cos" 26 | step = [50, 80] 27 | 28 | ############################## ST-GCN ############################### 29 | in_channels = 3 30 | hidden_channels = 16 31 | hidden_dim = 256 32 | dropout = 0.5 33 | graph_args = { 34 | "layout" : 'ntu-rgb+d', 35 | "strategy" : 'spatial' 36 | } 37 | edge_importance_weighting = True 38 | ############################# downstream ############################# 39 | split_1 = [4,19,31,47,51] 40 | split_2 = [12,29,32,44,59] 41 | split_3 = [7,20,28,39,58] 42 | split_4 = [3, 18, 26, 38, 41, 60, 87, 99, 102, 110] 43 | split_5 = [5, 12, 14, 15, 17, 42, 67, 82, 100, 119] 44 | split_6 = [6, 20, 27, 33, 42, 55, 71, 97, 104, 118] 45 | split_7 = [1, 9, 20, 34, 50] 46 | split_8 = [3, 14, 29, 31, 49] 47 | split_9 = [2, 15, 39, 41, 43] 48 | unseen_label = eval('split_'+split) 49 | visual_size = 256 50 | language_size = 768 51 | max_frame = 50 52 | language_path = "./data/language/"+dataset+"_des_embeddings.npy" # des best 53 | train_list = "./data/zeroshot/"+dataset+"/split_"+split+"/seen_train_data.npy" 54 | train_label = "./data/zeroshot/"+dataset+"/split_"+split+"/seen_train_label.npy" 55 | test_list = "./data/zeroshot/"+dataset+"/split_"+split+"/unseen_data.npy" 56 | test_label = "./data/zeroshot/"+dataset+"/split_"+split+"/unseen_label.npy" 57 | ############################ sota compare ############################ 58 | sota_split = "5" 59 | unseen_label_5 = [10,11,19,26,56] 60 | unseen_label_12 = [3,5,9,12,15,40,42,47,51,56,58,59] 61 | unseen_label_10 = [4,13,37,43,49,65,88,95,99,106] 62 | unseen_label_24 = [5,9,11,16,18,20,22,29,35,39,45,49,59,68,70,81,84,87,93,94,104,113,114,119] 63 | sota_unseen = eval('unseen_label_'+sota_split) 64 | sota_train_list = "./sourcedata/sota/split_"+sota_split+"/train.npy" 65 | sota_train_label = "./sourcedata/sota/split_"+sota_split+"/train_label.npy" 66 | sota_test_list = "./sourcedata/sota/split_"+sota_split+"/test.npy" 67 | sota_test_label = "./sourcedata/sota/split_"+sota_split+"/test_label.npy" 68 | # %% 69 | -------------------------------------------------------------------------------- /descriptions/ntu_labelmap.txt: -------------------------------------------------------------------------------- 1 | Drink water 2 | Eat meal/snack 3 | Brushing teeth 4 | Brushing hair 5 | Drop 6 | Pickup 7 | Throw 8 | Sitting down 9 | Standing up (from sitting position) 10 | Clapping 11 | Reading 12 | Writing 13 | Tear up paper 14 | Wear jacket 15 | Take off jacket 16 | Wear a shoe 17 | Take off a shoe 18 | Wear on glasses 19 | Take off glasses 20 | Put on a hat/cap 21 | Take off a hat/cap 22 | Cheer up 23 | Hand waving 24 | Kicking something 25 | Reach into pocket 26 | Hopping (one foot jumping) 27 | Jumping up 28 | Make a phone call/Answer phone 29 | Playing with phone/tablet 30 | Typing on a keyboard 31 | Pointing to something with finger 32 | Taking a selfie 33 | Checking time (from watch) 34 | Rubbing two hands together 35 | Nod headbow 36 | Shake head 37 | Wiping face 38 | Salute 39 | Putting the palms together 40 | Crossing hands in front (saying stop) 41 | Sneeze/cough 42 | Staggering 43 | Falling 44 | Touching head (headache) 45 | Touching chest (stomachache/heart pain) 46 | Touching back (backache) 47 | Touching neck (neckache) 48 | Nausea or vomiting 49 | Using a fan (with hand or paper) 50 | Punching/Slapping other person 51 | Kicking other person 52 | Pushing other person 53 | Patting on back of other person 54 | Pointing finger at the other person 55 | Hugging other person 56 | Giving something to other person 57 | Touching other person's pocket 58 | Handshaking 59 | Walking towards each other 60 | Walking apart from each other 61 | Putting on headphones 62 | Taking off headphones 63 | Shooting at the basket 64 | Bouncing a ball 65 | Tennis bat swing 66 | Juggling table tennis balls 67 | Hushing 68 | Flicking hair 69 | Thumbs up 70 | Thumbs down 71 | Making OK sign 72 | Making victory sign 73 | Stapling book 74 | Counting money 75 | Cutting nails 76 | Cutting paper (using scissors) 77 | Snapping fingers 78 | Opening bottle 79 | Sniffing 80 | Squatting down 81 | Tossing a coin 82 | Folding paper 83 | Balling up paper 84 | Playing with a magic cube 85 | Applying cream on face 86 | Applying cream on hand back 87 | Putting on a bag 88 | Taking off a bag 89 | Putting something into a bag 90 | Taking something out of a bag 91 | Opening a box 92 | Moving heavy objects 93 | Shaking fist 94 | Throwing up a cap/hat 95 | Raising hands up 96 | Crossing arms 97 | Making arm circles 98 | Swinging arms 99 | Running on the spot 100 | Butt kicks (kick backward) 101 | Cross toe touch 102 | Side kick 103 | Yawning 104 | Stretching oneself 105 | Blowing nose 106 | Hitting other person with something 107 | Wielding knife towards other person 108 | Knocking over other person (hit with body) 109 | Grabbing other person's stuff 110 | Shooting at other person with a gun 111 | Stepping on foot 112 | High-five 113 | Cheers and drink 114 | Carrying something with other person 115 | Taking a photo of other person 116 | Following other person 117 | Whispering in other person's ear 118 | Exchanging things with other person 119 | Supporting somebody with hand 120 | Finger-guessing game (playing rock-paper-scissors) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sacred import Experiment 3 | 4 | ex = Experiment("baseline", save_git_info=False) 5 | 6 | @ex.config 7 | def my_config(): 8 | track = "main" # main or sota 9 | split = '1' 10 | dataset = "ntu60" # ntu60: split 1-3, sota_split 5,12; ntu120: split 4-6, sota_split 10,24; pku: split 7-9 11 | lr = 0.05 # 0.05 for ntu60 and pku, 0.005 for ntu120 12 | margin = 0.1 13 | weight_decay = 0.0005 14 | epoch_num = 50 15 | batch_size = 128 #128 16 | loss_type = "kl" 17 | alpha = 1 18 | beta = 1 19 | m = 1 20 | DA = True # DA means using our prototype-guided text feature alignment 21 | fix_encoder = False 22 | finetune = False 23 | support_factor = 0.9 # 0.9 for ntu60, 0.4 for ntu120, 1.0 for pku 24 | weight_path = './module/gcn/model/split_'+split+".pt" # only using when set fix_encoder/finetune, copy from SMIE 25 | log_path = './output/log/split_{}_{}_DA_des_support_factor{}_lr{}.log'.format(split,loss_type,support_factor,lr) 26 | # log_path = './output/log/sota_split10_des_DA_epoch100_lr{}_support_factor{}.log'.format(lr, support_factor) 27 | 28 | save_path = './output/model/split_{}_{}_DA_des_support_factor{}_lr{}.pt'.format(split,loss_type,support_factor,lr) 29 | loss_mode = "step" # "step" or "cos" 30 | step = [50, 80] 31 | ############################## ST-GCN ############################### 32 | in_channels = 3 33 | hidden_channels = 16 34 | hidden_dim = 256 35 | dropout = 0.5 36 | graph_args = { 37 | "layout" : 'ntu-rgb+d', 38 | "strategy" : 'spatial' 39 | } 40 | edge_importance_weighting = True 41 | ############################# downstream ############################# 42 | split_1 = [4,19,31,47,51] 43 | split_2 = [12,29,32,44,59] 44 | split_3 = [7,20,28,39,58] 45 | split_4 = [3, 18, 26, 38, 41, 60, 87, 99, 102, 110] 46 | split_5 = [5, 12, 14, 15, 17, 42, 67, 82, 100, 119] 47 | split_6 = [6, 20, 27, 33, 42, 55, 71, 97, 104, 118] 48 | split_7 = [1, 9, 20, 34, 50] 49 | split_8 = [3, 14, 29, 31, 49] 50 | split_9 = [2, 15, 39, 41, 43] 51 | unseen_label = eval('split_'+split) 52 | visual_size = 256 53 | language_size = 768 54 | max_frame = 50 55 | language_path = "./data/language/"+dataset+"_des_embeddings.npy" # des best 56 | train_list = "./data/zeroshot/"+dataset+"/split_"+split+"/seen_train_data.npy" 57 | train_label = "./data/zeroshot/"+dataset+"/split_"+split+"/seen_train_label.npy" 58 | test_list = "./data/zeroshot/"+dataset+"/split_"+split+"/unseen_data.npy" 59 | test_label = "./data/zeroshot/"+dataset+"/split_"+split+"/unseen_label.npy" 60 | ############################ sota compare ############################ 61 | sota_split = "5" # 5 or 12 or 10 or 24 62 | model_choice_for_sota = 'shift-gcn' # shift-gcn or st-gcn 63 | unseen_label_5 = [10,11,19,26,56] 64 | unseen_label_12 = [3,5,9,12,15,40,42,47,51,56,58,59] 65 | unseen_label_10 = [4,13,37,43,49,65,88,95,99,106] 66 | unseen_label_24 = [5,9,11,16,18,20,22,29,35,39,45,49,59,68,70,81,84,87,93,94,104,113,114,119] 67 | sota_unseen = eval('unseen_label_'+sota_split) 68 | sota_train_list = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/seen_train_data.npy" 69 | sota_train_label = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/seen_train_label.npy" 70 | sota_test_list = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/unseen_data.npy" 71 | sota_test_label = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/unseen_label.npy" 72 | # %% 73 | -------------------------------------------------------------------------------- /one_shot_test_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sacred import Experiment 3 | 4 | ex = Experiment("baseline", save_git_info=False) 5 | 6 | @ex.config 7 | def my_config(): 8 | track = "main" 9 | split = 'one_shot_full_10' 10 | dataset = "ntu60" 11 | lr = 0.05 #1e-5, 0.05 12 | margin = 0.1 13 | weight_decay = 0.0005 14 | epoch_num = 100 15 | batch_size = 128 #128 16 | loss_type = "kl" # 17 | alpha = 1 18 | beta = 1 19 | m = 1 20 | DA = True 21 | fix_encoder = False 22 | finetune = False 23 | support_factor = 0.9 #0.9 24 | weight_path = "your_model_weight_path" 25 | log_path = './output/log/your_log_path' 26 | save_path = './output/log/your_save_path' 27 | 28 | loss_mode = "step" # "step" or "cos" 29 | step = [50, 80] 30 | ############################## ST-GCN ############################### 31 | in_channels = 3 32 | hidden_channels = 16 33 | hidden_dim = 256 34 | dropout = 0.5 35 | graph_args = { 36 | "layout" : 'ntu-rgb+d', 37 | "strategy" : 'spatial' 38 | } 39 | edge_importance_weighting = True 40 | ############################# one-shot ############################# 41 | split_1 = [4,19,31,47,51] 42 | split_2 = [12,29,32,44,59] 43 | split_3 = [7,20,28,39,58] 44 | split_4 = [3, 18, 26, 38, 41, 60, 87, 99, 102, 110] 45 | split_5 = [5, 12, 14, 15, 17, 42, 67, 82, 100, 119] 46 | split_6 = [6, 20, 27, 33, 42, 55, 71, 97, 104, 118] 47 | split_7 = [1, 9, 20, 34, 50] 48 | split_8 = [3, 14, 29, 31, 49] 49 | split_9 = [2, 15, 39, 41, 43] 50 | one_shot_full_10 = [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] 51 | one_shot_full_20 = [0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114] 52 | one_shot_pku_10 = [10, 30, 40, 0, 5, 35, 45, 15, 20, 25] 53 | unseen_label = eval(split) 54 | visual_size = 256 55 | language_size = 768 56 | max_frame = 50 57 | language_path = "./data/language/"+dataset+"_des_embeddings.npy" # des best 58 | one_shot_exemplar_data_path = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_exemplar_data.npy" 59 | train_list = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_train_data.npy" 60 | train_label = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_train_label.npy" 61 | test_list = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_test_data.npy" 62 | test_label = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_test_label.npy" 63 | ############################ sota compare ############################ 64 | sota_split = "10" 65 | model_choice_for_sota = 'shift-gcn' # shift-gcn or st-gcn 66 | unseen_label_5 = [10,11,19,26,56] 67 | unseen_label_12 = [3,5,9,12,15,40,42,47,51,56,58,59] 68 | unseen_label_10 = [4,13,37,43,49,65,88,95,99,106] 69 | unseen_label_24 = [5,9,11,16,18,20,22,29,35,39,45,49,59,68,70,81,84,87,93,94,104,113,114,119] 70 | sota_unseen = eval('unseen_label_'+sota_split) 71 | sota_train_list = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/seen_train_data.npy" 72 | sota_train_label = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/seen_train_label.npy" 73 | sota_test_list = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/unseen_data.npy" 74 | sota_test_label = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/unseen_label.npy" 75 | # sota_train_list = "./sourcedata/sota/split_"+sota_split+"/train.npy" 76 | # sota_train_label = "./sourcedata/sota/split_"+sota_split+"/train_label.npy" 77 | # sota_test_list = "./sourcedata/sota/split_"+sota_split+"/test.npy" 78 | # sota_test_label = "./sourcedata/sota/split_"+sota_split+"/test_label.npy" 79 | # %% 80 | -------------------------------------------------------------------------------- /one_shot_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sacred import Experiment 3 | 4 | ex = Experiment("baseline", save_git_info=False) 5 | 6 | @ex.config 7 | def my_config(): 8 | track = "main" 9 | split = 'one_shot_full_10' 10 | dataset = "ntu60" # ntu60: one_shot_full_10; ntu120: one_shot_full_20; pku: one_shot_pku_10 11 | lr = 0.05 #1e-5, 0.05 12 | margin = 0.1 13 | weight_decay = 0.0005 14 | epoch_num = 100 15 | batch_size = 128 #128 16 | loss_type = "kl" # 17 | alpha = 1 18 | beta = 1 19 | m = 1 20 | DA = True 21 | fix_encoder = False 22 | finetune = False 23 | support_factor = 0.9 24 | seed = 1314 # if you modify this, please also modify the seed of one_shot_main.py 25 | # weight_path = './module/gcn/model/split_'+split+".pt" 26 | log_path = './output/log/'+split+'_des_epoch{}_lr{}_alpha{}_seed{}.log'.format(epoch_num,lr,support_factor,seed) 27 | save_path = './output/model/'+split+'_des_epoch{}_lr{}_alpha{}_seed{}.pt'.format(epoch_num,lr,support_factor,seed) 28 | loss_mode = "step" # "step" or "cos" 29 | step = [50, 80] 30 | ############################## ST-GCN ############################### 31 | in_channels = 3 32 | hidden_channels = 16 33 | hidden_dim = 256 34 | dropout = 0.5 35 | graph_args = { 36 | "layout" : 'ntu-rgb+d', 37 | "strategy" : 'spatial' 38 | } 39 | edge_importance_weighting = True 40 | ############################# one-shot ############################# 41 | split_1 = [4,19,31,47,51] 42 | split_2 = [12,29,32,44,59] 43 | split_3 = [7,20,28,39,58] 44 | split_4 = [3, 18, 26, 38, 41, 60, 87, 99, 102, 110] 45 | split_5 = [5, 12, 14, 15, 17, 42, 67, 82, 100, 119] 46 | split_6 = [6, 20, 27, 33, 42, 55, 71, 97, 104, 118] 47 | split_7 = [1, 9, 20, 34, 50] 48 | split_8 = [3, 14, 29, 31, 49] 49 | split_9 = [2, 15, 39, 41, 43] 50 | one_shot_full_10 = [0, 6, 12, 18, 24, 30, 36, 42, 48, 54] 51 | one_shot_full_20 = [0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114] 52 | one_shot_pku_10 = [10, 30, 40, 0, 5, 35, 45, 15, 20, 25] 53 | unseen_label = eval(split) 54 | visual_size = 256 55 | language_size = 768 56 | max_frame = 50 57 | language_path = "./data/language/"+dataset+"_des_embeddings.npy" # des best 58 | one_shot_exemplar_data_path = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_exemplar_data.npy" 59 | train_list = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_train_data.npy" 60 | train_label = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_train_label.npy" 61 | test_list = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_test_data.npy" 62 | test_label = "./data/zeroshot/"+dataset+"/"+split+"/one_shot_test_label.npy" 63 | ############################ sota compare ############################ 64 | sota_split = "10" 65 | model_choice_for_sota = 'shift-gcn' # shift-gcn or st-gcn 66 | unseen_label_5 = [10,11,19,26,56] 67 | unseen_label_12 = [3,5,9,12,15,40,42,47,51,56,58,59] 68 | unseen_label_10 = [4,13,37,43,49,65,88,95,99,106] 69 | unseen_label_24 = [5,9,11,16,18,20,22,29,35,39,45,49,59,68,70,81,84,87,93,94,104,113,114,119] 70 | sota_unseen = eval('unseen_label_'+sota_split) 71 | sota_train_list = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/seen_train_data.npy" 72 | sota_train_label = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/seen_train_label.npy" 73 | sota_test_list = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/unseen_data.npy" 74 | sota_test_label = "./data/zeroshot/"+dataset+"/unseen_label_"+sota_split+"/unseen_label.npy" 75 | # sota_train_list = "./sourcedata/sota/split_"+sota_split+"/train.npy" 76 | # sota_train_label = "./sourcedata/sota/split_"+sota_split+"/train_label.npy" 77 | # sota_test_list = "./sourcedata/sota/split_"+sota_split+"/test.npy" 78 | # sota_test_label = "./sourcedata/sota/split_"+sota_split+"/test_label.npy" 79 | # %% 80 | -------------------------------------------------------------------------------- /tool.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | import math 4 | 5 | mlambda = [ 6 | lambda x: x ** 0, 7 | lambda x: x ** 1, 8 | lambda x: 2 * x ** 2 - 1, 9 | lambda x: 4 * x ** 3 - 3 * x, 10 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 11 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 12 | ] 13 | def gen_label(labels): 14 | num = len(labels) 15 | gt = numpy.zeros(shape=(num,num)) 16 | for i, label in enumerate(labels): 17 | for k in range(num): 18 | if labels[k] == label: 19 | gt[i,k] = 1 20 | return gt 21 | 22 | def gen_label_from_text_sim(x): 23 | x = x / x.norm(dim=-1, keepdim=True) 24 | return x @ x.t() 25 | 26 | def get_m_theta(cos_theta, m=4): 27 | cos_m_theta = mlambda[m](cos_theta) 28 | temp = cos_theta.clone().detach() 29 | theta = torch.acos(temp.clamp(-1.+1e-6, 1.-1e-6)) 30 | k = (theta*m / math.pi).floor() 31 | sign = -2 * torch.remainder(k, 2) + 1 # (-1)**k 32 | phi_theta = sign * cos_m_theta - 2. * k 33 | return phi_theta 34 | # d_theta = phi_theta - cos_theta 35 | # return d_theta + x 36 | 37 | 38 | def create_logits(x1, x2, logit_scale, exp=True): 39 | x1 = x1 / x1.norm(dim=-1, keepdim=True) 40 | x2 = x2 / x2.norm(dim=-1, keepdim=True) 41 | if exp: 42 | scale = logit_scale.exp() 43 | else: 44 | scale = logit_scale 45 | 46 | # cosine similarity as logits 47 | logits_per_x1 = scale * x1 @ x2.t() 48 | logits_per_x2 = logits_per_x1.t() 49 | 50 | # shape = [global_batch_size, global_batch_size] 51 | return logits_per_x1, logits_per_x2 52 | 53 | def create_sim_matrix(x1, x2, alpha=1): 54 | x1 = x1 / x1.norm(dim=-1, keepdim=True) 55 | x2 = x2 / x2.norm(dim=-1, keepdim=True) 56 | x1x1 = alpha * x1 @ x1.t() 57 | x1x2 = alpha * x1 @ x2.t() 58 | x2x2 = alpha * x2 @ x2.t() 59 | return x1x1,x1x2,x2x2 60 | 61 | 62 | def get_acc(x1, x2, unseen_label, label): 63 | x1 = x1 / x1.norm(dim=-1, keepdim=True) 64 | x2 = x2 / x2.norm(dim=-1, keepdim=True) 65 | logits = x1 @ x2.t() # 128, 5 66 | pred = torch.argmax(logits, dim=1) 67 | unseen_label = torch.tensor(unseen_label).cuda() 68 | pred = torch.index_select(unseen_label,0,pred) 69 | acc = pred.eq(label.view_as(pred)).float().mean() 70 | return acc, pred 71 | 72 | def get_acc_v2(x1, x2, unseen_label, label): 73 | x1 = x1 / x1.norm(dim=-1, keepdim=True) 74 | x2 = x2 / x2.norm(dim=-1, keepdim=True) 75 | logits = x1 @ x2.t() # 128, 5 76 | pred = torch.argmax(logits, dim=1) 77 | unseen_len = len(unseen_label) 78 | 79 | old_pred = pred 80 | ent = softmax_entropy(logits) 81 | 82 | # unseen_len = len(unseen_label) 83 | # for i in range(unseen_len): 84 | # class_support_set = x1[pred == i] 85 | # class_logit = logits[pred == i] 86 | # class_ent = softmax_entropy(class_logit) 87 | # _, indices = torch.topk(class_ent, 5) 88 | # z = torch.mean(class_support_set[indices], dim=-1) 89 | # z_list.append(z) 90 | 91 | 92 | unseen_label = torch.tensor(unseen_label).cuda() 93 | pred = torch.index_select(unseen_label,0,pred) 94 | acc = pred.eq(label.view_as(pred)).float().mean() 95 | return acc, pred, old_pred, ent, x1 96 | 97 | def get_acc_v3(x1, x2, unseen_label, label): 98 | x1 = x1 / x1.norm(dim=-1, keepdim=True) 99 | x2 = x2 / x2.norm(dim=-1, keepdim=True) 100 | logits = x1 @ x2.t() # 128, 5 101 | pred = torch.argmax(logits, dim=1) 102 | ent = softmax_entropy(logits) 103 | unseen_label = torch.tensor(unseen_label).cuda() 104 | pred = torch.index_select(unseen_label,0,pred) 105 | acc = pred.eq(label.view_as(pred)).float().mean() 106 | return acc, pred, ent 107 | 108 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 109 | """Entropy of softmax distribution from logits.""" 110 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 111 | 112 | # def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 113 | # """Entropy of softmax distribution from logits.""" 114 | # return -(x.softmax(1) * math.log2(math.e) * x.log_softmax(1)).sum(1) 115 | -------------------------------------------------------------------------------- /descriptions/pku_des.txt: -------------------------------------------------------------------------------- 1 | "Bow" is lowering one's head or body as a gesture of respect, greeting, or acknowledgment. 2 | "Brushing hair" is using a brush or comb to groom and style one's hair. 3 | "Brushing teeth" is using a toothbrush and toothpaste to clean one's teeth and maintain oral hygiene. 4 | "Check time (from watch)" is looking at one's wristwatch or clock to determine the current time. 5 | "Cheer up" is offering encouragement or support to someone who is feeling down or sad. 6 | "Clapping" is striking one's hands together repeatedly to make a sound, typically as a form of applause or appreciation. 7 | "Cross hands in front (say stop)" is placing one's hands in front of one's body with the palms facing outward. 8 | "Drink water" is consuming water or another liquid to quench thirst or hydrate the body. 9 | "Drop" is releasing an object or item from one's hand or grasp, typically accidentally or intentionally. 10 | "Eat meal/snack" is consuming food or a small portion of food as sustenance or for pleasure. 11 | "Falling" is losing one's balance or support and dropping to the ground or another surface. 12 | "Giving something to other person" is transferring an object or item to another person, typically as a gift, gesture of kindness. 13 | "Hand waving" is making a waving motion with one's hand or hands, typically as a greeting, farewell, or to get someone's attention. 14 | "Handshaking" is grasping another person's hand and shaking it up and down, typically as a form of greeting or agreement. 15 | "Hopping (one foot jumping)" is jumping on one foot repeatedly, typically as a form of exercise or play. 16 | "Hugging other person" is wrapping one's arms around another person as a gesture of affection or comfort. 17 | "Jumping up" is propelling oneself off the ground or another surface with one or both feet. 18 | "Kicking other person" is striking another person with one's foot, typically as a form of aggression or defense. 19 | "Kicking something" is striking an object with one's foot, typically as a form of play or to achieve a desired result. 20 | "Making a phone call/answering phone" is using a phone to communicate with another person, either by initiating or receiving a call. 21 | "Patting on back of other person" is striking another person's back with one's hand in a gentle or congratulatory manner. 22 | "Pickup" is lifting or taking hold of an object or item from a surface or another person's grasp. 23 | "Playing with phone/tablet" is interacting with a mobile phone or tablet device. 24 | "Pointing finger at the other person" is extending one's finger in the direction of another person. 25 | "Pointing to something with finger" is extending one's finger in the direction of an object or location. 26 | "Punching/slapping other person" is striking another person with one's hand, typically as a form of aggression or defense. 27 | "Pushing other person" is using physical force to move another person away from oneself or a particular location. 28 | "Put on a hat/cap" is placing a hat or cap on one's head, typically for protection from the elements or as a fashion accessory. 29 | "Put something inside pocket" is placing an object or item into a pocket or pouch, typically for storage or convenience purposes. 30 | "Reading" is processing and interpreting written or printed text, typically for education, entertainment, or information purposes. 31 | "Rub two hands together" is applying friction by rubbing one's hands together, typically to warm them up or to create a soothing sensation. 32 | "Salute" is making a gesture of respect or acknowledgment by raising one's hand to the forehead or eyebrow level. 33 | "Sitting down" is lowering one's body onto a seat or other surface for the purpose of resting or engaging in an activity. 34 | "Standing up" is lifting one's body from a seated or prone position to an upright posture. 35 | "Take off a hat/cap" is removing a hat or cap from one's head, typically for comfort or convenience purposes. 36 | "Take off glasses" is removing eyeglasses or sunglasses from one's face or head, typically for comfort or convenience purposes. 37 | "Take off jacket" is removing a jacket or outer layer of clothing from one's body, typically for comfort or temperature regulation purposes. 38 | "Take out something from pocket" is removing an object or item from a pocket or pouch, typically for use or inspection purposes. 39 | "Taking a selfie" is using a camera or phone to take a photograph of oneself, typically for personal or social media purposes. 40 | "Tear up paper" is separating paper or a similar material into smaller pieces by hand or with a tool, typically for disposal or artistic purposes. 41 | "Throw" is propelling an object or item through the air with force, typically as a form of play, sport, or aggression. 42 | "Touch back (backache)" is placing one's hand on the back to alleviate or indicate pain in that area. 43 | "Touch chest (stomachache/heart pain)" is placing one's hand on the chest to alleviate or indicate pain in that area. 44 | "Touch head (headache)" is placing one's hand on the head. 45 | "Touch neck (neckache)" is placing one's hand on the neck. 46 | "Typing on a keyboard" is inputting text or commands into a computer or device using a keyboard or similar input device. 47 | "Use a fan (with hand or paper)/feeling warm" is employing a fan or makeshift device to create a flow of air or coolness. 48 | "Wear jacket" is putting on a jacket or outer layer of clothing, typically for protection or warmth. 49 | "Wear on glasses" is putting on eyeglasses or sunglasses to aid in vision or eye protection. 50 | "Wipe face" is using a cloth, tissue, or one's hand to remove sweat or moisture from one's face or to clean it. 51 | "Writing" is creating written or printed text using a pen, pencil, or similar tool. -------------------------------------------------------------------------------- /module/gcn/st_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils.tgcn import ConvTemporalGraphical 6 | from .utils.graph import Graph 7 | 8 | class Model(nn.Module): 9 | r"""Spatial temporal graph convolutional networks.""" 10 | 11 | def __init__(self, in_channels, hidden_channels, hidden_dim, graph_args, 12 | edge_importance_weighting, **kwargs): 13 | super().__init__() 14 | 15 | # load graph 16 | self.graph = Graph(**graph_args) 17 | A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) 18 | self.register_buffer('A', A) 19 | self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) 20 | # build networks 21 | spatial_kernel_size = A.size(0) 22 | temporal_kernel_size = 9 23 | kernel_size = (temporal_kernel_size, spatial_kernel_size) 24 | kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} 25 | self.st_gcn_networks = nn.ModuleList(( 26 | st_gcn(in_channels, hidden_channels, kernel_size, 1, residual=False, **kwargs0), 27 | st_gcn(hidden_channels, hidden_channels, kernel_size, 1, **kwargs), 28 | st_gcn(hidden_channels, hidden_channels, kernel_size, 1, **kwargs), 29 | st_gcn(hidden_channels, hidden_channels, kernel_size, 1, **kwargs), 30 | st_gcn(hidden_channels, hidden_channels * 2, kernel_size, 2, **kwargs), 31 | st_gcn(hidden_channels * 2, hidden_channels * 2, kernel_size, 1, **kwargs), 32 | st_gcn(hidden_channels * 2, hidden_channels * 2, kernel_size, 1, **kwargs), 33 | st_gcn(hidden_channels * 2, hidden_channels * 4, kernel_size, 2, **kwargs), 34 | st_gcn(hidden_channels * 4, hidden_channels * 4, kernel_size, 1, **kwargs), 35 | st_gcn(hidden_channels * 4, hidden_dim, kernel_size, 1, **kwargs), 36 | )) 37 | 38 | # initialize parameters for edge importance weighting 39 | if edge_importance_weighting: 40 | self.edge_importance = nn.ParameterList([ 41 | nn.Parameter(torch.ones(self.A.size())) 42 | for i in self.st_gcn_networks 43 | ]) 44 | else: 45 | self.edge_importance = [1] * len(self.st_gcn_networks) 46 | 47 | def forward(self, x, ignore_joint=[]): 48 | 49 | # data normalization 50 | N, C, T, V, M = x.size() 51 | x = x.permute(0, 4, 3, 1, 2).contiguous() 52 | x = x.view(N * M, V * C, T) 53 | x = self.data_bn(x) 54 | x = x.view(N, M, V, C, T) 55 | x = x.permute(0, 1, 3, 4, 2).contiguous() 56 | x = x.view(N * M, C, T, V) 57 | 58 | #1.获取未被mask掉的节点序列 59 | all_joint = set(range(V)) 60 | remain_joint = list(all_joint - set(ignore_joint)) 61 | remain_joint = sorted(remain_joint) 62 | x = x[:,:,:,remain_joint] 63 | 64 | for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): 65 | x, _ = gcn(x, self.A * importance, remain_joint) 66 | 67 | # print(x.shape) 68 | 69 | x = F.avg_pool2d(x, x.size()[2:]) 70 | # print(x.shape) 71 | x = x.view(N, M, -1).mean(dim=1) 72 | 73 | return x 74 | 75 | 76 | class st_gcn(nn.Module): 77 | r"""Applies a spatial temporal graph convolution over an input graph sequence. 78 | Args: 79 | in_channels (int): Number of channels in the input sequence data 80 | out_channels (int): Number of channels produced by the convolution 81 | kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel 82 | stride (int, optional): Stride of the temporal convolution. Default: 1 83 | dropout (int, optional): Dropout rate of the final output. Default: 0 84 | residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` 85 | Shape: 86 | - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format 87 | - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format 88 | - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format 89 | - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format 90 | where 91 | :math:`N` is a batch size, 92 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, 93 | :math:`T_{in}/T_{out}` is a length of input/output sequence, 94 | :math:`V` is the number of graph nodes. 95 | """ 96 | 97 | def __init__(self, 98 | in_channels, 99 | out_channels, 100 | kernel_size, 101 | stride=1, 102 | dropout=0, 103 | residual=True): 104 | super().__init__() 105 | 106 | assert len(kernel_size) == 2 107 | assert kernel_size[0] % 2 == 1 108 | padding = ((kernel_size[0] - 1) // 2, 0) 109 | 110 | self.gcn = ConvTemporalGraphical(in_channels, out_channels, 111 | kernel_size[1]) 112 | 113 | self.tcn = nn.Sequential( 114 | nn.BatchNorm2d(out_channels), 115 | nn.ReLU(inplace=True), 116 | nn.Conv2d( 117 | out_channels, 118 | out_channels, 119 | (kernel_size[0], 1), 120 | (stride, 1), 121 | padding, 122 | ), 123 | nn.BatchNorm2d(out_channels), 124 | nn.Dropout(dropout, inplace=True), 125 | ) 126 | 127 | if not residual: 128 | self.residual = lambda x: 0 129 | 130 | elif (in_channels == out_channels) and (stride == 1): 131 | self.residual = lambda x: x 132 | 133 | else: 134 | self.residual = nn.Sequential( 135 | nn.Conv2d( 136 | in_channels, 137 | out_channels, 138 | kernel_size=1, 139 | stride=(stride, 1)), 140 | nn.BatchNorm2d(out_channels), 141 | ) 142 | 143 | self.relu = nn.ReLU(inplace=True) 144 | 145 | def forward(self, x, A, remain_joint): 146 | 147 | A = A[:,remain_joint,:] 148 | A = A[:,:,remain_joint] 149 | res = self.residual(x) 150 | x, A = self.gcn(x, A) 151 | x = self.tcn(x) + res 152 | return self.relu(x), A 153 | -------------------------------------------------------------------------------- /module/gcn/utils/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Graph(): 4 | """ The Graph to model the skeletons extracted by the openpose 5 | 6 | Args: 7 | strategy (string): must be one of the follow candidates 8 | - uniform: Uniform Labeling 9 | - distance: Distance Partitioning 10 | - spatial: Spatial Configuration 11 | For more information, please refer to the section 'Partition Strategies' 12 | in our paper (https://arxiv.org/abs/1801.07455). 13 | 14 | layout (string): must be one of the follow candidates 15 | - openpose: Is consists of 18 joints. For more information, please 16 | refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output 17 | - ntu-rgb+d: Is consists of 25 joints. For more information, please 18 | refer to https://github.com/shahroudy/NTURGB-D 19 | 20 | max_hop (int): the maximal distance between two connected nodes 21 | dilation (int): controls the spacing between the kernel points 22 | 23 | """ 24 | 25 | def __init__(self, 26 | layout='openpose', 27 | strategy='uniform', 28 | max_hop=1, 29 | dilation=1): 30 | self.max_hop = max_hop 31 | self.dilation = dilation 32 | 33 | self.get_edge(layout) 34 | self.hop_dis = get_hop_distance( 35 | self.num_node, self.edge, max_hop=max_hop) 36 | self.get_adjacency(strategy) 37 | 38 | def __str__(self): 39 | return self.A 40 | 41 | def get_edge(self, layout): 42 | if layout == 'openpose': 43 | self.num_node = 18 44 | self_link = [(i, i) for i in range(self.num_node)] 45 | neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), 46 | (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), 47 | (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] 48 | self.edge = self_link + neighbor_link 49 | self.center = 1 50 | elif layout == 'ntu-rgb+d': 51 | self.num_node = 25 52 | self_link = [(i, i) for i in range(self.num_node)] 53 | neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), 54 | (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), 55 | (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), 56 | (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), 57 | (22, 23), (23, 8), (24, 25), (25, 12)] 58 | neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] 59 | self.edge = self_link + neighbor_link 60 | self.center = 21 - 1 61 | elif layout == 'nw-ucla': 62 | self.num_node = 20 63 | self_link = [(i, i) for i in range(self.num_node)] 64 | neighbor_1base = [(1, 2), (2, 3), (4, 3), (5, 3), (6, 5), (7, 6), 65 | (8, 7), (9, 3), (10, 9), (11, 10), (12, 11), (13, 1), 66 | (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18), 67 | (20, 19)] 68 | neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] 69 | self.edge = self_link + neighbor_link 70 | self.center = 3 - 1 71 | else: 72 | raise ValueError("Do Not Exist This Layout.") 73 | 74 | def get_adjacency(self, strategy): 75 | valid_hop = range(0, self.max_hop + 1, self.dilation) 76 | adjacency = np.zeros((self.num_node, self.num_node)) 77 | for hop in valid_hop: 78 | adjacency[self.hop_dis == hop] = 1 79 | normalize_adjacency = normalize_digraph(adjacency) 80 | 81 | if strategy == 'uniform': 82 | A = np.zeros((1, self.num_node, self.num_node)) 83 | A[0] = normalize_adjacency 84 | self.A = A 85 | elif strategy == 'distance': 86 | A = np.zeros((len(valid_hop), self.num_node, self.num_node)) 87 | for i, hop in enumerate(valid_hop): 88 | A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == 89 | hop] 90 | self.A = A 91 | elif strategy == 'spatial': 92 | A = [] 93 | for hop in valid_hop: 94 | a_root = np.zeros((self.num_node, self.num_node)) 95 | a_close = np.zeros((self.num_node, self.num_node)) 96 | a_further = np.zeros((self.num_node, self.num_node)) 97 | for i in range(self.num_node): 98 | for j in range(self.num_node): 99 | if self.hop_dis[j, i] == hop: 100 | if self.hop_dis[j, self.center] == self.hop_dis[ 101 | i, self.center]: 102 | a_root[j, i] = normalize_adjacency[j, i] 103 | elif self.hop_dis[j, self. 104 | center] > self.hop_dis[i, self. 105 | center]: 106 | a_close[j, i] = normalize_adjacency[j, i] 107 | else: 108 | a_further[j, i] = normalize_adjacency[j, i] 109 | if hop == 0: 110 | A.append(a_root) 111 | else: 112 | A.append(a_root + a_close) 113 | A.append(a_further) 114 | A = np.stack(A) 115 | self.A = A 116 | else: 117 | raise ValueError("Do Not Exist This Strategy") 118 | 119 | 120 | def get_hop_distance(num_node, edge, max_hop=1): 121 | A = np.zeros((num_node, num_node)) 122 | for i, j in edge: 123 | A[j, i] = 1 124 | A[i, j] = 1 125 | 126 | # compute hop steps 127 | hop_dis = np.zeros((num_node, num_node)) + np.inf 128 | transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] 129 | arrive_mat = (np.stack(transfer_mat) > 0) 130 | for d in range(max_hop, -1, -1): 131 | hop_dis[arrive_mat[d]] = d 132 | return hop_dis 133 | 134 | 135 | def normalize_digraph(A): 136 | Dl = np.sum(A, 0) 137 | num_node = A.shape[0] 138 | Dn = np.zeros((num_node, num_node)) 139 | for i in range(num_node): 140 | if Dl[i] > 0: 141 | Dn[i, i] = Dl[i]**(-1) 142 | AD = np.dot(A, Dn) 143 | return AD 144 | 145 | 146 | def normalize_undigraph(A): 147 | Dl = np.sum(A, 0) 148 | num_node = A.shape[0] 149 | Dn = np.zeros((num_node, num_node)) 150 | for i in range(num_node): 151 | if Dl[i] > 0: 152 | Dn[i, i] = Dl[i]**(-0.5) 153 | DAD = np.dot(np.dot(Dn, A), Dn) 154 | return DAD 155 | -------------------------------------------------------------------------------- /descriptions/ntu60_des.txt: -------------------------------------------------------------------------------- 1 | "Drink water" refers to the act of taking in water by mouth to hydrate or quench thirst as a human action. 2 | "Eat meal/snack" is consuming food for sustenance, nourishment, or pleasure. 3 | "Brushing teeth" is the act of cleaning teeth and gums using a toothbrush and toothpaste for oral hygiene. 4 | "Brushing hair" involves using a brush or comb to groom and style hair. 5 | "Drop" refers to accidentally or intentionally releasing an object from one's grasp, allowing it to fall to the ground or another surface. 6 | “Pickup” involves lifting or grasping an object with one's hands. 7 | "Throw" is the act of projecting an object through the air by using the hand or arm. 8 | "Sitting down" involves moving from a standing position to a seated position. 9 | "Standing up (from sitting position)" involves transitioning from a seated position to an upright position by pushing up with the legs. 10 | "Clapping" involves striking one's palms together to make a sound, typically as a form of applause or approval. 11 | "Reading" involves interpreting and understanding written or printed words and symbols. 12 | "Writing" involves using a pen, pencil, or other writing tool to create letters, words, or symbols on a surface such as paper or a screen. 13 | "Tear up paper" involves using one's hands to rip apart paper into smaller pieces. 14 | "Wear jacket" is the act of putting on a garment designed to cover the upper body and arms. 15 | "Take off jacket" involves removing a jacket or outer garment from one's body. 16 | "Wear a shoe" involves putting a shoe onto one's foot for protection or fashion purposes. 17 | "Take off a shoe" involves removing a shoe from one's foot using one's hands. 18 | "Wear on glasses" involves putting eyeglasses or spectacles on one's face to correct vision or protect the eyes. 19 | "Take off glasses" involves removing eyeglasses from one's face with one's hands. 20 | "Put on a hat/cap" involves placing headwear over one's head, typically for protection from the sun. 21 | "Take off a hat/cap" involves removing a piece of headwear from one's head, typically by lifting or pulling it off. 22 | "Cheer up" involves trying to make oneself or someone else feel more positive or happy. 23 | "Hand waving" is a gesture involving moving one's hand or hands to signal, greet, or draw attention. 24 | "Kicking something" is a human action that involves striking an object with the foot. 25 | "Reach into pocket" is the act of extending one's hand into a pocket to retrieve an item or to adjust the pocket contents. 26 | "Hopping (one foot jumping)" is a human action that involves jumping repeatedly on one foot. 27 | "Jumping up" involves leaping upward from a standing position as part of physical training. 28 | "Make a phone call/Answer phone" involves using a phone to either initiate or receive a phone conversation. 29 | "Playing with phone/tablet" involves manipulating electronic devices through various interactions such as tapping, swiping, or scrolling. 30 | "Typing on a keyboard" is the act of pressing keys on a keyboard to input information or commands into a computer or other electronic device. 31 | "Pointing to something with finger" involves extending the arm and using the finger to indicate or draw attention to a specific object or location. 32 | "Taking a selfie" is using a camera to take a self-portrait photograph, typically for sharing on social media or personal keepsake. 33 | "Checking time (from watch)" involves looking at a wristwatch or other timekeeping device to determine the current time. 34 | "Rubbing two hands together" involves moving one's hands back and forth against each other, often for warmth. 35 | "Nod headbow" involves lowering one's head briefly in a sign of respect, greeting, or acknowledgment. 36 | "Shake head" is moving the head from side to side in a rapid or deliberate manner, often to indicate disagreement. 37 | "Wiping face" involves using a cloth or one's hands to remove dirt, sweat, or moisture from one's face. 38 | "Salute" is a gesture of respect or greeting, typically performed by raising one's hand to the forehead or brim of a hat. 39 | "Putting the palms together" is the act of pressing one's hands together, often as a sign of respect, greeting, or prayer. 40 | "Crossing hands in front (saying stop)" is a gesture where one crosses their arms in front of their body and says "stop". 41 | "Sneeze/cough" is the involuntary or deliberate act of expelling air and sometimes mucus from the nose and mouth, typically due to illness or irritation. 42 | "Staggering" is the unsteady movement or swaying of the body, usually caused by intoxication, dizziness, or fatigue. 43 | "Falling" is the sudden loss of balance resulting in a person dropping or collapsing to the ground or a lower surface. 44 | "Touching head (headache)" involves placing one's hands on the head to alleviate pain or discomfort in the head or neck area. 45 | "Touching chest (stomachache/heart pain)" is a human action that involves placing a hand on the chest to alleviate discomfort or pain in the chest area. 46 | "Touching back (backache)" involves placing one's hands on the back to alleviate or investigate pain, discomfort, or tension in the back. 47 | "Touching neck (neckache)" refers to the action of pressing or rubbing the neck with one's hands, often to relieve pain or discomfort in the neck area. 48 | "Nausea or vomiting" is a condition characterized by the urge to vomit or the act of forcefully expelling stomach contents through the mouth. 49 | "Using a fan (with hand or paper) to feel warm" involves creating airflow with a handheld or paper fan for the purpose of cooling down. 50 | "Punching/Slapping other person" is the act of striking another person with one's fist or open hand, respectively, typically for physical harm or aggression. 51 | "Kicking other person" involves striking another person with one's foot, often for physical harm or aggression. 52 | "Pushing other person" involves applying force to another person with one's body, typically with the intent to move or manipulate them. 53 | "Patting on back of other person" is the act of lightly striking or tapping another person's back with one's hand. 54 | "Pointing finger at the other person" involves extending one's finger towards another person, typically to indicate something or express disapproval. 55 | "Hugging other person" involves embracing another person with one's arms, often as a gesture of affection or comfort. 56 | "Giving something to other person" involves transferring an object or item to another person, typically as a gift or as part of an exchange. 57 | "Touching other person's pocket" involves coming into contact with or manipulating the pocket of another person's clothing. 58 | "Handshaking" involves greeting another person by clasping and shaking their hand, typically as a gesture of respect or friendliness. 59 | "Walking towards each other" is the act of moving one's body towards another person while standing or walking. 60 | "Walking apart from each other" involves moving away from another person while standing or walking. -------------------------------------------------------------------------------- /module/shift_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | 7 | import sys 8 | sys.path.append("./module/Temporal_shift/") 9 | from cuda.shift import Shift 10 | 11 | # from Temporal_shift.cuda.shift import Shift 12 | sys.path.append("./module/") 13 | from ntu_rgb_d import Graph 14 | 15 | 16 | def import_class(name): 17 | components = name.split('.') 18 | mod = __import__(components[0]) 19 | for comp in components[1:]: 20 | mod = getattr(mod, comp) 21 | return mod 22 | 23 | def conv_init(conv): 24 | nn.init.kaiming_normal(conv.weight, mode='fan_out') 25 | nn.init.constant(conv.bias, 0) 26 | 27 | 28 | def bn_init(bn, scale): 29 | nn.init.constant(bn.weight, scale) 30 | nn.init.constant(bn.bias, 0) 31 | 32 | 33 | class tcn(nn.Module): 34 | def __init__(self, in_channels, out_channels, kernel_size=9, stride=1): 35 | super(tcn, self).__init__() 36 | pad = int((kernel_size - 1) / 2) 37 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), 38 | stride=(stride, 1)) 39 | 40 | self.bn = nn.BatchNorm2d(out_channels) 41 | self.relu = nn.ReLU() 42 | conv_init(self.conv) 43 | bn_init(self.bn, 1) 44 | 45 | def forward(self, x): 46 | x = self.bn(self.conv(x)) 47 | return x 48 | 49 | 50 | class Shift_tcn(nn.Module): 51 | def __init__(self, in_channels, out_channels, kernel_size=9, stride=1): 52 | super(Shift_tcn, self).__init__() 53 | 54 | self.in_channels = in_channels 55 | self.out_channels = out_channels 56 | 57 | self.bn = nn.BatchNorm2d(in_channels) 58 | self.bn2 = nn.BatchNorm2d(in_channels) 59 | bn_init(self.bn2, 1) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.shift_in = Shift(channel=in_channels, stride=1, init_scale=1) 62 | self.shift_out = Shift(channel=out_channels, stride=stride, init_scale=1) 63 | 64 | self.temporal_linear = nn.Conv2d(in_channels, out_channels, 1) 65 | nn.init.kaiming_normal(self.temporal_linear.weight, mode='fan_out') 66 | 67 | def forward(self, x): 68 | x = self.bn(x) 69 | # shift1 70 | x = self.shift_in(x.contiguous()) 71 | x = self.temporal_linear(x) 72 | x = self.relu(x) 73 | # shift2 74 | x = self.shift_out(x) 75 | x = self.bn2(x) 76 | return x 77 | 78 | 79 | class Shift_gcn(nn.Module): 80 | def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3): 81 | super(Shift_gcn, self).__init__() 82 | self.in_channels = in_channels 83 | self.out_channels = out_channels 84 | if in_channels != out_channels: 85 | self.down = nn.Sequential( 86 | nn.Conv2d(in_channels, out_channels, 1), 87 | nn.BatchNorm2d(out_channels) 88 | ) 89 | else: 90 | self.down = lambda x: x 91 | 92 | self.Linear_weight = nn.Parameter(torch.zeros(in_channels, out_channels, requires_grad=True, device='cuda'), requires_grad=True) 93 | nn.init.normal_(self.Linear_weight, 0,math.sqrt(1.0/out_channels)) 94 | 95 | self.Linear_bias = nn.Parameter(torch.zeros(1,1,out_channels,requires_grad=True,device='cuda'),requires_grad=True) 96 | nn.init.constant(self.Linear_bias, 0) 97 | 98 | self.Feature_Mask = nn.Parameter(torch.ones(1,25,in_channels, requires_grad=True,device='cuda'),requires_grad=True) 99 | nn.init.constant(self.Feature_Mask, 0) 100 | 101 | self.bn = nn.BatchNorm1d(25*out_channels) 102 | self.relu = nn.ReLU() 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | conv_init(m) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | bn_init(m, 1) 109 | 110 | index_array = np.empty(25*in_channels).astype(np.int) 111 | for i in range(25): 112 | for j in range(in_channels): 113 | index_array[i*in_channels + j] = (i*in_channels + j + j*in_channels)%(in_channels*25) 114 | self.shift_in = nn.Parameter(torch.from_numpy(index_array),requires_grad=False) 115 | 116 | index_array = np.empty(25*out_channels).astype(np.int) 117 | for i in range(25): 118 | for j in range(out_channels): 119 | index_array[i*out_channels + j] = (i*out_channels + j - j*out_channels)%(out_channels*25) 120 | self.shift_out = nn.Parameter(torch.from_numpy(index_array),requires_grad=False) 121 | 122 | 123 | def forward(self, x0): 124 | n, c, t, v = x0.size() 125 | x = x0.permute(0,2,3,1).contiguous() 126 | 127 | # shift1 128 | x = x.view(n*t,v*c) 129 | x = torch.index_select(x, 1, self.shift_in) 130 | x = x.view(n*t,v,c) 131 | x = x * (torch.tanh(self.Feature_Mask)+1) 132 | 133 | x = torch.einsum('nwc,cd->nwd', (x, self.Linear_weight)).contiguous() # nt,v,c 134 | x = x + self.Linear_bias 135 | 136 | # shift2 137 | x = x.view(n*t,-1) 138 | x = torch.index_select(x, 1, self.shift_out) 139 | x = self.bn(x) 140 | x = x.view(n,t,v,self.out_channels).permute(0,3,1,2) # n,c,t,v 141 | 142 | x = x + self.down(x0) 143 | x = self.relu(x) 144 | return x 145 | 146 | 147 | class TCN_GCN_unit(nn.Module): 148 | def __init__(self, in_channels, out_channels, A, stride=1, residual=True): 149 | super(TCN_GCN_unit, self).__init__() 150 | self.gcn1 = Shift_gcn(in_channels, out_channels, A) 151 | self.tcn1 = Shift_tcn(out_channels, out_channels, stride=stride) 152 | self.relu = nn.ReLU() 153 | 154 | if not residual: 155 | self.residual = lambda x: 0 156 | 157 | elif (in_channels == out_channels) and (stride == 1): 158 | self.residual = lambda x: x 159 | else: 160 | self.residual = tcn(in_channels, out_channels, kernel_size=1, stride=stride) 161 | 162 | def forward(self, x): 163 | x = self.tcn1(self.gcn1(x)) + self.residual(x) 164 | return self.relu(x) 165 | 166 | 167 | class Model(nn.Module): 168 | def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args={'labeling_mode': 'spatial'}, in_channels=3): 169 | super(Model, self).__init__() 170 | 171 | # if graph is None: 172 | # raise ValueError() 173 | # else: 174 | # # Graph = import_class(graph) 175 | # self.graph = Graph(**graph_args) 176 | self.graph = Graph(**graph_args) 177 | A = self.graph.A 178 | self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point) 179 | 180 | self.l1 = TCN_GCN_unit(3, 64, A, residual=False) 181 | self.l2 = TCN_GCN_unit(64, 64, A) 182 | self.l3 = TCN_GCN_unit(64, 64, A) 183 | self.l4 = TCN_GCN_unit(64, 64, A) 184 | self.l5 = TCN_GCN_unit(64, 128, A, stride=2) 185 | self.l6 = TCN_GCN_unit(128, 128, A) 186 | self.l7 = TCN_GCN_unit(128, 128, A) 187 | self.l8 = TCN_GCN_unit(128, 256, A, stride=2) 188 | self.l9 = TCN_GCN_unit(256, 256, A) 189 | self.l10 = TCN_GCN_unit(256, 256, A) 190 | 191 | self.fc = nn.Linear(256, num_class) 192 | nn.init.normal(self.fc.weight, 0, math.sqrt(2. / num_class)) 193 | bn_init(self.data_bn, 1) 194 | 195 | def forward(self, x): 196 | N, C, T, V, M = x.size() # B, 3, 50, 25, 2 197 | 198 | x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T) 199 | x = self.data_bn(x) 200 | x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V) 201 | 202 | x = self.l1(x) 203 | x = self.l2(x) 204 | x = self.l3(x) 205 | x = self.l4(x) 206 | x = self.l5(x) 207 | x = self.l6(x) 208 | x = self.l7(x) 209 | x = self.l8(x) 210 | x = self.l9(x) 211 | x = self.l10(x) 212 | 213 | # N*M,C,T,V 214 | c_new = x.size(1) 215 | x = x.view(N, M, c_new, -1) 216 | x = x.mean(3).mean(1) 217 | 218 | # x = F.avg_pool2d(x, x.size()[2:]) 219 | # # print(x.shape) 220 | # x = x.view(N, M, -1).mean(dim=1) 221 | 222 | # return self.fc(x) 223 | return x 224 | 225 | if __name__ == '__main__': 226 | model = Model() 227 | model.cuda() 228 | data = torch.randn(32, 3, 64, 25, 2).cuda() # 一定要cuda, 229 | output = model(data) 230 | print(output.shape) 231 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-Shot Skeleton-Based Action Recognition With Prototype-Guided Feature Alignment [TIP 2025] 2 | > [Kai Zhou](https://kaai520.github.io), [Shuhai Zhang](https://zshsh98.github.io), [Zeng You](https://www.youzeng.com.cn), [Jinwu Hu](https://fhujinwu.github.io), [Mingkui Tan](https://tanmingkui.github.io/), and [Fei Liu](https://scholar.google.com/citations?user=gC-YMYgAAAAJ)\ 3 | South China University of Technology 4 | 5 | 6 | 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/zero-shot-skeleton-based-action-recognition-2/zero-shot-skeletal-action-recognition-on-pku)](https://paperswithcode.com/sota/zero-shot-skeletal-action-recognition-on-pku?p=zero-shot-skeleton-based-action-recognition-2) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/zero-shot-skeleton-based-action-recognition-2/zero-shot-skeletal-action-recognition-on-ntu)](https://paperswithcode.com/sota/zero-shot-skeletal-action-recognition-on-ntu?p=zero-shot-skeleton-based-action-recognition-2) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/zero-shot-skeleton-based-action-recognition-2/zero-shot-skeletal-action-recognition-on-ntu-1)](https://paperswithcode.com/sota/zero-shot-skeletal-action-recognition-on-ntu-1?p=zero-shot-skeleton-based-action-recognition-2) 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/zero-shot-skeleton-based-action-recognition-2/one-shot-3d-action-recognition-on-ntu-rgbd)](https://paperswithcode.com/sota/one-shot-3d-action-recognition-on-ntu-rgbd?p=zero-shot-skeleton-based-action-recognition-2) 12 | 13 | This is an official PyTorch implementation of **"Zero-Shot Skeleton-Based Action Recognition With Prototype-Guided Feature Alignment" in [IEEE TIP 2025](https://doi.org/10.1109/TIP.2025.3586487)** (DOI: 10.1109/TIP.2025.3586487). 14 | 15 | ## Abstract 16 | Zero-shot skeleton-based action recognition aims to classify unseen skeleton-based human actions without prior exposure to such categories during training. This task is extremely challenging due to the difficulty in generalizing from known to unknown actions. Previous studies typically use two-stage training: pre-training skeleton encoders on seen action categories using cross-entropy loss and then aligning pre-extracted skeleton and text features, enabling knowledge transfer to unseen classes through skeleton-text alignment and language models' generalization. 17 | However, their efficacy is hindered by 1) insufficient discrimination for skeleton features, as the fixed skeleton encoder fails to capture necessary alignment information for effective skeleton-text alignment; 2) the neglect of alignment bias between skeleton and unseen text features during testing. 18 | To this end, we propose a prototype-guided feature alignment paradigm for zero-shot skeleton-based action recognition, termed PGFA. 19 | Specifically, we develop an end-to-end cross-modal contrastive training framework to improve skeleton-text alignment, ensuring sufficient discrimination for skeleton features. Additionally, we introduce a prototype-guided text feature alignment strategy to mitigate the adverse impact of the distribution discrepancy during testing. 20 | We provide a theoretical analysis to support our prototype-guided text feature alignment strategy and empirically evaluate our overall PGFA on three well-known datasets. 21 | Compared with the top competitor SMIE method, our PGFA achieves absolute accuracy improvements of 22.96\%, 12.53\%, and 18.54\% on the NTU-60, NTU-120, and PKU-MMD datasets, respectively. 22 | 23 | ## Framework 24 | ### Training Framework 25 | ![traing](./assets/training.png) 26 | ### Testing Framework 27 | ![testing](./assets/testing.png) 28 | 29 | prototype 30 | 31 | ## Requirements 32 | ![python = 3.11](https://img.shields.io/badge/python-3.7.11-green) 33 | ``` 34 | sacred 35 | tqdm 36 | einops 37 | torch==1.13.1 38 | logging 39 | sentence-transformers 40 | pprint 41 | scikit-learn 42 | ``` 43 | 44 | ## Installation 45 | ```bash 46 | # Install the python libraries 47 | $ cd PGFA 48 | $ pip install -r requirements.txt 49 | 50 | # Install the ShiftGCN 51 | $ cd ./module/Temporal_shift 52 | $ bash run.sh 53 | ``` 54 | 55 | Please consult the official installation tutorial (e.g., [ShiftGCN](https://github.com/kchengiva/Shift-GCN) and [PyTorch](https://pytorch.org/get-started/previous-versions/)) if you experience any difficulties. 56 | 57 | ## Data Preparation 58 | We apply the same dataset processing as [SMIE](https://github.com/YujieOuO/SMIE). You can download **data.zip** in [BaiduYun](https://pan.baidu.com/s/1xVdL4vTsZYEBPTqPzYpa0g?pwd=pgfa)/[Hugging Face](https://huggingface.co/kaai520/PGFA-Data/resolve/main/data.zip). Please download and extract it to the current folder (PGFA). 59 | 60 | The subfolder "zero-shot" of "data" contains the processed skeleton data for each dataset, already split into seen and unseen categories. The subfolder "language" contains the pre-extracted text features obtained using Sentence-Bert. 61 | 62 | * [dataset]_embeddings.npy: based on label names using Sentence-Bert. 63 | * [dataset]_des_embeddings.npy: based on complete descriptions using Sentence-Bert. 64 | * [dataset]_ska_embeddings.npy: based on skeleton-focused descriptions using Sentence-Bert. 65 | 66 | If you want to process the zero-shot data by yourself, please refer to the Data Preparation section in [SMIE](https://github.com/YujieOuO/SMIE). We additionally process one-shot data; for example, 'data/zeroshot/ntu60/one_shot_full_10' contains the one-shot skeleton data of NTU-60. 67 | 68 | ## Action Label Descriptions 69 | The total label descriptions can be found in ./descriptions. 70 | 71 | ## Different Experiment Settings 72 | Our PGFA employs two experiment setting for zero-shot learning and one experiment setting for one-shot learning. 73 | * Setting 1 : three datasets are used (NTU-60, NTU-120, PKU-MMD), and each dataset have three random splits. The skeleton feature extractor is classical ST-GCN. 74 | * Setting 2: two datasets are used, split_5 and split_12 on NTU-60, and split_10 and split_24 on NTU-120. The skelelton feature extractor is Shift-GCN. 75 | * One-shot setting: three datasets are used, one_shot_full_10 on NTU-60, one_shot_full_20 on NTU-120, and one_shot_pku_10 on PKU-MMD. The skeleton feature extractor is classical ST-GCN. 76 | 77 | ### Setting 1 78 | 79 | #### Training & Testing 80 | Example for training and testing on NTU-60 split_1. 81 | ```bash 82 | # Setting 1 83 | $ python main.py with 'train_mode="main"' 84 | ``` 85 | You can change some settings of config.py. 86 | 87 | ### Setting 2 88 | #### Training & Testing 89 | Example for training and testing on NTU-60 split_5 data. 90 | ```bash 91 | # Setting 2 92 | $ python main.py with 'train_mode="sota"' 93 | ``` 94 | You can also choose different split id of config.py (sota compare part). 95 | 96 | ### One-shot setting 97 | #### Training & Testing 98 | Example for training and testing on NTU-60 one_shot_full_10 data. 99 | ```bash 100 | # One-shot setting 101 | $ python one_shot_main.py 102 | ``` 103 | You can also choose different split id of one_shot_config.py. 104 | 105 | 109 | 110 | ## Checkpoints 111 | We also provide the model checkpoints trained under Setting 1. You can download **ckpts.zip** from [BaiduYun](https://pan.baidu.com/s/1xA7ph9qg1c2fRS3pUz3PsQ?pwd=pgfa)/[Google Drive](https://drive.google.com/file/d/1rhB_6fdgOpYpg4Ikdtl53ZqfNCu8LMvu/view?usp=sharing)/[Hugging Face](https://huggingface.co/datasets/kaai520/PGFA-Data/resolve/main/ckpts.zip?download=true), and extract it to the current folder (PGFA). 112 | 113 | Example for loading the checkpoint and testing on NTU-60 split_1. 114 | ```bash 115 | # Setting 1 116 | $ python test_main.py with 'train_mode="main"' 117 | ``` 118 | You can change some settings of test_config.py. 119 | 120 | 121 | ## Acknowledgement 122 | * The codebase is from [MS2L](https://github.com/LanglandsLin/MS2L). 123 | * The skeleton backbones are based on [ST-GCN](https://github.com/yysijie/st-gcn/blob/master/OLD_README.md) and [ShiftGCN](https://github.com/kchengiva/Shift-GCN). 124 | * The text feature is based on [Sentence-Bert](https://github.com/UKPLab/sentence-transformers). 125 | * The baseline methods are from [SMIE](https://github.com/YujieOuO/SMIE). 126 | 127 | ## Licence 128 | This project is licensed under the terms of the MIT license. 129 | 130 | ## Citation 131 | 132 | If you use this code in your research or implementations, please cite the following paper: 133 | 134 | ```bibtex 135 | @article{zhou2025pgfa, 136 | title={Zero-Shot Skeleton-Based Action Recognition With Prototype-Guided Feature Alignment}, 137 | author={Kai Zhou and Shuhai Zhang and Zeng You and Jinwu Hu and Mingkui Tan and Fei Liu}, 138 | journal={IEEE Transactions on Image Processing}, 139 | year={2025}, 140 | volume={34}, 141 | pages={4602-4617}, 142 | publisher={IEEE}, 143 | doi={10.1109/TIP.2025.3586487} 144 | } 145 | ``` 146 | 147 | ## Contact 148 | For any questions, feel free to contact: kayjoe0723@gmail.com 149 | -------------------------------------------------------------------------------- /one_shot_main.py: -------------------------------------------------------------------------------- 1 | from one_shot_config import * 2 | # from model import * 3 | from dataset import DataSet 4 | from one_shot_logger import Log 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import random 11 | from math import pi, cos 12 | from tqdm import tqdm 13 | 14 | from module.gcn.st_gcn import Model 15 | from module.shift_gcn import Model as ShiftGCN 16 | from module.adapter import Adapter, Linear 17 | from KLLoss import KLLoss, KDLoss 18 | from tool import gen_label, create_logits, get_acc, create_sim_matrix, gen_label_from_text_sim, get_m_theta, get_acc_v2 19 | 20 | def setup_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | setup_seed(1314) # 0->2025 28 | 29 | # %% 30 | class Processor: 31 | 32 | @ex.capture 33 | def load_data(self, train_list, train_label, test_list, test_label, batch_size, language_path, one_shot_exemplar_data_path): 34 | self.dataset = dict() 35 | self.data_loader = dict() 36 | self.best_epoch = -1 37 | self.best_acc = -1 38 | self.dim_loss = -1 39 | self.test_acc = -1 40 | self.test_aug_acc = -1 41 | self.best_aug_acc = -1 42 | self.best_aug_epoch = -1 43 | 44 | self.full_language = np.load(language_path) 45 | self.full_language = torch.Tensor(self.full_language) 46 | self.full_language = self.full_language.cuda() 47 | 48 | self.one_shot_exemplar_data = np.load(one_shot_exemplar_data_path) 49 | self.one_shot_exemplar_data = torch.Tensor(self.one_shot_exemplar_data).cuda() 50 | 51 | self.dataset['train'] = DataSet(train_list, train_label) 52 | self.dataset['test'] = DataSet(test_list, test_label) 53 | 54 | self.data_loader['train'] = torch.utils.data.DataLoader( 55 | dataset=self.dataset['train'], 56 | batch_size=batch_size, 57 | num_workers=16, 58 | shuffle=True, 59 | drop_last=True) 60 | 61 | self.data_loader['test'] = torch.utils.data.DataLoader( 62 | dataset=self.dataset['test'], 63 | batch_size=64, 64 | num_workers=16, 65 | shuffle=False) 66 | 67 | def load_weights(self, model=None, weight_path=None): 68 | pretrained_dict = torch.load(weight_path) 69 | model.load_state_dict(pretrained_dict) 70 | 71 | def adjust_learning_rate(self,optimizer,current_epoch, max_epoch,lr_min=0,lr_max=0.1,warmup_epoch=15, loss_mode='step', step=[50, 80]): 72 | 73 | if current_epoch < warmup_epoch: 74 | lr = lr_max * current_epoch / warmup_epoch 75 | elif loss_mode == 'cos': 76 | lr = lr_min + (lr_max-lr_min)*(1 + cos(pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2 77 | elif loss_mode == 'step': 78 | lr = lr_max * (0.1 ** np.sum(current_epoch >= np.array(step))) 79 | else: 80 | raise Exception('Please check loss_mode!') 81 | 82 | for param_group in optimizer.param_groups: 83 | param_group['lr'] = lr 84 | # if i == 0: 85 | # param_group['lr'] = lr * 0.1 86 | # else: 87 | # param_group['lr'] = lr 88 | 89 | def layernorm(self, feature): 90 | 91 | num = feature.shape[0] 92 | mean = torch.mean(feature, dim=1).reshape(num, -1) 93 | var = torch.var(feature, dim=1).reshape(num, -1) 94 | out = (feature-mean) / torch.sqrt(var) 95 | 96 | return out 97 | 98 | @ex.capture 99 | def load_model(self,in_channels,hidden_channels,hidden_dim, 100 | dropout,graph_args,edge_importance_weighting, loss_type): 101 | self.encoder = Model(in_channels=in_channels, hidden_channels=hidden_channels, 102 | hidden_dim=hidden_dim,dropout=dropout, 103 | graph_args=graph_args, 104 | edge_importance_weighting=edge_importance_weighting, 105 | ) 106 | self.encoder = self.encoder.cuda() 107 | self.adapter = Linear().cuda() 108 | if loss_type == "kl": 109 | self.loss = KLLoss().cuda() 110 | else: 111 | raise Exception('loss_type Error!') 112 | self.logit_scale = self.adapter.get_logit_scale() 113 | self.logit_scale_v2 = self.adapter.get_logit_scale_v2() 114 | 115 | # self.model = MI(visual_size, language_size).cuda() 116 | 117 | 118 | @ex.capture 119 | def load_optim(self, lr, epoch_num, weight_decay): 120 | # self.optimizer = torch.optim.Adam([ 121 | # {'params': self.encoder.parameters()}, 122 | # {'params': self.model.parameters()}], 123 | # lr=lr, 124 | # weight_decay=weight_decay, 125 | # ) 126 | # self.optimizer = torch.optim.Adam([ 127 | # {'params': self.encoder.parameters()}, 128 | # {'params': self.adapter.parameters()}], 129 | # lr=lr, 130 | # weight_decay=weight_decay, 131 | # ) 132 | self.optimizer = torch.optim.SGD([ 133 | {'params': self.encoder.parameters()}, 134 | {'params': self.adapter.parameters()}], 135 | lr=lr, 136 | weight_decay=weight_decay, 137 | momentum=0.9, 138 | nesterov=False 139 | ) 140 | # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 100) 141 | 142 | @ex.capture 143 | def optimize(self, epoch_num, DA): # print -> log.info 144 | self.log.info("main track") 145 | for epoch in range(epoch_num): 146 | self.train_epoch(epoch) 147 | with torch.no_grad(): 148 | self.test_epoch(epoch=epoch) 149 | self.log.info("epoch [{}] train loss: {}".format(epoch,self.dim_loss)) 150 | self.log.info("epoch [{}] test acc: {}".format(epoch,self.test_acc)) 151 | self.log.info("epoch [{}] gets the best acc: {}".format(self.best_epoch,self.best_acc)) 152 | if DA: 153 | self.log.info("epoch [{}] DA test acc: {}".format(epoch,self.test_aug_acc)) 154 | self.log.info("epoch [{}] gets the best DA acc: {}".format(self.best_aug_epoch,self.best_aug_acc)) 155 | # if epoch > 5: 156 | # self.log.info("epoch [{}] test acc: {}".format(epoch,self.test_acc)) 157 | # self.log.info("epoch [{}] gets the best acc: {}".format(self.best_epoch,self.best_acc)) 158 | # else: 159 | # self.log.info("epoch [{}] : warm up epoch.".format(epoch)) 160 | 161 | @ex.capture 162 | def train_epoch(self, epoch, lr, loss_mode, step, loss_type, alpha, beta, m, fix_encoder): 163 | self.encoder.train() # eval -> train 164 | if fix_encoder: 165 | self.encoder.eval() 166 | self.adapter.train() 167 | self.adjust_learning_rate(self.optimizer, current_epoch=epoch, max_epoch=100, lr_max=lr, warmup_epoch=5, loss_mode=loss_mode, step=step) 168 | running_loss = [] 169 | loader = self.data_loader['train'] 170 | for data, label in tqdm(loader): 171 | data = data.type(torch.FloatTensor).cuda() 172 | # print(data.shape) #128,3,50,25,2 173 | # label = label.type(torch.LongTensor).cuda() 174 | label_g = gen_label(label) 175 | label = label.type(torch.LongTensor).cuda() 176 | # print(label.shape) # 128 177 | # print(label) # int 178 | seen_language = self.full_language[label] # 128, 768 179 | # print(seen_language.shape) 180 | 181 | feat = self.encoder(data) 182 | if fix_encoder: 183 | feat = feat.detach() 184 | skleton_feat = self.adapter(feat) 185 | if loss_type == "kl": 186 | logits_per_skl, logits_per_text = create_logits(skleton_feat, seen_language, self.logit_scale, exp=True) 187 | ground_truth = torch.tensor(label_g, dtype=skleton_feat.dtype).cuda() 188 | # ground_truth = gen_label_from_text_sim(seen_language) 189 | loss_skls = self.loss(logits_per_skl, ground_truth) 190 | loss_texts = self.loss(logits_per_text, ground_truth) 191 | loss = (loss_skls + loss_texts) / 2 192 | else: 193 | raise Exception('loss_type Error!') 194 | 195 | running_loss.append(loss) 196 | self.optimizer.zero_grad() 197 | loss.backward() 198 | self.optimizer.step() 199 | 200 | running_loss = torch.tensor(running_loss) 201 | self.dim_loss = running_loss.mean().item() 202 | 203 | @ex.capture 204 | def test_epoch(self, unseen_label, epoch, DA, support_factor): 205 | self.encoder.eval() 206 | self.adapter.eval() 207 | 208 | loader = self.data_loader['test'] 209 | y_true = [] 210 | y_pred = [] 211 | acc_list = [] 212 | ent_list = [] 213 | feat_list = [] 214 | old_pred_list = [] 215 | for data, label in tqdm(loader): 216 | 217 | # y_t = label.numpy().tolist() 218 | # y_true += y_t 219 | 220 | data = data.type(torch.FloatTensor).cuda() 221 | label = label.type(torch.LongTensor).cuda() 222 | # unseen_language = self.full_language[unseen_label] 223 | one_shot_skeleton = self.adapter(self.encoder(self.one_shot_exemplar_data)) 224 | # inference 225 | feature = self.encoder(data) 226 | feature = self.adapter(feature) 227 | if DA: 228 | # acc_batch, pred = get_acc(feature, unseen_language, unseen_label, label) 229 | acc_batch, pred, old_pred, ent, feat = get_acc_v2(feature, one_shot_skeleton, unseen_label, label) 230 | ent_list.append(ent) 231 | feat_list.append(feat) 232 | old_pred_list.append(old_pred) 233 | else: 234 | acc_batch, pred = get_acc(feature, one_shot_skeleton, unseen_label, label) 235 | 236 | # y_p = pred.cpu().numpy().tolist() 237 | # y_pred += y_p 238 | 239 | 240 | acc_list.append(acc_batch) 241 | 242 | acc_list = torch.tensor(acc_list) 243 | acc = acc_list.mean() 244 | if acc > self.best_acc: 245 | self.best_acc = acc 246 | self.best_epoch = epoch 247 | self.save_model() 248 | # y_true = np.array(y_true) 249 | # y_pred = np.array(y_pred) 250 | # np.save("y_true_3.npy",y_true) 251 | # np.save("y_pred_3.npy",y_pred) 252 | # print("save ok!") 253 | self.test_acc = acc 254 | 255 | if DA: 256 | ent_all = torch.cat(ent_list) 257 | feat_all = torch.cat(feat_list) 258 | old_pred_all = torch.cat(old_pred_list) 259 | z_list = [] 260 | for i in range(len(unseen_label)): 261 | mask = old_pred_all == i 262 | class_support_set = feat_all[mask] 263 | class_ent = ent_all[mask] 264 | class_len = class_ent.shape[0] 265 | if int(class_len*support_factor) < 1: 266 | z = self.full_language[unseen_label[i:i+1]] 267 | else: 268 | _, indices = torch.topk(-class_ent, int(class_len*support_factor)) 269 | z = torch.mean(class_support_set[indices], dim=0, keepdim=True) 270 | z_list.append(z) 271 | 272 | z_tensor = torch.cat(z_list) 273 | aug_acc_list = [] 274 | for data, label in tqdm(loader): 275 | # y_t = label.numpy().tolist() 276 | # y_true += y_t 277 | 278 | data = data.type(torch.FloatTensor).cuda() 279 | label = label.type(torch.LongTensor).cuda() 280 | one_shot_skeleton = z_tensor 281 | # inference 282 | feature = self.encoder(data) 283 | feature = self.adapter(feature) 284 | # acc_batch, pred = get_acc(feature, unseen_language, unseen_label, label) 285 | acc_batch, pred = get_acc(feature, one_shot_skeleton, unseen_label, label) 286 | 287 | # y_p = pred.cpu().numpy().tolist() 288 | # y_pred += y_p 289 | aug_acc_list.append(acc_batch) 290 | aug_acc = torch.tensor(aug_acc_list).mean() 291 | if aug_acc > self.best_aug_acc: 292 | self.best_aug_acc = aug_acc 293 | self.best_aug_epoch = epoch 294 | self.save_model() 295 | self.test_aug_acc = aug_acc 296 | 297 | 298 | 299 | def initialize(self): 300 | self.load_data() 301 | self.load_model() 302 | self.load_optim() 303 | self.log = Log() 304 | 305 | @ex.capture 306 | def save_model(self, save_path): 307 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 308 | torch.save({'encoder':self.encoder.state_dict(), 'adapter':self.adapter.state_dict()}, save_path) 309 | 310 | def start(self): 311 | self.initialize() 312 | self.optimize() 313 | # self.save_model() 314 | 315 | 316 | 317 | 318 | # %% 319 | @ex.automain 320 | def main(track): 321 | if "main" in track: 322 | p = Processor() 323 | p.start() 324 | -------------------------------------------------------------------------------- /descriptions/ntu120_des.txt: -------------------------------------------------------------------------------- 1 | "Drink water" refers to the act of taking in water by mouth to hydrate or quench thirst as a human action. 2 | "Eat meal/snack" is consuming food for sustenance, nourishment, or pleasure. 3 | "Brushing teeth" is the act of cleaning teeth and gums using a toothbrush and toothpaste for oral hygiene. 4 | "Brushing hair" involves using a brush or comb to groom and style hair. 5 | "Drop" refers to accidentally or intentionally releasing an object from one's grasp, allowing it to fall to the ground or another surface. 6 | “Pickup” involves lifting or grasping an object with one's hands. 7 | "Throw" is the act of projecting an object through the air by using the hand or arm. 8 | "Sitting down" involves moving from a standing position to a seated position. 9 | "Standing up (from sitting position)" involves transitioning from a seated position to an upright position by pushing up with the legs. 10 | "Clapping" involves striking one's palms together to make a sound, typically as a form of applause or approval. 11 | "Reading" involves interpreting and understanding written or printed words and symbols. 12 | "Writing" involves using a pen, pencil, or other writing tool to create letters, words, or symbols on a surface such as paper or a screen. 13 | "Tear up paper" involves using one's hands to rip apart paper into smaller pieces. 14 | "Wear jacket" is the act of putting on a garment designed to cover the upper body and arms. 15 | "Take off jacket" involves removing a jacket or outer garment from one's body. 16 | "Wear a shoe" involves putting a shoe onto one's foot for protection or fashion purposes. 17 | "Take off a shoe" involves removing a shoe from one's foot using one's hands. 18 | "Wear on glasses" involves putting eyeglasses or spectacles on one's face to correct vision or protect the eyes. 19 | "Take off glasses" involves removing eyeglasses from one's face with one's hands. 20 | "Put on a hat/cap" involves placing headwear over one's head, typically for protection from the sun. 21 | "Take off a hat/cap" involves removing a piece of headwear from one's head, typically by lifting or pulling it off. 22 | "Cheer up" involves trying to make oneself or someone else feel more positive or happy. 23 | "Hand waving" is a gesture involving moving one's hand or hands to signal, greet, or draw attention. 24 | "Kicking something" is a human action that involves striking an object with the foot. 25 | "Reach into pocket" is the act of extending one's hand into a pocket to retrieve an item or to adjust the pocket contents. 26 | "Hopping (one foot jumping)" is a human action that involves jumping repeatedly on one foot. 27 | "Jumping up" involves leaping upward from a standing position as part of physical training. 28 | "Make a phone call/Answer phone" involves using a phone to either initiate or receive a phone conversation. 29 | "Playing with phone/tablet" involves manipulating electronic devices through various interactions such as tapping, swiping, or scrolling. 30 | "Typing on a keyboard" is the act of pressing keys on a keyboard to input information or commands into a computer or other electronic device. 31 | "Pointing to something with finger" involves extending the arm and using the finger to indicate or draw attention to a specific object or location. 32 | "Taking a selfie" is using a camera to take a self-portrait photograph, typically for sharing on social media or personal keepsake. 33 | "Checking time (from watch)" involves looking at a wristwatch or other timekeeping device to determine the current time. 34 | "Rubbing two hands together" involves moving one's hands back and forth against each other, often for warmth. 35 | "Nod headbow" involves lowering one's head briefly in a sign of respect, greeting, or acknowledgment. 36 | "Shake head" is moving the head from side to side in a rapid or deliberate manner, often to indicate disagreement. 37 | "Wiping face" involves using a cloth or one's hands to remove dirt, sweat, or moisture from one's face. 38 | "Salute" is a gesture of respect or greeting, typically performed by raising one's hand to the forehead or brim of a hat. 39 | "Putting the palms together" is the act of pressing one's hands together, often as a sign of respect, greeting, or prayer. 40 | "Crossing hands in front (saying stop)" is a gesture where one crosses their arms in front of their body and says "stop". 41 | "Sneeze/cough" is the involuntary or deliberate act of expelling air and sometimes mucus from the nose and mouth, typically due to illness or irritation. 42 | "Staggering" is the unsteady movement or swaying of the body, usually caused by intoxication, dizziness, or fatigue. 43 | "Falling" is the sudden loss of balance resulting in a person dropping or collapsing to the ground or a lower surface. 44 | "Touching head (headache)" involves placing one's hands on the head to alleviate pain or discomfort in the head or neck area. 45 | "Touching chest (stomachache/heart pain)" is a human action that involves placing a hand on the chest to alleviate discomfort or pain in the chest area. 46 | "Touching back (backache)" involves placing one's hands on the back to alleviate or investigate pain, discomfort, or tension in the back. 47 | "Touching neck (neckache)" refers to the action of pressing or rubbing the neck with one's hands, often to relieve pain or discomfort in the neck area. 48 | "Nausea or vomiting" is a condition characterized by the urge to vomit or the act of forcefully expelling stomach contents through the mouth. 49 | "Using a fan (with hand or paper) to feel warm" involves creating airflow with a handheld or paper fan for the purpose of cooling down. 50 | "Punching/Slapping other person" is the act of striking another person with one's fist or open hand, respectively, typically for physical harm or aggression. 51 | "Kicking other person" involves striking another person with one's foot, often for physical harm or aggression. 52 | "Pushing other person" involves applying force to another person with one's body, typically with the intent to move or manipulate them. 53 | "Patting on back of other person" is the act of lightly striking or tapping another person's back with one's hand. 54 | "Pointing finger at the other person" involves extending one's finger towards another person, typically to indicate something or express disapproval. 55 | "Hugging other person" involves embracing another person with one's arms, often as a gesture of affection or comfort. 56 | "Giving something to other person" involves transferring an object or item to another person, typically as a gift or as part of an exchange. 57 | "Touching other person's pocket" involves coming into contact with or manipulating the pocket of another person's clothing. 58 | "Handshaking" involves greeting another person by clasping and shaking their hand, typically as a gesture of respect or friendliness. 59 | "Walking towards each other" is the act of moving one's body towards another person while standing or walking. 60 | "Walking apart from each other" involves moving away from another person while standing or walking. 61 | "Putting on headphones" involves placing earphones or headphones over one's ears to listen to music or other audio content. 62 | "Taking off headphones" involves removing earphones or headphones from one's ears. 63 | "Shooting at the basket" is the act of throwing a ball towards a basketball hoop with the intent of scoring points. 64 | "Bouncing a ball" involves repeatedly throwing a ball against a surface such as the ground or a wall, often as a form of exercise or entertainment. 65 | "Tennis bat swing" involves swinging a tennis racket to hit a ball during a game or practice session. 66 | "Juggling table tennis balls" is the act of tossing and catching multiple table tennis balls in the air, often for entertainment or as a skill-building exercise. 67 | "Hushing" involves making a gesture or sound to indicate to others to be quiet or silent. 68 | "Flicking hair" involves using one's hand or fingers to quickly move or toss one's hair, often for grooming or styling purposes. 69 | "Thumbs up" is the act of extending one's thumb upward, typically as a sign of approval or agreement. 70 | "Thumbs down" involves extending one's thumb downward, typically as a sign of disapproval or disagreement. 71 | "Making OK sign" involves forming a circle with one's thumb and index finger, with the other fingers extended, typically as a sign of agreement. 72 | "Making victory sign" involves forming a V-shape with one's index and middle finger, typically as a sign of victory or success. 73 | "Stapling book" involves using a stapler to bind multiple pages together, typically for the purpose of creating a booklet or document. 74 | "Counting money" involves physically sorting and tallying currency notes or coins, typically as part of a financial transaction. 75 | "Cutting nails" involves trimming one's fingernails or toenails with nail clippers or scissors, typically for grooming or hygiene purposes. 76 | "Cutting paper (using scissors)" involves using scissors to cut through paper or other thin materials. 77 | "Snapping fingers" involves creating a snapping sound by quickly pressing one's thumb and middle finger together and then releasing them. 78 | "Opening bottle" involves removing the cap or seal from a bottle, typically with the aid of a bottle opener or one's hands. 79 | "Sniffing" involves inhaling through one's nose to detect or identify a scent or odor. 80 | "Squatting down" involves bending one's knees and lowering one's body to a lower position, typically for the purpose of sitting, picking something up. 81 | "Tossing a coin" involves flipping a coin in the air and letting it fall to the ground, typically to determine a decision or outcome. 82 | "Folding paper" involves creasing paper to create folds or pleats, typically for origami or paper crafts. 83 | "Balling up paper" involves crumpling or compressing paper into a compact ball, typically for disposal or as a stress-relieving exercise. 84 | "Playing with a magic cube" involves manipulating and solving a Rubik's Cube or other similar puzzle toy. 85 | "Applying cream on face" involves spreading cosmetic cream or lotion onto one's face, typically for moisturizing or cosmetic purposes. 86 | "Applying cream on hand back" involves spreading cosmetic cream or lotion onto the back of one's hand, typically for moisturizing or cosmetic purposes. 87 | "Putting on a bag" involves placing a bag or backpack onto one's back or shoulder, typically to carry items or belongings. 88 | "Taking off a bag" involves removing a bag or backpack from one's back or shoulder. 89 | "Putting something into a bag" involves placing an object or item into a bag or backpack, typically for carrying or transporting. 90 | "Taking something out of a bag" involves removing an object or item from a bag or backpack. 91 | "Opening a box" involves removing the lid or cover from a box or container, typically to access or retrieve its contents. 92 | "Moving heavy objects" involves exerting physical effort to lift, push, or pull a heavy object from one location to another. 93 | "Shaking fist" involves making a fist and shaking it in a rapid or aggressive manner, typically as a sign of anger or frustration. 94 | "Throwing up a cap/hat" involves tossing a cap or hat into the air and catching it, typically for entertainment or as a celebratory gesture. 95 | "Raising hands up" involves lifting both hands upwards, typically as a sign of surrender, excitement, or joy. 96 | "Crossing arms" involves placing one arm over the other across one's chest, typically as a defensive or protective gesture. 97 | "Making arm circles" involves rotating one's arms in circular motions, typically as part of a warm-up or exercise routine. 98 | "Swinging arms" involves moving one's arms back and forth in a rhythmic or repetitive motion. 99 | "Running on the spot" involves jogging in place without moving forward, typically as a form of exercise or warm-up activity. 100 | "Butt kicks (kick backward)" involves kicking one's legs backwards, aiming to touch one's heels to one's buttocks. 101 | "Cross toe touch" involves touching one's toes with the opposite hand while crossing one's leg over the other. 102 | "Side kick" involves extending one's leg sideways to kick, typically as a form of martial arts or self-defense. 103 | "Yawning" involves opening one's mouth wide and inhaling deeply, typically as a sign of tiredness or boredom. 104 | "Stretching oneself" involves extending one's limbs or muscles to alleviate tension or discomfort. 105 | "Blowing nose" involves expelling mucus or other materials from one's nose through forceful exhalation, typically using a tissue or handkerchief. 106 | "Hitting other person with something" involves striking another person with an object or item, typically as a form of aggression or defense. 107 | "Wielding knife towards other person" involves threatening or brandishing a knife in the direction of another person. 108 | "Knocking over other person (hit with body)" involves using one's body to forcefully push or strike another person. 109 | "Grabbing other person's stuff" involves taking hold of or snatching an object or item that belongs to another person without their permission or consent. 110 | "Shooting at other person with a gun" involves discharging a firearm in the direction of another person. 111 | "Stepping on foot" involves placing one's foot on top of another person's foot, typically accidentally or as a form of intimidation. 112 | "High-five" involves slapping one's open hand against another person's hand as a celebratory gesture or form of congratulations. 113 | "Cheers and drink" is raising and clinking glasses or cups with another person as a gesture of good wishes or celebration, and then taking a sip of the drink. 114 | "Carrying something with other person" is holding or transporting an object or item in conjunction with another person. 115 | "Taking a photo of other person" is capturing an image of another person, typically for commemoration or documentation. 116 | "Following other person" is walking or moving in the same direction as another person. 117 | "Whispering in other person's ear" is speaking softly or confidentially into another person's ear. 118 | "Exchanging things with other person" is giving or receiving an object or item in return for something else. 119 | "Supporting somebody with hand" is providing physical assistance or comfort to another person by placing one's hand on their back or shoulder. 120 | "Finger-guessing game (playing rock-paper-scissors)" is a game played by revealing one of three hand signs (rock, paper, or scissors). -------------------------------------------------------------------------------- /one_shot_test_main.py: -------------------------------------------------------------------------------- 1 | from one_shot_test_config import * 2 | # from model import * 3 | from dataset import DataSet 4 | from one_shot_test_logger import Log 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import random 11 | from math import pi, cos 12 | from tqdm import tqdm 13 | 14 | from module.gcn.st_gcn import Model 15 | from module.shift_gcn import Model as ShiftGCN 16 | from module.adapter import Adapter, Linear 17 | from KLLoss import KLLoss, KDLoss 18 | from tool import gen_label, create_logits, get_acc, create_sim_matrix, gen_label_from_text_sim, get_m_theta, get_acc_v2 19 | 20 | def setup_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | setup_seed(0) # 0->2025 28 | 29 | # %% 30 | class Processor: 31 | 32 | @ex.capture 33 | def load_data(self, train_list, train_label, test_list, test_label, batch_size, language_path, one_shot_exemplar_data_path): 34 | self.dataset = dict() 35 | self.data_loader = dict() 36 | self.best_epoch = -1 37 | self.best_acc = -1 38 | self.dim_loss = -1 39 | self.test_acc = -1 40 | self.test_aug_acc = -1 41 | self.best_aug_acc = -1 42 | self.best_aug_epoch = -1 43 | 44 | self.full_language = np.load(language_path) 45 | self.full_language = torch.Tensor(self.full_language) 46 | self.full_language = self.full_language.cuda() 47 | 48 | self.one_shot_exemplar_data = np.load(one_shot_exemplar_data_path) 49 | self.one_shot_exemplar_data = torch.Tensor(self.one_shot_exemplar_data).cuda() 50 | 51 | self.dataset['train'] = DataSet(train_list, train_label) 52 | self.dataset['test'] = DataSet(test_list, test_label) 53 | 54 | self.data_loader['train'] = torch.utils.data.DataLoader( 55 | dataset=self.dataset['train'], 56 | batch_size=batch_size, 57 | num_workers=16, 58 | shuffle=True, 59 | drop_last=True) 60 | 61 | self.data_loader['test'] = torch.utils.data.DataLoader( 62 | dataset=self.dataset['test'], 63 | batch_size=64, 64 | num_workers=16, 65 | shuffle=False) 66 | 67 | def load_weights(self, model=None, weight_path=None): 68 | pretrained_dict = torch.load(weight_path) 69 | model.load_state_dict(pretrained_dict) 70 | 71 | def adjust_learning_rate(self,optimizer,current_epoch, max_epoch,lr_min=0,lr_max=0.1,warmup_epoch=15, loss_mode='step', step=[50, 80]): 72 | 73 | if current_epoch < warmup_epoch: 74 | lr = lr_max * current_epoch / warmup_epoch 75 | elif loss_mode == 'cos': 76 | lr = lr_min + (lr_max-lr_min)*(1 + cos(pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2 77 | elif loss_mode == 'step': 78 | lr = lr_max * (0.1 ** np.sum(current_epoch >= np.array(step))) 79 | else: 80 | raise Exception('Please check loss_mode!') 81 | 82 | for param_group in optimizer.param_groups: 83 | param_group['lr'] = lr 84 | # if i == 0: 85 | # param_group['lr'] = lr * 0.1 86 | # else: 87 | # param_group['lr'] = lr 88 | 89 | def layernorm(self, feature): 90 | 91 | num = feature.shape[0] 92 | mean = torch.mean(feature, dim=1).reshape(num, -1) 93 | var = torch.var(feature, dim=1).reshape(num, -1) 94 | out = (feature-mean) / torch.sqrt(var) 95 | 96 | return out 97 | 98 | @ex.capture 99 | def load_model(self,in_channels,hidden_channels,hidden_dim, 100 | dropout,graph_args,edge_importance_weighting, loss_type, weight_path): 101 | self.encoder = Model(in_channels=in_channels, hidden_channels=hidden_channels, 102 | hidden_dim=hidden_dim,dropout=dropout, 103 | graph_args=graph_args, 104 | edge_importance_weighting=edge_importance_weighting, 105 | ) 106 | self.encoder = self.encoder.cuda() 107 | self.adapter = Linear().cuda() 108 | if loss_type == "kl": 109 | self.loss = KLLoss().cuda() 110 | else: 111 | raise Exception('loss_type Error!') 112 | self.logit_scale = self.adapter.get_logit_scale() 113 | self.logit_scale_v2 = self.adapter.get_logit_scale_v2() 114 | pretrained_dict = torch.load(weight_path) 115 | # print(pretrained_dict['encoder']) 116 | self.encoder.load_state_dict(pretrained_dict['encoder']) 117 | self.adapter.load_state_dict(pretrained_dict['adapter']) 118 | 119 | # self.model = MI(visual_size, language_size).cuda() 120 | 121 | 122 | @ex.capture 123 | def load_optim(self, lr, epoch_num, weight_decay): 124 | # self.optimizer = torch.optim.Adam([ 125 | # {'params': self.encoder.parameters()}, 126 | # {'params': self.model.parameters()}], 127 | # lr=lr, 128 | # weight_decay=weight_decay, 129 | # ) 130 | # self.optimizer = torch.optim.Adam([ 131 | # {'params': self.encoder.parameters()}, 132 | # {'params': self.adapter.parameters()}], 133 | # lr=lr, 134 | # weight_decay=weight_decay, 135 | # ) 136 | self.optimizer = torch.optim.SGD([ 137 | {'params': self.encoder.parameters()}, 138 | {'params': self.adapter.parameters()}], 139 | lr=lr, 140 | weight_decay=weight_decay, 141 | momentum=0.9, 142 | nesterov=False 143 | ) 144 | # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 100) 145 | 146 | @ex.capture 147 | def optimize(self, epoch_num, DA): # print -> log.info 148 | self.log.info("main track") 149 | epoch = 0 150 | with torch.no_grad(): 151 | self.test_epoch(epoch=epoch) 152 | self.log.info("epoch [{}] test acc: {}".format(epoch,self.test_acc)) 153 | # self.log.info("epoch [{}] gets the best acc: {}".format(self.best_epoch,self.best_acc)) 154 | if DA: 155 | self.log.info("epoch [{}] DA test acc: {}".format(epoch,self.test_aug_acc)) 156 | 157 | @ex.capture 158 | def train_epoch(self, epoch, lr, loss_mode, step, loss_type, alpha, beta, m, fix_encoder): 159 | self.encoder.train() # eval -> train 160 | if fix_encoder: 161 | self.encoder.eval() 162 | self.adapter.train() 163 | self.adjust_learning_rate(self.optimizer, current_epoch=epoch, max_epoch=100, lr_max=lr, warmup_epoch=5, loss_mode=loss_mode, step=step) 164 | running_loss = [] 165 | loader = self.data_loader['train'] 166 | for data, label in tqdm(loader): 167 | data = data.type(torch.FloatTensor).cuda() 168 | # print(data.shape) #128,3,50,25,2 169 | # label = label.type(torch.LongTensor).cuda() 170 | label_g = gen_label(label) 171 | label = label.type(torch.LongTensor).cuda() 172 | # print(label.shape) # 128 173 | # print(label) # int 174 | seen_language = self.full_language[label] # 128, 768 175 | # print(seen_language.shape) 176 | 177 | feat = self.encoder(data) 178 | if fix_encoder: 179 | feat = feat.detach() 180 | skleton_feat = self.adapter(feat) 181 | if loss_type == "kl": 182 | logits_per_skl, logits_per_text = create_logits(skleton_feat, seen_language, self.logit_scale, exp=True) 183 | ground_truth = torch.tensor(label_g, dtype=skleton_feat.dtype).cuda() 184 | # ground_truth = gen_label_from_text_sim(seen_language) 185 | loss_skls = self.loss(logits_per_skl, ground_truth) 186 | loss_texts = self.loss(logits_per_text, ground_truth) 187 | loss = (loss_skls + loss_texts) / 2 188 | else: 189 | raise Exception('loss_type Error!') 190 | 191 | running_loss.append(loss) 192 | self.optimizer.zero_grad() 193 | loss.backward() 194 | self.optimizer.step() 195 | 196 | running_loss = torch.tensor(running_loss) 197 | self.dim_loss = running_loss.mean().item() 198 | 199 | @ex.capture 200 | def test_epoch(self, unseen_label, epoch, DA, support_factor, da_iterations=3): # Add da_iterations parameter 201 | self.encoder.eval() 202 | self.adapter.eval() 203 | 204 | loader = self.data_loader['test'] 205 | acc_list = [] 206 | 207 | # First get baseline accuracy without DA 208 | for data, label in tqdm(loader): 209 | data = data.type(torch.FloatTensor).cuda() 210 | label = label.type(torch.LongTensor).cuda() 211 | one_shot_skeleton = self.adapter(self.encoder(self.one_shot_exemplar_data)) 212 | feature = self.encoder(data) 213 | feature = self.adapter(feature) 214 | acc_batch, pred = get_acc(feature, one_shot_skeleton, unseen_label, label) 215 | acc_list.append(acc_batch) 216 | 217 | acc_list = torch.tensor(acc_list) 218 | acc = acc_list.mean() 219 | if acc > self.best_acc: 220 | self.best_acc = acc 221 | self.best_epoch = epoch 222 | self.save_model() 223 | self.test_acc = acc 224 | 225 | if DA: 226 | current_skeleton = self.adapter(self.encoder(self.one_shot_exemplar_data)) 227 | best_aug_acc = 0 228 | 229 | # Perform multiple DA iterations 230 | for da_iter in range(da_iterations): 231 | ent_list = [] 232 | feat_list = [] 233 | old_pred_list = [] 234 | 235 | # Collect features and entropy for refinement 236 | for data, label in tqdm(loader): 237 | data = data.type(torch.FloatTensor).cuda() 238 | label = label.type(torch.LongTensor).cuda() 239 | feature = self.encoder(data) 240 | feature = self.adapter(feature) 241 | _, pred, old_pred, ent, feat = get_acc_v2(feature, current_skeleton, unseen_label, label) 242 | ent_list.append(ent) 243 | feat_list.append(feat) 244 | old_pred_list.append(old_pred) 245 | 246 | # Refine prototypes 247 | ent_all = torch.cat(ent_list) 248 | feat_all = torch.cat(feat_list) 249 | old_pred_all = torch.cat(old_pred_list) 250 | z_list = [] 251 | for i in range(len(unseen_label)): 252 | mask = old_pred_all == i 253 | class_support_set = feat_all[mask] 254 | class_ent = ent_all[mask] 255 | class_len = class_ent.shape[0] 256 | if int(class_len*support_factor) < 1: 257 | z = self.full_language[unseen_label[i:i+1]] 258 | else: 259 | _, indices = torch.topk(-class_ent, int(class_len*support_factor)) 260 | z = torch.mean(class_support_set[indices], dim=0, keepdim=True) 261 | z_list.append(z) 262 | 263 | current_skeleton = torch.cat(z_list) 264 | 265 | # Evaluate with refined prototypes 266 | aug_acc_list = [] 267 | for data, label in tqdm(loader): 268 | data = data.type(torch.FloatTensor).cuda() 269 | label = label.type(torch.LongTensor).cuda() 270 | feature = self.encoder(data) 271 | feature = self.adapter(feature) 272 | acc_batch, pred = get_acc(feature, current_skeleton, unseen_label, label) 273 | aug_acc_list.append(acc_batch) 274 | 275 | aug_acc = torch.tensor(aug_acc_list).mean() 276 | 277 | # Update best augmented accuracy 278 | if aug_acc > best_aug_acc: 279 | best_aug_acc = aug_acc 280 | best_skeleton = current_skeleton.clone() 281 | 282 | # Update class attributes with best results 283 | if best_aug_acc > self.best_aug_acc: 284 | self.best_aug_acc = best_aug_acc 285 | self.best_aug_epoch = epoch 286 | self.test_aug_acc = best_aug_acc 287 | 288 | # @ex.capture 289 | # def test_epoch(self, unseen_label, epoch, DA, support_factor): 290 | # self.encoder.eval() 291 | # self.adapter.eval() 292 | 293 | # loader = self.data_loader['test'] 294 | # y_true = [] 295 | # y_pred = [] 296 | # acc_list = [] 297 | # ent_list = [] 298 | # feat_list = [] 299 | # old_pred_list = [] 300 | # for data, label in tqdm(loader): 301 | 302 | # # y_t = label.numpy().tolist() 303 | # # y_true += y_t 304 | 305 | # data = data.type(torch.FloatTensor).cuda() 306 | # label = label.type(torch.LongTensor).cuda() 307 | # # unseen_language = self.full_language[unseen_label] 308 | # one_shot_skeleton = self.adapter(self.encoder(self.one_shot_exemplar_data)) 309 | # # inference 310 | # feature = self.encoder(data) 311 | # feature = self.adapter(feature) 312 | # if DA: 313 | # # acc_batch, pred = get_acc(feature, unseen_language, unseen_label, label) 314 | # acc_batch, pred, old_pred, ent, feat = get_acc_v2(feature, one_shot_skeleton, unseen_label, label) 315 | # ent_list.append(ent) 316 | # feat_list.append(feat) 317 | # old_pred_list.append(old_pred) 318 | # else: 319 | # acc_batch, pred = get_acc(feature, one_shot_skeleton, unseen_label, label) 320 | 321 | # # y_p = pred.cpu().numpy().tolist() 322 | # # y_pred += y_p 323 | 324 | 325 | # acc_list.append(acc_batch) 326 | 327 | # acc_list = torch.tensor(acc_list) 328 | # acc = acc_list.mean() 329 | # if acc > self.best_acc: 330 | # self.best_acc = acc 331 | # self.best_epoch = epoch 332 | # self.save_model() 333 | # # y_true = np.array(y_true) 334 | # # y_pred = np.array(y_pred) 335 | # # np.save("y_true_3.npy",y_true) 336 | # # np.save("y_pred_3.npy",y_pred) 337 | # # print("save ok!") 338 | # self.test_acc = acc 339 | 340 | # if DA: 341 | # ent_all = torch.cat(ent_list) 342 | # feat_all = torch.cat(feat_list) 343 | # old_pred_all = torch.cat(old_pred_list) 344 | # z_list = [] 345 | # for i in range(len(unseen_label)): 346 | # mask = old_pred_all == i 347 | # class_support_set = feat_all[mask] 348 | # class_ent = ent_all[mask] 349 | # class_len = class_ent.shape[0] 350 | # if int(class_len*support_factor) < 1: 351 | # z = self.full_language[unseen_label[i:i+1]] 352 | # else: 353 | # _, indices = torch.topk(-class_ent, int(class_len*support_factor)) 354 | # z = torch.mean(class_support_set[indices], dim=0, keepdim=True) 355 | # z_list.append(z) 356 | 357 | # z_tensor = torch.cat(z_list) 358 | # aug_acc_list = [] 359 | # for data, label in tqdm(loader): 360 | # # y_t = label.numpy().tolist() 361 | # # y_true += y_t 362 | 363 | # data = data.type(torch.FloatTensor).cuda() 364 | # label = label.type(torch.LongTensor).cuda() 365 | # one_shot_skeleton = z_tensor 366 | # # inference 367 | # feature = self.encoder(data) 368 | # feature = self.adapter(feature) 369 | # # acc_batch, pred = get_acc(feature, unseen_language, unseen_label, label) 370 | # acc_batch, pred = get_acc(feature, one_shot_skeleton, unseen_label, label) 371 | 372 | # # y_p = pred.cpu().numpy().tolist() 373 | # # y_pred += y_p 374 | # aug_acc_list.append(acc_batch) 375 | # aug_acc = torch.tensor(aug_acc_list).mean() 376 | # if aug_acc > self.best_aug_acc: 377 | # self.best_aug_acc = aug_acc 378 | # self.best_aug_epoch = epoch 379 | # self.test_aug_acc = aug_acc 380 | 381 | 382 | 383 | def initialize(self): 384 | self.load_data() 385 | self.load_model() 386 | self.load_optim() 387 | self.log = Log() 388 | 389 | @ex.capture 390 | def save_model(self, save_path): 391 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 392 | torch.save({'encoder':self.encoder.state_dict(), 'adapter':self.adapter.state_dict()}, save_path) 393 | 394 | def start(self): 395 | self.initialize() 396 | self.optimize() 397 | # self.save_model() 398 | 399 | 400 | 401 | 402 | # %% 403 | @ex.automain 404 | def main(track): 405 | if "main" in track: 406 | p = Processor() 407 | p.start() 408 | -------------------------------------------------------------------------------- /module/Temporal_shift/cuda/shift_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace { 11 | template 12 | __global__ void shift_cuda_forward_kernel( 13 | const scalar_t* __restrict__ input, 14 | scalar_t* output, 15 | scalar_t* xpos, 16 | scalar_t* ypos, 17 | const int batch, 18 | const int channel, 19 | const int bottom_height, 20 | const int bottom_width, 21 | const int top_height, 22 | const int top_width, 23 | const int stride) 24 | { 25 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 26 | 27 | 28 | if (index < batch*channel*top_height*top_width) 29 | { 30 | const int top_sp_dim = top_height * top_width; 31 | const int bottom_sp_dim = bottom_height * bottom_width; 32 | const int n = index/(channel * top_sp_dim); 33 | const int idx = index%(channel * top_sp_dim); 34 | const int c_out = idx/top_sp_dim; 35 | const int c_in = c_out; 36 | const int sp_idx = idx%top_sp_dim; 37 | const int h = sp_idx/top_width; 38 | const int w = sp_idx%top_width; 39 | const scalar_t* data_im_ptr = input + n*channel*bottom_sp_dim + c_in*bottom_sp_dim; // ->(n,c) 40 | 41 | const int h_offset = h * stride; // h on input feature map 42 | const int w_offset = w; // w on input feature map 43 | 44 | scalar_t val = 0; 45 | const scalar_t x = xpos[c_in]; 46 | const scalar_t y = ypos[c_in]; 47 | 48 | int h_im, w_im; 49 | int x1 = floorf(x); 50 | int x2 = x1+1; 51 | int y1 = floorf(y); 52 | int y2 = y1+1; 53 | 54 | h_im = h_offset + y1; 55 | w_im = w_offset + x1; 56 | scalar_t q11 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? data_im_ptr[h_im*bottom_width + w_im] : 0; 57 | 58 | h_im = h_offset + y1; 59 | w_im = w_offset + x2; 60 | scalar_t q21 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? data_im_ptr[h_im*bottom_width + w_im] : 0; 61 | 62 | h_im = h_offset + y2; 63 | w_im = w_offset + x1; 64 | scalar_t q12 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? data_im_ptr[h_im*bottom_width + w_im] : 0; 65 | 66 | h_im = h_offset + y2; 67 | w_im = w_offset + x2; 68 | scalar_t q22 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? data_im_ptr[h_im*bottom_width + w_im] : 0; 69 | 70 | scalar_t dx = x-x1; 71 | scalar_t dy = y-y1; 72 | 73 | val = q11*(1-dx)*(1-dy) + q21*dx*(1-dy) + q12*(1-dx)*dy + q22*dx*dy; 74 | output[index] = val; 75 | } 76 | } 77 | 78 | template 79 | __global__ void Shift_Bottom_Backward_Stride1( 80 | const scalar_t* __restrict__ grad_output, 81 | scalar_t* grad_input, 82 | scalar_t* xpos, 83 | scalar_t* ypos, 84 | const int batch, 85 | const int channel, 86 | const int bottom_height, 87 | const int bottom_width) 88 | { 89 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 90 | 91 | if (index < batch*channel*bottom_height*bottom_width) 92 | { 93 | const int top_sp_dim = bottom_height * bottom_width; // h*w 94 | const int bottom_sp_dim = bottom_height * bottom_width; 95 | const int n = index/(channel * bottom_sp_dim); 96 | const int idx = index%(channel * bottom_sp_dim); 97 | const int c_in = idx/bottom_sp_dim; 98 | const int c_out = c_in; 99 | const int sp_idx = idx%bottom_sp_dim; 100 | const int h_col = sp_idx/bottom_width; 101 | const int w_col = sp_idx%bottom_width; 102 | const scalar_t* top_diff_ptr = grad_output + n*channel*top_sp_dim + c_out*top_sp_dim; 103 | 104 | const int h_offset = h_col; 105 | const int w_offset = w_col; 106 | 107 | scalar_t val = 0; 108 | const scalar_t x = -xpos[c_in]; //reverse position 109 | const scalar_t y = -ypos[c_in]; 110 | 111 | int h_im, w_im; 112 | 113 | int x1 = floorf(x); 114 | int x2 = x1+1; 115 | int y1 = floorf(y); 116 | int y2 = y1+1; 117 | 118 | //q11 119 | scalar_t q11 = 0; 120 | 121 | h_im = (h_offset + y1); 122 | w_im = (w_offset + x1); 123 | q11 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? top_diff_ptr[h_im*bottom_width + w_im] : 0; 124 | 125 | //q21 126 | scalar_t q21 = 0; 127 | 128 | h_im = (h_offset + y1); 129 | w_im = (w_offset + x2); 130 | q21 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? top_diff_ptr[h_im*bottom_width + w_im] : 0; 131 | 132 | //q12 133 | scalar_t q12 = 0; 134 | 135 | h_im = (h_offset + y2); 136 | w_im = (w_offset + x1); 137 | q12 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? top_diff_ptr[h_im*bottom_width + w_im] : 0; 138 | 139 | //q22 140 | scalar_t q22 = 0; 141 | 142 | h_im = (h_offset + y2); 143 | w_im = (w_offset + x2); 144 | q22 = (h_im >= 0 && w_im >= 0 && h_im < bottom_height && w_im < bottom_width) ? top_diff_ptr[h_im*bottom_width + w_im] : 0; 145 | 146 | scalar_t dx = x-x1; 147 | scalar_t dy = y-y1; 148 | 149 | val = q11*(1-dx)*(1-dy) + q21*dx*(1-dy) + q12*(1-dx)*dy + q22*dx*dy; 150 | grad_input[index] = val; 151 | } 152 | } 153 | 154 | 155 | template 156 | __global__ void Shift_Bottom_Backward( 157 | const scalar_t* __restrict__ grad_output, 158 | scalar_t* grad_input, 159 | scalar_t* xpos, 160 | scalar_t* ypos, 161 | const int batch, 162 | const int channel, 163 | const int bottom_height, 164 | const int bottom_width) 165 | { 166 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 167 | 168 | 169 | if (index < batch*channel*bottom_height*bottom_width) 170 | { 171 | 172 | const int top_height = bottom_height/2; 173 | const int top_width = bottom_width; 174 | const int stride = 2; 175 | const int top_sp_dim = top_height * top_width; 176 | const int bottom_sp_dim = bottom_height * bottom_width; 177 | const int n = index/(channel * bottom_sp_dim); 178 | const int idx = index%(channel * bottom_sp_dim); 179 | const int c_in = idx/bottom_sp_dim; 180 | const int c_out = c_in; 181 | const int sp_idx = idx%bottom_sp_dim; 182 | const int h_col = sp_idx/bottom_width; 183 | const int w_col = sp_idx%bottom_width; 184 | const scalar_t* top_diff_ptr = grad_output + n*channel*top_sp_dim + c_out*top_sp_dim; 185 | 186 | const int h_offset = h_col; 187 | const int w_offset = w_col; 188 | 189 | 190 | scalar_t val = 0; 191 | const scalar_t x = -xpos[c_in]; 192 | const scalar_t y = -ypos[c_in]; 193 | 194 | int h_im, w_im; 195 | int x1 = floorf(x); 196 | int x2 = x1+1; 197 | int y1 = floorf(y); 198 | int y2 = y1+1; 199 | 200 | //q11 201 | scalar_t q11 = 0; 202 | 203 | h_im = (h_offset + y1); 204 | w_im = (w_offset + x1); 205 | if(h_im%stride == 0) 206 | { 207 | h_im=h_im/stride; 208 | 209 | q11 = (h_im >= 0 && w_im >= 0 && h_im < top_height && w_im < top_width) ? top_diff_ptr[h_im*top_width + w_im] : 0; 210 | } 211 | 212 | //q21 213 | scalar_t q21 = 0; 214 | 215 | h_im = (h_offset + y1); 216 | w_im = (w_offset + x2); 217 | if(h_im%stride == 0) 218 | { 219 | h_im=h_im/stride; 220 | 221 | q21 = (h_im >= 0 && w_im >= 0 && h_im < top_height && w_im < top_width) ? top_diff_ptr[h_im*top_width + w_im] : 0; 222 | } 223 | 224 | //q12 225 | scalar_t q12 = 0; 226 | 227 | h_im = (h_offset + y2); 228 | w_im = (w_offset + x1); 229 | 230 | if(h_im%stride == 0) 231 | { 232 | h_im=h_im/stride; 233 | 234 | q12 = (h_im >= 0 && w_im >= 0 && h_im < top_height && w_im < top_width) ? top_diff_ptr[h_im*top_width + w_im] : 0; 235 | } 236 | 237 | //q22 238 | scalar_t q22 = 0; 239 | 240 | h_im = (h_offset + y2); 241 | w_im = (w_offset + x2); 242 | 243 | if(h_im%stride == 0) 244 | { 245 | h_im=h_im/stride; 246 | 247 | q22 = (h_im >= 0 && w_im >= 0 && h_im < top_height && w_im < top_width) ? top_diff_ptr[h_im*top_width + w_im] : 0; 248 | } 249 | 250 | scalar_t dx = x-x1; 251 | scalar_t dy = y-y1; 252 | 253 | val = q11*(1-dx)*(1-dy) + q21*dx*(1-dy) + q12*(1-dx)*dy + q22*dx*dy; 254 | grad_input[index] = val; 255 | } 256 | } // namespace 257 | 258 | 259 | 260 | template 261 | __inline__ __device__ void myAtomicAdd(scalar_t *buf, scalar_t val); 262 | 263 | template <> 264 | __inline__ __device__ void myAtomicAdd(float *buf, float val) 265 | { 266 | atomicAdd(buf, val); 267 | } 268 | 269 | template <> 270 | __inline__ __device__ void myAtomicAdd(double *buf, double val) 271 | { 272 | //Not Supported 273 | } 274 | 275 | 276 | 277 | template 278 | __global__ void Shift_Position_Backward( 279 | const scalar_t* __restrict__ input, 280 | const scalar_t* __restrict__ grad_output, 281 | scalar_t* grad_input, 282 | scalar_t* xpos, 283 | scalar_t* ypos, 284 | scalar_t* grad_xpos_bchw, 285 | scalar_t* grad_ypos_bchw, 286 | const int batch, 287 | const int channel, 288 | const int bottom_height, 289 | const int bottom_width, 290 | const int stride) 291 | { 292 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 293 | 294 | const int top_height = bottom_height/stride; 295 | const int top_width = bottom_width; 296 | 297 | 298 | if (index < batch*channel*top_height*top_width) 299 | { 300 | const int top_sp_dim = top_height * top_width; 301 | const int bottom_sp_dim = bottom_height * bottom_width; 302 | const int n = index/(channel * top_sp_dim); 303 | const int idx = index%(channel * top_sp_dim); 304 | const int c_mul = 1; 305 | const int c_out = idx/top_sp_dim; 306 | const int c_in = c_out/c_mul; 307 | const int sp_idx = idx%top_sp_dim; 308 | const int h = sp_idx/top_width; 309 | const int w = sp_idx%top_width; 310 | const scalar_t* data_im_ptr = input + n*channel*bottom_sp_dim + c_in*bottom_sp_dim; 311 | 312 | const int h_offset = h * stride; 313 | const int w_offset = w; 314 | 315 | //output : 2*(C) x (1*H*W) 316 | const int kernel_offset = top_sp_dim; 317 | const int c_off = c_out % c_mul; 318 | 319 | scalar_t val_x = 0, val_y = 0; 320 | 321 | const scalar_t shiftX = xpos[c_in]; 322 | const scalar_t shiftY = ypos[c_in]; 323 | 324 | 325 | const int ix1 = floorf(shiftX); 326 | const int ix2 = ix1+1; 327 | const int iy1 = floorf(shiftY); 328 | const int iy2 = iy1+1; 329 | const scalar_t dx = shiftX-ix1; 330 | const scalar_t dy = shiftY-iy1; 331 | 332 | const int h_im1 = h_offset + iy1; 333 | const int h_im2 = h_offset + iy2; 334 | 335 | const int w_im1 = w_offset + ix1; 336 | const int w_im2 = w_offset + ix2; 337 | 338 | const scalar_t q11 = (h_im1 >= 0 && w_im1 >= 0 && h_im1 < bottom_height && w_im1 < bottom_width) ? data_im_ptr[h_im1*bottom_width + w_im1] : 0; 339 | const scalar_t q21 = (h_im1 >= 0 && w_im2 >= 0 && h_im1 < bottom_height && w_im2 < bottom_width) ? data_im_ptr[h_im1*bottom_width + w_im2] : 0; 340 | const scalar_t q12 = (h_im2 >= 0 && w_im1 >= 0 && h_im2 < bottom_height && w_im1 < bottom_width) ? data_im_ptr[h_im2*bottom_width + w_im1] : 0; 341 | const scalar_t q22 = (h_im2 >= 0 && w_im2 >= 0 && h_im2 < bottom_height && w_im2 < bottom_width) ? data_im_ptr[h_im2*bottom_width + w_im2] : 0; 342 | 343 | val_x = (1-dy)*(q21-q11)+dy*(q22-q12); 344 | val_y = (1-dx)*(q12-q11)+dx*(q22-q21); 345 | 346 | 347 | 348 | grad_xpos_bchw[index] = val_x * grad_output[index]; 349 | grad_ypos_bchw[index] = val_y * grad_output[index]; 350 | 351 | //grad_xpos_bchw[index] = val_x; 352 | //grad_ypos_bchw[index] = val_y; 353 | 354 | //grad_xpos_bchw[index] = 0; 355 | //grad_ypos_bchw[index] = 0; 356 | 357 | //scalar_t* out_ptr_x = grad_xpos_bchw + index; 358 | //scalar_t* out_ptr_y = grad_ypos_bchw + index; 359 | 360 | //myAtomicAdd(out_ptr_x, val_x * grad_output[index]); 361 | //myAtomicAdd(out_ptr_y, val_y * grad_output[index]); 362 | } 363 | } // namespace 364 | 365 | 366 | 367 | 368 | 369 | 370 | template 371 | __global__ void applyShiftConstraint( 372 | scalar_t* grad_xpos, 373 | scalar_t* grad_ypos, 374 | const int channel) 375 | { 376 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 377 | 378 | if (index < channel) 379 | { 380 | const scalar_t dx = grad_xpos[index]; 381 | const scalar_t dy = grad_ypos[index]; 382 | const scalar_t dr = sqrt(dy*dy); 383 | 384 | if(dr!=0) 385 | { 386 | grad_xpos[index] = dx/dr*0.0; 387 | grad_ypos[index] = dy/dr*0.01; 388 | } 389 | else // without this, the grad_ypos may be large. 390 | { 391 | grad_xpos[index] = 0.0; 392 | grad_ypos[index] = 0.0001; 393 | } 394 | } 395 | } // namespace 396 | 397 | 398 | 399 | 400 | } 401 | 402 | 403 | 404 | 405 | at::Tensor shift_cuda_forward( 406 | at::Tensor input,at::Tensor xpos,at::Tensor ypos,const int stride) { 407 | 408 | auto output = at::zeros({input.size(0), input.size(1), input.size(2)/stride, input.size(3)}, input.options()); 409 | 410 | const dim3 blocks((input.size(0)*input.size(1)*input.size(2)*input.size(3)/stride+1024-1)/1024); 411 | const int threads = 1024; 412 | 413 | AT_DISPATCH_FLOATING_TYPES(input.type(), "shift_forward_cuda", ([&] { 414 | shift_cuda_forward_kernel<<>>( 415 | input.data(), 416 | output.data(), 417 | xpos.data(), 418 | ypos.data(), 419 | input.size(0), 420 | input.size(1), 421 | input.size(2), 422 | input.size(3), 423 | input.size(2)/stride, 424 | input.size(3), 425 | stride); 426 | })); 427 | 428 | //std::cout << output[0] << std::endl; 429 | 430 | return output; 431 | } 432 | 433 | std::vector shift_cuda_backward( 434 | at::Tensor grad_output, 435 | at::Tensor input, 436 | at::Tensor output, 437 | at::Tensor xpos, 438 | at::Tensor ypos, 439 | const int stride) { 440 | auto grad_input = at::zeros_like(input); 441 | 442 | 443 | 444 | 445 | const dim3 blocks((input.size(0)*input.size(1)*input.size(2)*input.size(3)+1024-1)/1024); 446 | const int threads = 1024; 447 | 448 | if(stride==1) 449 | { 450 | AT_DISPATCH_FLOATING_TYPES(input.type(), "Shift_Bottom_Backward_Stride1_", ([&] { 451 | Shift_Bottom_Backward_Stride1<<>>( 452 | grad_output.data(), 453 | grad_input.data(), 454 | xpos.data(), 455 | ypos.data(), 456 | input.size(0), 457 | input.size(1), 458 | input.size(2), 459 | input.size(3)); 460 | })); 461 | } 462 | else 463 | { 464 | AT_DISPATCH_FLOATING_TYPES(input.type(), "Shift_Bottom_Backward_", ([&] { 465 | Shift_Bottom_Backward<<>>( 466 | grad_output.data(), 467 | grad_input.data(), 468 | xpos.data(), 469 | ypos.data(), 470 | input.size(0), 471 | input.size(1), 472 | input.size(2), 473 | input.size(3)); 474 | })); 475 | } 476 | 477 | 478 | 479 | 480 | auto grad_xpos_bchw = at::zeros({output.size(0), output.size(1), output.size(2), output.size(3)}, output.options()); // (b,c,h,w) 481 | auto grad_ypos_bchw = at::zeros({output.size(0), output.size(1), output.size(2), output.size(3)}, output.options()); // (b,c,h,w) 482 | 483 | const dim3 blocks_output((output.size(0)*output.size(1)*output.size(2)*output.size(3)+1024-1)/1024); 484 | 485 | AT_DISPATCH_FLOATING_TYPES(input.type(), "Shift_Position_Backward_", ([&] { 486 | Shift_Position_Backward<<>>( 487 | input.data(), 488 | grad_output.data(), 489 | grad_input.data(), 490 | xpos.data(), 491 | ypos.data(), 492 | grad_xpos_bchw.data(), 493 | grad_ypos_bchw.data(), 494 | input.size(0), 495 | input.size(1), 496 | input.size(2), 497 | input.size(3), 498 | stride); 499 | })); 500 | 501 | auto grad_xpos_chw = at::mean(grad_xpos_bchw, 0, false); 502 | auto grad_xpos_ch = at::sum(grad_xpos_chw, 2, false); 503 | auto grad_xpos_c = at::sum(grad_xpos_ch, 1, false); 504 | auto grad_xpos = grad_xpos_c; 505 | 506 | auto grad_ypos_chw = at::mean(grad_ypos_bchw, 0, false); 507 | auto grad_ypos_ch = at::sum(grad_ypos_chw, 2, false); 508 | auto grad_ypos_c = at::sum(grad_ypos_ch, 1, false); 509 | auto grad_ypos = grad_ypos_c; 510 | 511 | 512 | 513 | const dim3 blocks_norm((output.size(1)+1024-1)/1024); 514 | 515 | AT_DISPATCH_FLOATING_TYPES(input.type(), "applyShiftConstraint_", ([&] { 516 | applyShiftConstraint<<>>( 517 | grad_xpos.data(), 518 | grad_ypos.data(), 519 | output.size(1)); 520 | })); 521 | 522 | return {grad_input,grad_xpos,grad_ypos}; 523 | } 524 | -------------------------------------------------------------------------------- /descriptions/ntu_parts_from_GAP.txt: -------------------------------------------------------------------------------- 1 | drink water;head tilts back slightly; hand grasps cup; arm lifts cup to mouth; hip remains stationary; leg remains stationary; foot remains stationary. 2 | eat meal;head tilts slightly forward; hand brings food to mouth; arm supports hand; hip remains stationary; leg remains stationary; foot remains stationary. 3 | brush teeth;head tilts forward slightly; hand brings toothbrush up to mouth; arm extends forward; hip remains stationary; leg remains stationary; foot remains stationary. 4 | brush hair;head is tilted slightly forward as he runs the brush through hair; hand is moving back and forth across head;arm is moving up and down, following the motion of the brush; hip is slightly tilted to the side; leg is slightly bent at the knee; foot is planted firmly on the ground. 5 | drop;head falls forward; hands go to sides; arms hang down; hips drop; legs bend at the knees; feet come off the ground. 6 | pick up;head tilts slightly forward; hand reaches down; arm extends; hip doesn't move; leg doesn't move; foot doesn't move. 7 | throw;head turns to the direction of the throw as arm goes back; hand goes back as he cocks arm;arm goes back and then forward, releasing the object;hip turns to the direction of the throw as arm goes back;leg goes back as he cocks arm;foot goes back as he cocks arm. 8 | sit down;head tilts slightly forward; hand supports body weight on thighs; arm hangs loosely at sides; hip bends at a 90-degree angle; leg bends at a 90-degree angle; foot rests flat on the floor. 9 | stand up;head tilts slightly forward; hand pushes down on thighs; arm extends to straighten legs; hip lifts body upward; leg straightens at knee and hip; foot comes down flat on the floor. 10 | clapping; head nodding up and down; hand slapping two palms together; arm swinging up and down; hip not moving; leg not moving; foot not moving 11 | reading;head is tilted slightly forward and to the left; hand right hand is holding the book up close to face; left arm is bent at the elbow and resting on the arm of the chair; hips are level with shoulders; egs are crossed at the ankles; feet are flat on the floor. 12 | writing;head is tilted slightly forward and to the left; right hand is holding a pen and left hand is holding the paper; right arm is moving the pen across the paper; hips are stationary; legs are stationary; feet are flat on the floor. 13 | tear up paper;head is tilted back slightly as he looks at the paper in hand;hand is holding the paper close to face;arm is extended slightly;hip is slightly rotated;leg is slightly bent;feet are flat on the ground. 14 | put on jacket;head looks up at the jacket; hand reaches up and grasps the jacket; arm brings the jacket down to level; hip slides arms into the jacket; leg stands up straight; foot stands on both feet. 15 | take off jacket;head tilts back slightly; grabs the bottom of jacket with both hands, brings hands up the jacket; arms straighten as the jacket falls downs;hip steps out of the jacket;legs straight;feet on the ground. 16 | put on a shoe;head tilts slightly forward; hand reaches down and grasps shoe; arm extends down and forward; hip remains stationary; leg bends at the knee, bringing foot closer to the hand; foot inserts into shoe. 17 | take off a shoe;head tilts slightly forward; hand reaches down to the shoe; arm extends down to the shoe; hip remains stationary; leg bends at the knee; foot grasps the shoe and pulls it off 18 | put on glasses;head tilt slightly backwards; hand reach up to forehead; arm bend at elbow; hip stay stationary; leg stay stationary; foot stay stationary. 19 | take off glasses;head tilts slightly forward; hand comes up to face; arm elbow bends; hip doesn't move; leg doesn't move; foot doesn't move. 20 | put on a hat/cap;head tilts forward slightly and descends; hand grasps edges of hat; arm raises hat to head level; hip tilts to one side; leg shifts weight to one side; foot remains stationary. 21 | take off a hat/cap; head tilts slightly forward; hand grasps the brim of the hat/cap; arm: raises the hat/cap; hip remains stationary; leg remains stationary; foot remains stationary 22 | cheer up;head tilts slightly forward; hand comes up to chest level with palms facing in; arm elbow is bent and held close to body; hip remains stationary; leg remains stationary; foot remains stationary. 23 | hand waving;head nodding; hand waving; arm swinging; hip swaying; leg shifting; foot stomping 24 | kicking something; head focus on his target;hands are clenched into fists as he prepares to strike;arm is raised and ready to strike;hip is cocked and ready to strike;leg is raised and ready to strike;foot is raised and ready to strike. 25 | reach into pocket;head tilts slightly forward; hand reaches into pocket; arm extends forward; hip remains stationary; leg remains stationary; foot remains stationary. 26 | hopping;head moves up and down; hand moves up and down; arm moves up and down; hip moves up and down; leg moves up and down; foot moves up and down. 27 | jump up;head tilts back and the chin points up; hands come up to the chest; arms bend at the elbows and the forearms come up; hips push forward and the legs bend at the knees; legs push off the ground and the feet come up; feet land on the ground and the legs bend at the knees. 28 | phone call; head tilts slightly forward; hand brings the phone up to ear; arm supports the phone and hand holds it in place; hip shifts weight to one side; crosses legs at the ankle; foot taps on the ground as he talks. 29 | play with phone/tablet;head looking down at the phone/tablet in hands; hand holding the phone/tablet; arm supporting the weight of the phone/tablet; hip keeping the body upright; leg keeping the body balanced; foot keeping the body stable. 30 | type on a keyboard;head looking at the keyboard; hand typing on the keyboard; arm supporting the hand; hip keeping the body stable; leg keeping the body stable; foot keeping the body stable. 31 | point to something;head points head in the direction of the object; hand raises and points finger at the object; arm is extended and finger is pointing at the object; hips are neither turned nor twisted; leg nearest to the object is slightly bent at the knee; foot nearest the object is pointing in the direction of the object. 32 | taking a selfie;head tilts slightly to the side; hand holds up phone in front of face; arm is extended out in front of body; hip is shifted to the side; leg is bent at the knee; foot is pointing slightly inward. 33 | check time (from watch);head turns to look at watch; hand raises arm to eye level; arm bends at elbow to bring watch closer to eyes; hip remains stationary; leg remains stationary; foot remains stationary 34 | rub two hands;head rubs hands together near head; hand rubs palms together; arms are extended and hands are moving back and forth; hips are stationary; legs are slightly apart; feet are flat on the ground. 35 | nod head/bow;head tilt forward slightly at the neck; hand move forward and downward; arm bend slightly at the elbow; hip remain stationary; leg move forward and downward; foot move forward and downward. 36 | shake head;head shake from side to side; hand hang down; arm swing slightly; hip stay in place; leg stay in place; foot stay in place. 37 | wipe face;head turns to the side so that cheek is facing hand; hand comes up to cheek and wipes from the bottom of eye down to jawline; arm extends out from body and hand moves across face; hip remains stationary; leg remains stationary; foot remains stationary. 38 | salute;head raises head and looks straight ahead; hand raises right hand and brings it to forehead; arm is extended straight out from shoulder; hips are level with shoulders; legs are straight and together; feet are together and pointing forward. 39 | put palms together;head tilts slightly forward; hand presses palms together; arm extends slightly forward; hip remains stationary; leg remains stationary; foot remains stationary. 40 | cross hands in front;head tilts slightly forward; hand crosses in front of body at waist level; arm extends slightly forward; hip remains stationary; leg remains stationary; foot remains stationary. 41 | sneeze/cough;head will tilt back slightly and their chin will raise up; bring their hand up to their face, cupping their mouth and nose; arm the person's arm will raise up to their face; hip will not move; leg will not move; foot will not move. 42 | staggering; head is tilting to one side; hand is gripping something for support; arm is hanging down; hip is tilted to one side; leg is not moving much; foot is dragging. 43 | falling down;head hits the ground first; hand tries to break the fall; arm bends at the elbow; hip twists to the side; leg bends at the knee; foot turns inward. 44 | headache;head holding head with both hands; hand holding head with both hands; arms holding head with both hands; hip sitting with legs crossed; leg sitting with legs crossed; foot sitting with legs crossed. 45 | chest pain;man clutches head in pain; man grabs chest in pain; man's arm falls limp at side; man doubles over in pain; man collapses to the ground; man's foot twitches uncontrollably. 46 | back pain;head he is looking down at feet; hand he is reaching down to lower back; arm is extended behind him; hip is flexed; leg is straight; foot is flat on the ground. 47 | neck pain;head is holding head with hand; hand is holding head with hand; arm is holding head with hand; hip is holding head with hand; leg is holding head with hand; foot is holding head with hand. 48 | nausea/vomiting; head is tilted back and eyes are closed; One hand is on stomach and the other is on chest; arms are slightly bent at the elbows; hips are slightly bent; legs are slightly bent at the knees; feet are flat on the ground. 49 | fan self;head tilts back slightly; hand raises up; arm extends out; hip doesn't move; leg doesn't move; foot doesn't move. 50 | punch/slap;The head moves forward; hand hits the target with a clenched fist; The arm is extended and hits the target; The hip is used for balance; The leg is used for balance; The foot is used for balance. 51 | kicking; head: turns to the direction of the kick; hand: pushes the leg that is kicking backward; arm: supports the body; hip: provides power to the kick; leg: kicks backward; foot: makes contact with the ball. 52 | pushing; tilts head slightly forward;pushes forward on an object with hand; arm extended; hip remains stationary; leg remains stationary; foot remains stationary. 53 | pat on back;head tilts slightly forward; hand comes up from behind and pats; arm extends and bends at elbow; hip remains stationary; leg remains stationary; foot remains stationary. 54 | point finger;head turns to the direction where the finger is pointing; hand extends the index finger and points to the direction; arm extends the arm and points the finger; hip keeps still; leg keeps still; foot keeps still. 55 | hugging; head tilts to the side to rest on shoulder; hand wraps around waist; arm hugs close; hip presses against hip; leg crosses in front of leg; foot steps on foot. 56 | giving object; nods head slightly and extends arm out; holding the object in hand; arm supports hand; hip moves slightly forward; leg supports body;foot remains stationary. 57 | touch pocket;head tilts slightly forward; hand reaches into pocket; arm extends forward; hip remains stationary; leg remains stationary; foot remains stationary. 58 | shaking hands;head nods slightly; hand grasps other person's hand firmly and shakes it up and down; arm extends from shoulder; hip remains stationary; leg remains stationary; foot remains stationary. 59 | walking towards;head looking forward; hand swinging at sides; arm swinging at sides; hip swinging at sides; leg moving forward; foot moving forward 60 | walking apart; head is looking forward; hand is swinging by side; arm is swinging by side; hip is swinging to the opposite direction of leg; leg is taking a step forward; foot is pushing off the ground. 61 | put on headphone;head moves slightly forward as he brings the headphone up to ear;hand moves up to head, holding the headphone in place;arm moves up to head;hip moves slightly forward;leg moves slightly forward;foot stand still. 62 | take off headphone;head moves head slightly to the side so that he can see the headphone jack; hand moves to the headphone jack and unplugs the headphones; arm moves the headphones away from head; hip moves to the side; leg moves to the side; foot moves to the side. 63 | shoot at basket; turns head to look at the basket; hold the object; extends arm to shoot the ball; shifts hip to line up the shot; leg bends to give power to the shot; foot remains stationary. 64 | bounce ball;head moves slightly forward and backward; hand holds the ball; arm is slightly bent; hip is slightly bent; leg is slightly bent; foot is slightly bent. 65 | tennis bat swing;head is tilted back; hand right hand is holding the tennis bat; arm right arm is swinging the tennis bat; hips are rotated to the right; right leg is extended to the right; right foot is pointing to the right. 66 | juggle table tennis ball;head tilts slightly to the left; hand moves quickly to the left to catch the ball; arm extends to the left to catch the ball; hip remains stationary; leg remains stationary; foot remains stationary 67 | hush;head he lowers head; hand he raises index finger to lips; arm is bent at the elbow; hip is not involved in the action; leg is not involved in the action; foot is not involved in the action. 68 | flick hair;head flick hair; hand hold head; arm move head; hip stay still; leg stay still; foot stay still. 69 | thumb up;head tilts slightly back; hand raises up; arm bends at elbow; hip doesn't move; leg doesn't move; foot doesn't move. 70 | thumb down;head tilts slightly forward; hand bends at the wrist so that the thumb points down; arm hangs straight down; hip remains stationary; leg remains stationary; foot remains stationary. 71 | make OK sign;head tilts slightly to the side; hand forms a circle with the thumb and first two fingers; arm extends straight out from the shoulder; hip remains at a neutral position; leg remains at a neutral position; foot remains at a neutral position. 72 | make victory sign;head tilts slightly forward; hand forms a V shape with the index and middle fingers; arm extends fully; hip remains stationary; leg remains stationary; foot remains stationary. 73 | staple book;head tilts slightly forward; hand holds the book in place; arm supports the book; hip keeps the body stable; leg keeps the body stable; foot keeps the body stable. 74 | counting money;head looking down at the money in hand; hand holding a stack of bills; arm supporting hand; hip keeping body upright; leg keeping body balanced; foot keeping body stable. 75 | cutting nails;head looking down at hands; hand holding a nail clipper; arm keeping the hand steady; hip keeping the body steady; leg keeping the body steady; foot keeping the body steady. 76 | cutting paper;head looking down at the paper; hand holding the scissors; arm moving the scissors back and forth; hip staying in place; leg staying in place; foot staying in place. 77 | snap fingers;head turns slightly to the side; hand moves quickly to meet the other hand in the middle and create a snapping sound; arm remains at the side; hip remains stationary; leg remains stationary; foot remains stationary. 78 | open bottle;head tilts back slightly; hand grasps the neck of the bottle; arm extends the arm holding the bottle; hip remains stationary; leg remains stationary; foot remains stationary. 79 | sniff/smell;head tilts slightly forward; hand brings object close to nose; arm supports hand; hip remains stationary; leg remains stationary; foot remains stationary. 80 | squat down;head tilts slightly forward; hand supports body weight on thighs; arm hangs loosely at sides; hip lowers until thighs are parallel to ground; leg bends at knees; foot remains flat on ground. 81 | toss a coin;head turns to the side to watch the coin; hand opens and closes to release the coin; arm extends forward to release the coin; hip remains stationary; leg remains stationary; foot remains stationary. 82 | fold paper;head looks down at the paper; hand grasps the paper with their fingers; arm is bent at the elbow; hip is not moving; leg is not moving; foot is not moving. 83 | ball up paper;head tilts slightly forward; hand grasps paper; arm bends at elbow; hip remains stationary; leg remains stationary; foot remains stationary 84 | play magic cube;head is looking at the cube in hand; hand is holding the cube in hand and moving it around; arm is moving the cube around; hip is not moving; leg is not moving; foot is not moving. 85 | apply cream on face;head turns to the side; hand reaches up; arm extends; hip stays stationary; leg stays stationary; foot stays stationary. 86 | apply cream on hand;head none; hand takes some cream in hand and rubs it on skin; arm moves as he rubs the cream on hand; hip none; leg none; foot none. 87 | put on bag;head looks down at bag, then bends down to pick it up; hand grasps the bag; arm goes down to hand; hip moves down as he bends down; leg bends as he moves down; foot moves with leg as he bends down. 88 | take off bag;head looks down at the bag; hand reaches down and grabs the bag's handles; arm lifts the bag off the ground; hip shifts the bag's weight to other hand; leg steps forward; foot comes down on the ground. 89 | put object into bag;head looking at object; hand holding object; arm moving object towards bag; hip keeping body stable; leg keeping body stable; foot keeping body stable. 90 | take object out of bag;head is tilted slightly forward as he looks down into the bag; hand reaches into the bag and grasps the object; arm pulls the object out of the bag; hip is slightly elevated; leg is slightly bent; foot is planted firmly on the ground. 91 | open a box;head tilts slightly forward; hand reaches out and grasps the edge of the lid; arm extends forward; hip remains stationary; leg remains stationary; foot remains stationary. 92 | move heavy objects;head turns to look at the object; hand reaches out and grasps the object; arm pulls the object towards the body; hip moves forward to create leverage; leg pushes against the ground to create power; foot stabilizes the body. 93 | shake fist;head shake head slightly; hand grip tightly and shake up and down; arm move up and down at the elbow; hip stay in place; leg stay in place; foot stay in place. 94 | throw up cap/hat;head is tilted back and eyes are closed; hand is holding the hat at the brim; arm is extended straight up; hip is not involved in the action; leg is not involved in the action; foot is not involved in the action. 95 | capitulate; head nods in surrender;hand raises in the air; arm hangs limply at side;hip sags; leg buckles;foot drags. 96 | cross arms;head tilts slightly; hand grasps the right arm just above the elbow with the left hand; both arms are now bent at the elbow and held close to the body; hip remains at a neutral position; leg both legs remain straight; foot both feet remain flat on the ground. 97 | arm circles;head turns to look at the arm; hand grasps the arm at the elbow; arm circles around the body; hip remains stationary; leg remains stationary; foot remains stationary. 98 | arm swings;head turns to look at the arm; hand grips the arm; arm swings in a circle; hip remains stationary; leg remains stationary; foot remains stationary. 99 | run on the spot;head nodding up and down; hand swinging back and forth; arm pumping back and forth; hip thrusting back and forth; leg kicking back and forth; foot stomping back and forth. 100 | butt kicks;head turns to look at target; hand reaches back and grasps heel of leg to be extended; arm pulls leg back so that thigh is close to buttocks; hip extends leg forcefully; leg extends and then flexes at knee; foot makes contact with target and then returns to ground. 101 | cross toe touch;head tilts slightly forward; hand reaches out and down to touch the opposite foot's toes; arm extends out and down; hip remains stationary; leg crosses in front of the body and bends at the knee to touch the hand to the foot; foot remains stationary 102 | side kick;head turns to the side; hand grabs the ankle; arm pulls the leg back; hip lifts the leg up; leg kicks out to the side; foot makes contact with the target 103 | yawn;head tilts back; hand comes to mouth; arm supports head; hip doesn't move; leg doesn't move; foot doesn't move. 104 | stretch oneself; tilts head back; raises arms above head; stretches arms; raises hips; stretches legs; foot points toes. 105 | blow nose;head tilts forward slightly; hand brings tissue to nose; arm supports head; hip doesn't move; leg doesn't move; foot doesn't move. 106 | hit with object; head hit; hand grip; arm swing; hip rotate; leg step; foot stomp 107 | wield knife; turns head to look at the knife; grasps it in hand; arm raises it up; hip twists so that he is facing the direction of the knife; he steps forward with leg; planting foot firmly on the ground. 108 | knock over;head hit the ground first; hand hit the ground next to the head; arm hit the ground next to the hand; hip hit the ground next to the arm; leg hit the ground next to the hip; foot hit the ground next to the leg. 109 | grab stuff;head tilts slightly forward; hand reaches out and grasps object; arm extends forward; hip remains stationary; leg remains stationary; foot remains stationary. 110 | shoot with gun;head tilts slightly back; hand grips gun with fingers; arm supports gun with arm; hip stands with one foot in front of the other; leg bends knees; foot stands on toes 111 | step on foot;head is upright; hands are at sides; arms are at sides; hips are level; legs are straight; feet are flat on the ground. 112 | high-five;head tilts slightly to the side; hand moves up to meet the other person's hand in the air; arm extends up to meet the other person's hand; hip remains stationary; leg remains stationary; foot remains stationary. 113 | cheers and drink;head tilts back slightly; hand brings glass to lips; arm supports glass; hip remains stationary; leg remains stationary; foot remains stationary. 114 | carry object;head tilts slightly forward; hand grasps object; arm supports object; hip supports object; leg supports object; foot supports object. 115 | take a photo;head tilts slightly to the left; hand holds the camera up to the right eye; arm extends straight out to the right; hip shifts slightly to the right; leg remains stationary; foot remains stationary 116 | follow; nods head; points hand; extends arm; sways hip; bends leg; taps foot. 117 | whisper;head he brings head close to the person who is talking to; raises hand to mouth; extends arm to the person who is talking to; stands with hip at a 90 degree angle; stands with legs shoulder width apart; stands with feet shoulder width apart. 118 | exchange things; nods head in greeting; takes the other person's hand and shakes it; wraps arm around the other person in a hug; sways hips back and forth in a friendly manner; leg stands still; stands still on both feet. 119 | support somebody; head tilts slightly forward; hand grasps the other person's arm just above the elbow; arm supports the other person's arm; hip stands upright; leg stands upright; foot stands flat on the ground. 120 | rock-paper-scissors; nods head; raises hand up; bends arm at the elbow; sways hip; stands still on leg; stands still on foot. -------------------------------------------------------------------------------- /test_main.py: -------------------------------------------------------------------------------- 1 | from test_config import * 2 | # from model import * 3 | from dataset import DataSet 4 | from test_logger import Log 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import random 11 | from math import pi, cos 12 | from tqdm import tqdm 13 | 14 | from module.gcn.st_gcn import Model 15 | from module.adapter import Adapter, Linear 16 | from KLLoss import KLLoss, KDLoss 17 | from tool import gen_label, create_logits, get_acc, create_sim_matrix, gen_label_from_text_sim, get_m_theta, get_acc_v2, get_acc_v3 18 | from sklearn.metrics import confusion_matrix 19 | 20 | def setup_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | setup_seed(0) 28 | 29 | # %% 30 | class Processor: 31 | 32 | @ex.capture 33 | def load_data(self, train_list, train_label, test_list, test_label, batch_size, language_path): 34 | self.dataset = dict() 35 | self.data_loader = dict() 36 | self.best_epoch = -1 37 | self.best_acc = -1 38 | self.dim_loss = -1 39 | self.test_acc = -1 40 | self.test_aug_acc = -1 41 | self.best_aug_acc = -1 42 | self.best_aug_epoch = -1 43 | 44 | self.full_language = np.load(language_path) 45 | self.full_language = torch.Tensor(self.full_language) 46 | self.full_language = self.full_language.cuda() 47 | self.dataset['train'] = DataSet(train_list, train_label) 48 | self.dataset['test'] = DataSet(test_list, test_label) 49 | 50 | self.data_loader['train'] = torch.utils.data.DataLoader( 51 | dataset=self.dataset['train'], 52 | batch_size=batch_size, 53 | num_workers=16, 54 | shuffle=True, 55 | drop_last=True) 56 | 57 | self.data_loader['test'] = torch.utils.data.DataLoader( 58 | dataset=self.dataset['test'], 59 | batch_size=64, 60 | num_workers=16, 61 | shuffle=False) 62 | 63 | def load_weights(self, model=None, weight_path=None): 64 | pretrained_dict = torch.load(weight_path) 65 | model.load_state_dict(pretrained_dict) 66 | 67 | def adjust_learning_rate(self,optimizer,current_epoch, max_epoch,lr_min=0,lr_max=0.1,warmup_epoch=15, loss_mode='step', step=[50, 80]): 68 | 69 | if current_epoch < warmup_epoch: 70 | lr = lr_max * current_epoch / warmup_epoch 71 | elif loss_mode == 'cos': 72 | lr = lr_min + (lr_max-lr_min)*(1 + cos(pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2 73 | elif loss_mode == 'step': 74 | lr = lr_max * (0.1 ** np.sum(current_epoch >= np.array(step))) 75 | else: 76 | raise Exception('Please check loss_mode!') 77 | 78 | for param_group in optimizer.param_groups: 79 | param_group['lr'] = lr 80 | # if i == 0: 81 | # param_group['lr'] = lr * 0.1 82 | # else: 83 | # param_group['lr'] = lr 84 | 85 | def layernorm(self, feature): 86 | 87 | num = feature.shape[0] 88 | mean = torch.mean(feature, dim=1).reshape(num, -1) 89 | var = torch.var(feature, dim=1).reshape(num, -1) 90 | out = (feature-mean) / torch.sqrt(var) 91 | 92 | return out 93 | 94 | @ex.capture 95 | def load_model(self,in_channels,hidden_channels,hidden_dim, 96 | dropout,graph_args,edge_importance_weighting, visual_size, language_size, weight_path, loss_type): 97 | self.encoder = Model(in_channels=in_channels, hidden_channels=hidden_channels, 98 | hidden_dim=hidden_dim,dropout=dropout, 99 | graph_args=graph_args, 100 | edge_importance_weighting=edge_importance_weighting, 101 | ) 102 | self.encoder = self.encoder.cuda() 103 | self.adapter = Linear().cuda() 104 | if loss_type == "kl" or loss_type == "klv2" or loss_type == "kl+cosface" or loss_type == "kl+sphereface" or "kl+margin": 105 | self.loss = KLLoss().cuda() 106 | elif loss_type == "mse": 107 | self.loss = nn.MSELoss().cuda() 108 | elif loss_type == "kl+mse": 109 | self.loss_kl = KLLoss().cuda() 110 | self.loss_mse = nn.MSELoss().cuda() 111 | elif loss_type == "kl+kd": 112 | self.loss = KLLoss().cuda() 113 | self.kd_loss = KDLoss().cuda() 114 | else: 115 | raise Exception('loss_type Error!') 116 | self.logit_scale = self.adapter.get_logit_scale() 117 | self.logit_scale_v2 = self.adapter.get_logit_scale_v2() 118 | 119 | # self.model = MI(visual_size, language_size).cuda() 120 | # print(weight_path) 121 | pretrained_dict = torch.load(weight_path) 122 | # print(pretrained_dict['encoder']) 123 | self.encoder.load_state_dict(pretrained_dict['encoder']) 124 | self.adapter.load_state_dict(pretrained_dict['adapter']) 125 | 126 | @ex.capture 127 | def load_optim(self, lr, epoch_num, weight_decay): 128 | # self.optimizer = torch.optim.Adam([ 129 | # {'params': self.encoder.parameters()}, 130 | # {'params': self.model.parameters()}], 131 | # lr=lr, 132 | # weight_decay=weight_decay, 133 | # ) 134 | # self.optimizer = torch.optim.Adam([ 135 | # {'params': self.encoder.parameters()}, 136 | # {'params': self.adapter.parameters()}], 137 | # lr=lr, 138 | # weight_decay=weight_decay, 139 | # ) 140 | self.optimizer = torch.optim.SGD([ 141 | {'params': self.encoder.parameters()}, 142 | {'params': self.adapter.parameters()}], 143 | lr=lr, 144 | weight_decay=weight_decay, 145 | momentum=0.9, 146 | nesterov=False 147 | ) 148 | # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 100) 149 | 150 | @ex.capture 151 | def optimize(self, epoch_num, DA): # print -> log.info 152 | self.log.info("main track") 153 | epoch = 0 154 | with torch.no_grad(): 155 | self.test_epoch(epoch=epoch) 156 | self.log.info("epoch [{}] test acc: {}".format(epoch,self.test_acc)) 157 | # self.log.info("epoch [{}] gets the best acc: {}".format(self.best_epoch,self.best_acc)) 158 | if DA: 159 | self.log.info("epoch [{}] DA test acc: {}".format(epoch,self.test_aug_acc)) 160 | # self.log.info("epoch [{}] gets the best DA acc: {}".format(self.best_aug_epoch,self.best_aug_acc)) 161 | # if epoch > 5: 162 | # self.log.info("epoch [{}] test acc: {}".format(epoch,self.test_acc)) 163 | # self.log.info("epoch [{}] gets the best acc: {}".format(self.best_epoch,self.best_acc)) 164 | # else: 165 | # self.log.info("epoch [{}] : warm up epoch.".format(epoch)) 166 | 167 | @ex.capture 168 | def train_epoch(self, epoch, lr, loss_mode, step, loss_type, alpha, beta, m): 169 | self.encoder.train() # eval -> train 170 | self.adapter.train() 171 | self.adjust_learning_rate(self.optimizer, current_epoch=epoch, max_epoch=100, lr_max=lr, warmup_epoch=5, loss_mode=loss_mode, step=step) 172 | running_loss = [] 173 | loader = self.data_loader['train'] 174 | for data, label in tqdm(loader): 175 | data = data.type(torch.FloatTensor).cuda() 176 | # print(data.shape) #128,3,50,25,2 177 | # label = label.type(torch.LongTensor).cuda() 178 | label_g = gen_label(label) 179 | label = label.type(torch.LongTensor).cuda() 180 | # print(label.shape) # 128 181 | # print(label) # int 182 | seen_language = self.full_language[label] # 128, 768 183 | # print(seen_language.shape) 184 | 185 | feat = self.encoder(data) 186 | skleton_feat = self.adapter(feat) 187 | if loss_type == "kl": 188 | logits_per_skl, logits_per_text = create_logits(skleton_feat, seen_language, self.logit_scale, exp=True) 189 | ground_truth = torch.tensor(label_g, dtype=skleton_feat.dtype).cuda() 190 | # ground_truth = gen_label_from_text_sim(seen_language) 191 | loss_skls = self.loss(logits_per_skl, ground_truth) 192 | loss_texts = self.loss(logits_per_text, ground_truth) 193 | loss = (loss_skls + loss_texts) / 2 194 | elif loss_type == "kl+margin": 195 | logits_per_skl, logits_per_text = create_logits(skleton_feat, seen_language, self.logit_scale, exp=True) 196 | ground_truth = torch.tensor(label_g, dtype=skleton_feat.dtype).cuda() 197 | ones = torch.ones_like(ground_truth).cuda() 198 | ones -= m 199 | logits_per_skl = torch.where(ones self.best_acc: 313 | self.best_acc = acc 314 | self.best_epoch = epoch 315 | # self.save_model() 316 | # y_true = np.array(y_true) 317 | # y_pred = np.array(y_pred) 318 | # self.log.info('true_labels:{}'.format(y_true)) 319 | # self.log.info('predicted_labels:{}'.format(y_pred)) 320 | self.log.info('unseen_label:{}'.format(unseen_label)) 321 | conf_matrix = confusion_matrix(y_true, y_pred, labels=unseen_label) 322 | self.log.info('Confusion Matrix:{}{}'.format('\n',conf_matrix)) 323 | # np.save("y_true_3.npy",y_true) 324 | # np.save("y_pred_3.npy",y_pred) 325 | # print("save ok!") 326 | self.test_acc = acc 327 | 328 | if DA: 329 | ent_all = torch.cat(ent_list) 330 | feat_all = torch.cat(feat_list) 331 | old_pred_all = torch.cat(old_pred_list) 332 | mean_old_ent = torch.mean(ent_all) 333 | z_list = [] 334 | for i in range(len(unseen_label)): 335 | mask = old_pred_all == i 336 | class_support_set = feat_all[mask] 337 | class_ent = ent_all[mask] 338 | class_len = class_ent.shape[0] 339 | if int(class_len*support_factor) < 1: 340 | z = self.full_language[unseen_label[i:i+1]] 341 | else: 342 | _, indices = torch.topk(-class_ent, int(class_len*support_factor)) 343 | z = torch.mean(class_support_set[indices], dim=0, keepdim=True) 344 | z_list.append(z) 345 | 346 | z_tensor = torch.cat(z_list) 347 | aug_acc_list = [] 348 | DA_ent_list = [] 349 | for data, label in tqdm(loader): 350 | # y_t = label.numpy().tolist() 351 | # y_true += y_t 352 | 353 | data = data.type(torch.FloatTensor).cuda() 354 | label = label.type(torch.LongTensor).cuda() 355 | unseen_language = z_tensor 356 | # inference 357 | feature = self.encoder(data) 358 | feature = self.adapter(feature) 359 | # acc_batch, pred = get_acc(feature, unseen_language, unseen_label, label) 360 | # acc_batch, pred = get_acc(feature, unseen_language, unseen_label, label) 361 | acc_batch, pred, ent = get_acc_v3(feature, unseen_language, unseen_label, label) 362 | DA_ent_list.append(ent) 363 | 364 | y_p = pred.cpu().numpy().tolist() 365 | y_pred_DA += y_p 366 | 367 | aug_acc_list.append(acc_batch) 368 | mean_DA_ent = torch.mean(torch.cat(DA_ent_list)) 369 | 370 | aug_acc = torch.tensor(aug_acc_list).mean() 371 | if aug_acc > self.best_aug_acc: 372 | self.best_aug_acc = aug_acc 373 | self.best_aug_epoch = epoch 374 | # self.log.info('true_labels:{}'.format(y_true)) 375 | # self.log.info('DA_predicted_labels:{}'.format(y_pred_DA)) 376 | self.log.info('Unseen label:{}'.format(unseen_label)) 377 | conf_matrix = confusion_matrix(y_true, y_pred_DA, labels=unseen_label) 378 | self.log.info('DA Confusion Matrix:{}{}'.format('\n',conf_matrix)) 379 | self.log.info('Mean old entropy:{}'.format(mean_old_ent)) 380 | self.log.info('Mean DA entropy:{}'.format(mean_DA_ent)) 381 | self.test_aug_acc = aug_acc 382 | 383 | 384 | 385 | def initialize(self): 386 | self.load_data() 387 | self.load_model() 388 | self.load_optim() 389 | self.log = Log() 390 | 391 | @ex.capture 392 | def save_model(self, save_path): 393 | torch.save({'encoder':self.encoder.state_dict(), 'adapter':self.adapter.state_dict()}, save_path) 394 | 395 | def start(self): 396 | self.initialize() 397 | self.optimize() 398 | # self.save_model() 399 | 400 | class SotaProcessor: 401 | 402 | @ex.capture 403 | def load_data(self, sota_train_list, sota_train_label, 404 | sota_test_list, sota_test_label, batch_size, language_path): 405 | self.dataset = dict() 406 | self.data_loader = dict() 407 | self.best_epoch = -1 408 | self.best_acc = -1 409 | self.dim_loss = -1 410 | self.test_acc = -1 411 | 412 | self.full_language = np.load(language_path) 413 | self.full_language = torch.Tensor(self.full_language) 414 | self.full_language = F.normalize(self.full_language,dim=-1) 415 | self.full_language = self.full_language.cuda() 416 | 417 | self.dataset['train'] = DataSet(sota_train_list, sota_train_label) 418 | self.dataset['test'] = DataSet(sota_test_list, sota_test_label) 419 | 420 | self.data_loader['train'] = torch.utils.data.DataLoader( 421 | dataset=self.dataset['train'], 422 | batch_size=batch_size, 423 | num_workers=16, 424 | shuffle=True) 425 | 426 | self.data_loader['test'] = torch.utils.data.DataLoader( 427 | dataset=self.dataset['test'], 428 | batch_size=64, 429 | num_workers=16, 430 | shuffle=False) 431 | 432 | def adjust_learning_rate(self,optimizer,current_epoch, max_epoch,lr_min=0,lr_max=0.1,warmup_epoch=15): 433 | 434 | if current_epoch < warmup_epoch: 435 | lr = lr_max * current_epoch / warmup_epoch 436 | else: 437 | lr = lr_min + (lr_max-lr_min)*(1 + cos(pi * (current_epoch - warmup_epoch) / (max_epoch - warmup_epoch))) / 2 438 | for param_group in optimizer.param_groups: 439 | param_group['lr'] = lr 440 | 441 | @ex.capture 442 | def load_model(self,in_channels,hidden_channels,hidden_dim, 443 | dropout,graph_args,edge_importance_weighting, visual_size, language_size, weight_path): 444 | self.model = MI(visual_size, language_size).cuda() 445 | 446 | @ex.capture 447 | def load_optim(self, lr, epoch_num, weight_decay): 448 | self.optimizer = torch.optim.Adam([ 449 | {'params': self.model.parameters()}], 450 | lr=lr, 451 | weight_decay=weight_decay, 452 | ) 453 | 454 | @ex.capture 455 | def optimize(self, epoch_num): 456 | print("sota track") 457 | for epoch in range(epoch_num): 458 | self.train_epoch(epoch) 459 | with torch.no_grad(): 460 | self.test_epoch(epoch=epoch) 461 | print("epoch [{}] dim loss: {}".format(epoch,self.dim_loss)) 462 | print("epoch [{}] test acc: {}".format(epoch,self.test_acc)) 463 | print("epoch [{}] gets the best acc: {}".format(self.best_epoch,self.best_acc)) 464 | 465 | @ex.capture 466 | def train_epoch(self, epoch, lr): 467 | self.model.train() 468 | self.adjust_learning_rate(self.optimizer, current_epoch=epoch, max_epoch=100, lr_max=lr) 469 | running_loss = [] 470 | loader = self.data_loader['train'] 471 | for data, label in tqdm(loader): 472 | 473 | data = data.type(torch.FloatTensor).cuda() 474 | label = label.type(torch.LongTensor).cuda() 475 | seen_language = self.full_language[label] 476 | 477 | # Global 478 | feat0 = data.clone() 479 | dim = self.model(feat0, seen_language) 480 | 481 | # Loss 482 | loss = -dim 483 | 484 | running_loss.append(loss) 485 | self.optimizer.zero_grad() 486 | loss.backward() 487 | self.optimizer.step() 488 | 489 | running_loss = torch.tensor(running_loss) 490 | self.dim_loss = running_loss.mean().item() 491 | 492 | @ex.capture 493 | def test_epoch(self, sota_unseen, epoch): 494 | self.model.eval() 495 | 496 | total = 0 497 | correct = 0 498 | loader = self.data_loader['test'] 499 | acc_list = [] 500 | for data, label in tqdm(loader): 501 | feature = data.type(torch.FloatTensor).cuda() 502 | label = label.type(torch.LongTensor).cuda() 503 | unseen_language = self.full_language[sota_unseen] 504 | # inference 505 | acc_batch = self.model.get_acc(feature, unseen_language, label, sota_unseen) 506 | acc_list.append(acc_batch) 507 | acc_list = torch.tensor(acc_list) 508 | acc = acc_list.mean() 509 | if acc > self.best_acc: 510 | self.best_acc = acc 511 | self.best_epoch = epoch 512 | self.test_acc = acc 513 | 514 | def initialize(self): 515 | self.load_data() 516 | self.load_model() 517 | self.load_optim() 518 | self.log = Log() 519 | 520 | def start(self): 521 | self.initialize() 522 | self.optimize() 523 | 524 | # %% 525 | @ex.automain 526 | def main(track): 527 | if "sota" in track: 528 | p = SotaProcessor() 529 | elif "main" in track: 530 | p = Processor() 531 | p.start() 532 | --------------------------------------------------------------------------------