├── requirements.txt ├── .gitignore ├── find_big_frames.py ├── extract_classname.py ├── find_frame_error.py ├── cp_videos.py ├── openpose.py ├── rm_empty_frame.py ├── video.py ├── Readme.md ├── utils.py ├── skating_gendata.py ├── skating_convert.py └── feeder_skating.py /requirements.txt: -------------------------------------------------------------------------------- 1 | sk-video 2 | torch==1.0.1 3 | torchvision==0.2.2 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.out 3 | *.csv 4 | tmp* 5 | *.swp 6 | __pycache__ -------------------------------------------------------------------------------- /find_big_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def name_loader(filename): 4 | with open(filename) as f: 5 | for line in f.readlines(): 6 | info=line.strip() 7 | category, video_name, frame_num, label = info.split(",") 8 | yield category, video_name, frame_num, label 9 | 10 | dataset_path = "/skating2.0/skating63/" 11 | 12 | parts = ['label_train_skating63', 'label_val_skating63'] 13 | for part in parts: 14 | csvfile = os.path.join(dataset_path, "{}.csv".format(part)) 15 | for category, video_name, frame_num, label in name_loader(csvfile): 16 | video_path = os.path.join(dataset_path, category, video_name + ".mp4") 17 | if int(frame_num) > 2000: 18 | print("big frame:", video_name, frame_num) -------------------------------------------------------------------------------- /extract_classname.py: -------------------------------------------------------------------------------- 1 | 2 | #csv_path = "/skating2.0/skating2_data/train_skating63.csv" 3 | #output_csv = "/skating2.0/skeleton_script/label_train_skating63.csv" 4 | csv_path = "/skating2.0/skating2_data/val_skating63.csv" 5 | output_csv = "/skating2.0/skeleton_script/label_val_skating63.csv" 6 | with open(csv_path, "r") as f, open(output_csv, "w") as outfile: 7 | for line in f.readlines(): 8 | line = line.strip() 9 | print(line) 10 | full_file_name, frame_num, label = line.split(" ") 11 | file_name = full_file_name.split("/")[-1] 12 | class_name = file_name.split("_n")[0] 13 | save_line = [class_name, file_name, frame_num, label] 14 | save_line = ",".join(save_line) + "\n" 15 | outfile.write(save_line) 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /find_frame_error.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import name_loader 3 | import video 4 | dataset_path = "/skating2.0/skating63/" 5 | 6 | with open("frame_err.log","w") as log_file: 7 | parts = ['label_train_skating63', 'label_val_skating63'] 8 | for part in parts: 9 | csvfile = os.path.join(dataset_path, "{}.csv".format(part)) 10 | for category, video_name, label in name_loader(csvfile): 11 | video_path = os.path.join(dataset_path, category, video_name + ".mp4") 12 | try: 13 | print("read video: {}".format(video_path)) 14 | video_obj = video.get_video_frames(video_path) 15 | except RuntimeError as e: 16 | print("{} Runtime Error: {}".format(video_path, str(e))) 17 | msg = video_name + " ### " + str(e) + "\n" 18 | log_file.write(msg) 19 | -------------------------------------------------------------------------------- /cp_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | video_src_dir = "/share/skating2.0/removed_empty_frame_videos" 5 | video_dst_dir = "/share_4t/skating63" 6 | 7 | record_file = "/share/skating2.0/skeleton_script/frame_err.log" 8 | 9 | # if not os.path.exists(video_dst_dir): 10 | # os.makedirs(video_dst_dir) 11 | # else: 12 | # os.system("rm -r " + video_dst_dir) 13 | # os.makedirs(video_dst_dir) 14 | 15 | with open(record_file, 'r') as f: 16 | for line in f.readlines(): 17 | video_name = line.strip().split(" ")[0] 18 | class_name = video_name.split("_n")[0] 19 | # src_dir/class_name/video_name.mp4 20 | video_src_path = os.path.join(video_src_dir, video_name + ".mp4") 21 | video_dst_path = os.path.join(video_dst_dir, class_name, video_name + ".mp4") 22 | command = "cp " + video_src_path + " " + video_dst_path 23 | sys.stdout.write(command+"\n") 24 | os.system(command) -------------------------------------------------------------------------------- /openpose.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | 4 | from utils import LOGGER 5 | 6 | def json_pack(snippets_dir, video_name, frame_width, frame_height, label='unknown', label_index=-1): 7 | sequence_info = [] 8 | p = Path(snippets_dir) 9 | LOGGER.info(path) 10 | for path in p.glob(video_name+'*.json'): 11 | json_path = str(path) 12 | LOGGER.info(path) 13 | frame_id = int(path.stem.split('_')[-2]) 14 | frame_data = {'frame_index': frame_id} 15 | data = json.load(open(json_path)) 16 | skeletons = [] 17 | for person in data['people']: 18 | score, coordinates = [], [] 19 | skeleton = {} 20 | keypoints = person['pose_keypoints_2d'] 21 | for i in range(0, len(keypoints), 3): 22 | coordinates += [keypoints[i]/frame_width, keypoints[i + 1]/frame_height] 23 | score += [keypoints[i + 2]] 24 | skeleton['pose'] = coordinates 25 | skeleton['score'] = score 26 | skeletons += [skeleton] 27 | frame_data['skeleton'] = skeletons 28 | sequence_info += [frame_data] 29 | 30 | video_info = dict() 31 | video_info['data'] = sequence_info 32 | video_info['label'] = label 33 | video_info['label_index'] = label_index 34 | 35 | return video_info -------------------------------------------------------------------------------- /rm_empty_frame.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | video_src_dir = "/skating2.0/extracted_videos" 5 | video_dst_dir = "/skating2.0/removed_empty_frame_videos" 6 | 7 | if os.path.exists(video_dst_dir): 8 | os.system("rm -rf "+video_dst_dir) 9 | os.makedirs(video_dst_dir) 10 | 11 | count = 0 12 | for video_name in os.listdir(video_src_dir): 13 | count += 1 14 | print(video_name) 15 | if video_name.split(".")[-1] == "mp4": 16 | video_path = os.path.join(video_src_dir, video_name) 17 | camera = cv2.VideoCapture(video_path) 18 | fps = camera.get(5) 19 | fps = int(fps) 20 | width = camera.get(3) 21 | height = camera.get(4) 22 | print("video frame: ", camera.get(7)) 23 | size = (int(width), int(height)) 24 | 25 | video_out_path = os.path.join(video_dst_dir, video_name) 26 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 27 | writer = cv2.VideoWriter(video_out_path, fourcc, 30, size) 28 | print(video_out_path) 29 | ret, img = camera.read() 30 | frame_count = 0 31 | while ret: 32 | # if img.empty(): 33 | # print("empty frame:", frame_count) 34 | frame_count += 1 35 | writer.write(img) 36 | ret, img = camera.read() 37 | 38 | print("video real frame: ", frame_count) 39 | camera.release() 40 | writer.release() 41 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import skvideo.io 2 | import numpy as np 3 | import cv2 4 | 5 | def video_info_parsing(video_info, num_person_in=5, num_person_out=2): 6 | data_numpy = np.zeros((3, len(video_info['data']), 18, num_person_in)) 7 | for frame_info in video_info['data']: 8 | frame_index = frame_info['frame_index'] 9 | for m, skeleton_info in enumerate(frame_info["skeleton"]): 10 | if m >= num_person_in: 11 | break 12 | pose = skeleton_info['pose'] 13 | score = skeleton_info['score'] 14 | data_numpy[0, frame_index, :, m] = pose[0::2] 15 | data_numpy[1, frame_index, :, m] = pose[1::2] 16 | data_numpy[2, frame_index, :, m] = score 17 | 18 | # centralization 19 | data_numpy[0:2] = data_numpy[0:2] - 0.5 20 | data_numpy[0][data_numpy[2] == 0] = 0 21 | data_numpy[1][data_numpy[2] == 0] = 0 22 | 23 | sort_index = (-data_numpy[2, :, :, :].sum(axis=1)).argsort(axis=1) 24 | for t, s in enumerate(sort_index): 25 | data_numpy[:, t, :, :] = data_numpy[:, t, :, s].transpose((1, 2, 26 | 0)) 27 | data_numpy = data_numpy[:, :, :, :num_person_out] 28 | 29 | label = video_info['label_index'] 30 | return data_numpy, label 31 | 32 | def get_video_frames(video_path): 33 | vread = skvideo.io.vread(video_path) 34 | video = [] 35 | for frame in vread: 36 | video.append(frame) 37 | return video 38 | 39 | def video_play(video_path, fps=30): 40 | cap = cv2.VideoCapture(video_path) 41 | 42 | while(cap.isOpened()): 43 | ret, frame = cap.read() 44 | 45 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 46 | 47 | cv2.imshow('frame',gray) 48 | if cv2.waitKey(1000/fps) & 0xFF == ord('q'): 49 | break 50 | 51 | cap.release() 52 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Skeleton Scripts 2 | 3 | ## 预先准备 4 | 5 | ### 安装openpose的docker镜像 6 | docker版本: 19.03.12 7 | 如果版本不对,加载容器的时候会出错 8 | docker仓库:https://hub.docker.com/r/cwaffles/openpose 9 | docker常用指令:https://www.runoob.com/docker/docker-tutorial.html 10 | **step1:安装docker** 11 | 使用官方安装脚本自动安装: 12 | curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun 13 | 手动安装 14 | 参考链接:https://www.runoob.com/docker/ubuntu-docker-install.html 15 | 16 | **step2:安装NVIDIA CONTAINER RUNTIME** 17 | 新建一个脚本文件 vim nvidia.sh 填入如下内容: 18 | sudo curl -s -L https://nvidia.github.io/nvidia-container-runtime/gpgkey | \ 19 | sudo apt-key add - 20 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 21 | sudo curl -s -L https://nvidia.github.io/nvidia-container-runtime/$distribution/nvidia-container-runtime.list | \ 22 | sudo tee /etc/apt/sources.list.d/nvidia-container-runtime.list 23 | sudo apt-get update 24 | 25 | 执行脚本 sh nvidia.sh 26 | 27 | 安装 nvidia-container-runtime: 28 | sudo apt-get install nvidia-container-runtime 29 | 30 | **step3:创建用户组,方便授权** 31 | 32 | 如果没有sudo权限,可以创建dockers权限组 33 | sudo groupadd docker 34 | sudo gpasswd -a ${USER} docker 35 | sudo service docker restart 36 | newgrp - docker //将当前用户以docker用户组的身份再次登录系统 37 | 38 | 通过cat /etc/group可以查看用户组信息 39 | 40 | **step4:下载镜像,对应cuda10.0,cudnn7.0** 41 | 42 | docker pull cwaffles/openpose 43 | 通过镜像创建容器 44 | 45 | sudo docker run --gpus all --name openpose -it cwaffles/openpose:latest /bin/bash 46 | 进入容器内部(创建成功会自动进入容器) 47 | 48 | docker exec -it openpose /bin/bash 49 | 注:还可以使用以下命令一次删除所有停止的容器。docker rm $(docker ps -a -q) 50 | 51 | **step5:测试openpose的demo** 52 | 53 | #only body 54 | ./build/examples/openpose/openpose.bin --video examples/media/video.avi --write_json output/ --display 0 --render_pose 0 55 | #Body + face + hands 56 | ./build/examples/openpose/openpose.bin --video examples/media/video.avi --write_json output/ --display 0 --render_pose 0 --face --hand 57 | ### 将本仓库和视频文件以数据卷方式挂载到openpose docker容器中 58 | 接下来可以看一下docker容器的共享文件夹来拷贝数据集 59 | 60 | docker run -it -v /宿主机绝对路径目录: /容器内目录 镜像名 61 | docker run -idt -v --name openpose /home/$USER/share:/openpose/share cwaffles/openpose:latest //后台运行 62 | docker exec -it openpose /bin/bash //进入容器 63 | 64 | ### 安装python相关包 65 | 66 | pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple 67 | 68 | 69 | ## 主要脚本功能简介 70 | 71 | ### skating_convert.py: 72 | 73 | 调用openpose提取视频中的骨骼点并每个视频的结果打包成json文件 74 | 75 | ### skating_gendata.py 76 | 77 | 将json文件整理为npy文件并保存 78 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | import json 4 | 5 | LOGGER = logging.getLogger(__name__) 6 | LOGGER.setLevel(level = logging.INFO) 7 | handler = logging.FileHandler("convert.log") 8 | handler.setLevel(logging.INFO) 9 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | handler.setFormatter(formatter) 11 | LOGGER.addHandler(handler) 12 | 13 | toolbar_width = 30 14 | 15 | def print_toolbar(rate, annotation=''): 16 | # setup toolbar 17 | LOGGER.info("{}[".format(annotation)) 18 | # for i in range(toolbar_width): 19 | # if i * 1.0 / toolbar_width > rate: 20 | # LOGGER.info(' ') 21 | # else: 22 | # LOGGER.info('-') 23 | # LOGGER.info(']\r') 24 | 25 | 26 | def end_toolbar(): 27 | LOGGER.info("\n") 28 | 29 | def count_lines(filename): 30 | with open(filename) as f: 31 | count=-1 32 | for count,_ in enumerate(f): 33 | pass 34 | count+=1 35 | return count 36 | 37 | 38 | def name_loader(filename): 39 | with open(filename) as f: 40 | for line in f.readlines(): 41 | info=line.strip() 42 | category, video_name, frame_num, label = info.split(",") 43 | yield category, video_name, label 44 | 45 | 46 | def json_pack(snippets_dir, video_name, frame_width, frame_height, label='unknown', label_index=-1): 47 | sequence_info = [] 48 | p = Path(snippets_dir) 49 | for path in p.glob(video_name+'*.json'): 50 | json_path = str(path) 51 | # LOGGER.info(path) 52 | frame_id = int(path.stem.split('_')[-2]) 53 | frame_data = {'frame_index': frame_id} 54 | data = json.load(open(json_path)) 55 | skeletons = [] 56 | for person in data['people']: 57 | score, coordinates = [], [] 58 | skeleton = {} 59 | keypoints = person['pose_keypoints_2d'] 60 | for i in range(0, len(keypoints), 3): 61 | coordinates += [keypoints[i]/frame_width, keypoints[i + 1]/frame_height] 62 | score += [keypoints[i + 2]] 63 | skeleton['pose'] = coordinates 64 | skeleton['score'] = score 65 | skeletons += [skeleton] 66 | frame_data['skeleton'] = skeletons 67 | sequence_info += [frame_data] 68 | 69 | video_info = dict() 70 | video_info['data'] = sequence_info 71 | video_info['label'] = label 72 | video_info['label_index'] = label_index 73 | 74 | return video_info -------------------------------------------------------------------------------- /skating_gendata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import argparse 5 | 6 | import numpy as np 7 | from numpy.lib.format import open_memmap 8 | 9 | from feeder_skating import Feeder_skating 10 | 11 | 12 | toolbar_width = 30 13 | 14 | def print_toolbar(rate, annotation=''): 15 | # setup toolbar 16 | sys.stdout.write("{}[".format(annotation)) 17 | for i in range(toolbar_width): 18 | if i * 1.0 / toolbar_width > rate: 19 | sys.stdout.write(' ') 20 | else: 21 | sys.stdout.write('-') 22 | sys.stdout.flush() 23 | sys.stdout.write(']\r') 24 | 25 | 26 | def end_toolbar(): 27 | sys.stdout.write("\n") 28 | 29 | 30 | def gendata( 31 | data_path, 32 | label_path, 33 | data_out_path, 34 | label_out_path, 35 | num_person_in=5, #observe the first 5 persons 36 | num_person_out=1, #then choose 2 persons with the highest score 37 | max_frame=2500, 38 | joins_count=25): 39 | 40 | feeder = Feeder_skating( 41 | data_path=data_path, 42 | label_path=label_path, 43 | num_person_in=num_person_in, 44 | num_person_out=num_person_out, 45 | joins_count=joins_count, 46 | debug=False, 47 | window_size=max_frame) 48 | 49 | sample_name = feeder.sample_name 50 | sample_label = [] 51 | 52 | fp = open_memmap( 53 | data_out_path, 54 | dtype='float32', 55 | mode='w+', 56 | shape=(len(sample_name), 3, max_frame, joins_count, num_person_out)) 57 | 58 | for i, s in enumerate(sample_name): 59 | data, label = feeder[i] 60 | print_toolbar(i * 1.0 / len(sample_name), 61 | '({:>5}/{:<5}) Processing data: '.format( 62 | i + 1, len(sample_name))) 63 | fp[i, :, 0:data.shape[1], :, :] = data 64 | sample_label.append(label) 65 | 66 | with open(label_out_path, 'wb') as f: 67 | pickle.dump((sample_name, list(sample_label)), f) 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser( 72 | description='Skating-skeleton Data Converter.') 73 | parser.add_argument( 74 | '--data_path', default='/skating2.0/skating63_openpose_result') 75 | parser.add_argument( 76 | '--out_folder', default='/skating2.0/skating63_openpose_result/skeleton_file') 77 | arg = parser.parse_args() 78 | 79 | part = ['label_train_skating63', 'label_val_skating63'] 80 | for p in part: 81 | data_path = '{}/{}_data'.format(arg.data_path, p) 82 | label_path = '{}/{}.csv'.format(arg.data_path, p) 83 | data_out_path = '{}/{}_data.npy'.format(arg.out_folder, p) 84 | label_out_path = '{}/{}_label.pkl'.format(arg.out_folder, p) 85 | 86 | if not os.path.exists(arg.out_folder): 87 | os.makedirs(arg.out_folder) 88 | gendata(data_path, label_path, data_out_path, label_out_path) 89 | -------------------------------------------------------------------------------- /skating_convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import json 5 | import shutil 6 | 7 | import video 8 | from utils import * 9 | 10 | 11 | def pose_estimation(openpose, out_folder, video_path, model_name, model_folder, info, p): 12 | video_name = video_path.split('/')[-1].split('.')[0] 13 | output_snippets_dir = os.path.join(out_folder,'openpose_estimation/{}/{}'.format(model_name, video_name)) 14 | output_sequence_dir = os.path.join(out_folder,'{}_data/'.format(p)) 15 | if not os.path.exists(output_sequence_dir): 16 | os.makedirs(output_sequence_dir) 17 | output_sequence_path = '{}/{}.json'.format(output_sequence_dir, video_name) 18 | # pose estimation 19 | openpose_args = dict( 20 | video=video_path, 21 | write_json=output_snippets_dir, 22 | display=0, 23 | render_pose=0, 24 | model_pose=model_name, 25 | model_folder=model_folder) 26 | command_line = openpose + ' ' 27 | command_line += ' '.join(['--{} {}'.format(k, v) for k, v in openpose_args.items()]) 28 | shutil.rmtree(output_snippets_dir, ignore_errors=True) 29 | os.makedirs(output_snippets_dir) 30 | LOGGER.info(command_line) 31 | os.system(command_line) 32 | # pack openpose ouputs 33 | video_obj = video.get_video_frames(video_path) 34 | height, width, _ = video_obj[0].shape 35 | video_info = json_pack( 36 | output_snippets_dir, video_name, width, height, label_index=info["label_index"],label=info["label"]) 37 | if not os.path.exists(output_sequence_dir): 38 | os.makedirs(output_sequence_dir) 39 | with open(output_sequence_path, 'w') as outfile: 40 | json.dump(video_info, outfile) 41 | if len(video_info['data']) == 0: 42 | LOGGER.info('Can not find pose estimation results of %s'%(video_name)) 43 | return 44 | else: 45 | LOGGER.info('%s Pose estimation complete.'%(video_name)) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser(description='Skating Data Converter.') 50 | # region arguments yapf: disable 51 | parser.add_argument('--openpose', 52 | default='/openpose/build', 53 | help='Path to openpose') 54 | parser.add_argument( 55 | '--data_path', default='/skating2.0/skating63',help="Path to dataset") 56 | parser.add_argument( 57 | '--out_folder', default='/skating2.0/skating63_openpose_result',help="Path to save files") 58 | parser.add_argument( 59 | '--model_folder', default='/openpose/models',help="Path to model folder") 60 | arg = parser.parse_args() 61 | arg.trainfile=os.path.join(arg.data_path,"label_train_skating63.csv") 62 | arg.testfile=os.path.join(arg.data_path,"label_val_skating63.csv") 63 | openpose='{}/examples/openpose/openpose.bin'.format(arg.openpose) 64 | LOGGER.info(os.getcwd()) 65 | part = ['label_train_skating63', 'label_val_skating63'] 66 | restart_list = { 67 | 'label_train_skating63': 3566, 68 | 'label_val_skating63': 0 69 | } 70 | debug = False 71 | debug_count = 2 72 | for p in part: 73 | csvfile=os.path.join(arg.data_path,"{}.csv".format(p)) 74 | total_count = count_lines(csvfile) 75 | count = 0 76 | restart_count = restart_list[p] 77 | for category, video_name, label in name_loader(csvfile): 78 | if debug and count >= debug_count: 79 | break 80 | if count < restart_count: 81 | count += 1 82 | continue 83 | # try: 84 | video_name = video_name + ".mp4" 85 | info={} 86 | info['label_index']=int(label) 87 | info['has_skeleton']=True 88 | info['label']=category 89 | video_path = os.path.join(arg.data_path, category, video_name) 90 | if not os.path.exists(video_path): 91 | LOGGER.info("%s not exist"%(video_path)) 92 | count+=1 93 | msg = '{}:({:>5}/{:<5}) Processing data: '.format(p, count, total_count) 94 | print_toolbar(count * 100.0 / total_count, msg) 95 | pose_estimation(openpose, arg.out_folder,video_path, "BODY_25",arg.model_folder,info,p) 96 | # except Exception as e: 97 | # LOGGER.warning(e) 98 | -------------------------------------------------------------------------------- /feeder_skating.py: -------------------------------------------------------------------------------- 1 | # sys 2 | import os 3 | import sys 4 | import numpy as np 5 | import random 6 | import pickle 7 | import json 8 | # torch 9 | import torch 10 | import torch.nn as nn 11 | from torchvision import datasets, transforms 12 | 13 | 14 | class Feeder_skating(torch.utils.data.Dataset): 15 | """ Feeder for skeleton-based action recognition in kinetics-skeleton dataset 16 | Arguments: 17 | data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M) 18 | label_path: the path to label 19 | random_choose: If true, randomly choose a portion of the input sequence 20 | random_shift: If true, randomly pad zeros at the begining or end of sequence 21 | random_move: If true, perform randomly but continuously changed transformation to input sequence 22 | window_size: The length of the output sequence 23 | pose_matching: If ture, match the pose between two frames 24 | num_person_in: The number of people the feeder can observe in the input sequence 25 | num_person_out: The number of people the feeder in the output sequence 26 | debug: If true, only use the first 100 samples 27 | """ 28 | 29 | def __init__(self, 30 | data_path, 31 | label_path, 32 | window_size=-1, 33 | num_person_in=5, 34 | num_person_out=2, 35 | joins_count=25, 36 | debug=False): 37 | self.debug = debug 38 | self.data_path = data_path 39 | self.label_path = label_path 40 | self.window_size = window_size 41 | self.num_person_in = num_person_in 42 | self.num_person_out = num_person_out 43 | self.joins_count = joins_count 44 | 45 | self.load_data() 46 | 47 | def load_data(self): 48 | # load file list 49 | self.sample_name = [] 50 | self.label = [] 51 | # load label 52 | label_path = self.label_path 53 | with open(label_path) as f: 54 | for line in f.readlines(): 55 | info = line.strip() 56 | _, video_name, _, label = info.split(",") 57 | self.sample_name.append(video_name) 58 | self.label.append(int(label)) 59 | 60 | if self.debug: 61 | self.sample_name = self.sample_name[0:2] 62 | 63 | # output data shape (N, C, T, V, M) 64 | self.N = len(self.sample_name) #sample 65 | self.C = 3 #channel 66 | self.T = self.window_size #frame 67 | self.V = self.joins_count #joint 68 | self.M = self.num_person_out #person 69 | 70 | def __len__(self): 71 | return len(self.sample_name) 72 | 73 | def __iter__(self): 74 | return self 75 | 76 | def __getitem__(self, index): 77 | 78 | # output shape (C, T, V, M) 79 | # get data 80 | sample_name = self.sample_name[index] + ".json" 81 | sample_path = os.path.join(self.data_path, sample_name) 82 | with open(sample_path, 'r') as f: 83 | video_info = json.load(f) 84 | 85 | # fill data_numpy 86 | data_numpy = np.zeros((self.C, self.T, self.V, self.num_person_in)) 87 | for frame_info in video_info['data']: 88 | frame_index = frame_info['frame_index'] 89 | for m, skeleton_info in enumerate(frame_info["skeleton"]): 90 | if m >= self.num_person_in: 91 | break 92 | pose = skeleton_info['pose'] 93 | score = skeleton_info['score'] 94 | data_numpy[0, frame_index, :, m] = pose[0::2] 95 | data_numpy[1, frame_index, :, m] = pose[1::2] 96 | data_numpy[2, frame_index, :, m] = score 97 | 98 | # centralization 99 | data_numpy[0:2] = data_numpy[0:2] - 0.5 100 | data_numpy[0][data_numpy[2] == 0] = 0 101 | data_numpy[1][data_numpy[2] == 0] = 0 102 | 103 | # get & check label index 104 | label = video_info['label_index'] 105 | assert (self.label[index] == label) 106 | 107 | # sort by score 108 | sort_index = (-data_numpy[2, :, :, :].sum(axis=1)).argsort(axis=1) 109 | for t, s in enumerate(sort_index): 110 | data_numpy[:, t, :, :] = data_numpy[:, t, :, s].transpose((1, 2, 111 | 0)) 112 | data_numpy = data_numpy[:, :, :, 0:self.num_person_out] 113 | 114 | return data_numpy, label 115 | --------------------------------------------------------------------------------