├── .gitignore ├── .vscode └── settings.json ├── README.md ├── README_origin.md ├── __init__.py ├── config ├── __init__.py ├── config.py ├── config_full_loss.py ├── config_hard_full_loss.py └── config_v2.py ├── demo.py ├── demo2.py ├── model ├── __init__.py ├── initialization.py ├── model.py ├── network │ ├── __init__.py │ ├── basic_blocks.py │ ├── basic_blocks_dyrelu.py │ ├── dyrelu.py │ ├── gaitset.py │ └── triplet.py └── utils │ ├── __init__.py │ ├── data_loader.py │ ├── data_set.py │ ├── evaluator.py │ ├── sampler.py │ └── tensorboardDraw.py ├── model2 ├── __init__.py ├── initialization.py ├── model.py ├── network │ ├── __init__.py │ ├── basic_blocks.py │ ├── basic_blocks_dyrelu.py │ ├── dyrelu.py │ ├── gaitset.py │ └── triplet.py └── utils │ ├── __init__.py │ ├── data_loader.py │ ├── data_set.py │ ├── evaluator.py │ ├── sampler.py │ └── tensorboardDraw.py ├── pretreatment.py ├── requirements.txt ├── test ├── __init__.py ├── test.py └── test_all.py ├── train ├── __init__.py ├── train.py ├── train_full.py └── train_hard_full.py └── work ├── OUMVLP_network ├── basic_blocks.py └── gaitset.py ├── log ├── tensorboardV2_log │ ├── events.out.tfevents.1601021568.lab206-Server │ ├── events.out.tfevents.1601021725.lab206-Server │ ├── events.out.tfevents.1601021764.lab206-Server │ ├── events.out.tfevents.1601022036.lab206-Server │ ├── events.out.tfevents.1601022451.lab206-Server │ ├── events.out.tfevents.1601022615.lab206-Server │ ├── events.out.tfevents.1601040881.lab206-Server │ ├── events.out.tfevents.1601082118.lab206-Server │ ├── events.out.tfevents.1601084376.lab206-Server │ └── events.out.tfevents.1601084535.lab206-Server ├── tensorboardX_log │ ├── events.out.tfevents.1600595800.lab206-Server │ ├── events.out.tfevents.1600595961.lab206-Server │ ├── events.out.tfevents.1600596086.lab206-Server │ └── events.out.tfevents.1600608993.lab206-Server └── visdom │ ├── acc_array_exclude.txt │ ├── acc_array_include.txt │ └── iter_list.txt └── partition └── CASIA-B_73_False.npy /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.ptm 3 | work/checkpoint/ 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/data/public/env/anaconda3/bin/python", 3 | "python.autoComplete.extraPaths": [ 4 | "/data/public/env/anaconda3/bin/python" 5 | ], 6 | "python.analysis.memory.keepLibraryLocalVariables": true 7 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GaitSet 2 | > 主要是关于原仓库代码的相关解读,在代码中加上了自己的注释解读,个人能力有限,解读会有些错误,欢迎讨论学习 3 | 4 | ## start 5 | 1. 首先选择你想要的训练的loss,比如如果是full_loss,请更改一下config/config_full_loss.py中的配置参数 6 | 2. 运行train/train_full.py,即可开始训练 7 | ## 环境依赖 8 | 可以查看requirements.txt -------------------------------------------------------------------------------- /README_origin.md: -------------------------------------------------------------------------------- 1 | # GaitSet 2 | 3 | [![LICENSE](https://img.shields.io/badge/license-NPL%20(The%20996%20Prohibited%20License)-blue.svg)](https://github.com/996icu/996.ICU/blob/master/LICENSE) 4 | [![996.icu](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu) 5 | 6 | [GaitSet](https://arxiv.org/abs/1811.06186) is a **flexible**, **effective** and **fast** network for cross-view gait recognition. 7 | 8 | #### Flexible 9 | The input of GaitSet is a set of silhouettes. 10 | 11 | - There are **NOT ANY constrains** on an input, 12 | which means it can contain **any number** of **non-consecutive** silhouettes filmed under **different viewpoints** 13 | with **different walking conditions**. 14 | 15 | - As the input is a set, the **permutation** of the elements in the input 16 | will **NOT change** the output at all. 17 | 18 | #### Effective 19 | It achieves **Rank@1=95.0%** on [CASIA-B](http://www.cbsr.ia.ac.cn/english/Gait%20Databases.asp) 20 | and **Rank@1=87.1%** on [OU-MVLP](http://www.am.sanken.osaka-u.ac.jp/BiometricDB/GaitMVLP.html), 21 | excluding identical-view cases. 22 | 23 | #### Fast 24 | With 8 NVIDIA 1080TI GPUs, it only takes **7 minutes** to conduct an evaluation on 25 | [OU-MVLP](http://www.am.sanken.osaka-u.ac.jp/BiometricDB/GaitMVLP.html) which contains 133,780 sequences 26 | and average 70 frames per sequence. 27 | 28 | ## What's new 29 | The code and checkpoint for OUMVLP dataset have been released. 30 | See [OUMVLP](#oumvlp) for details. 31 | 32 | ## Prerequisites 33 | 34 | - Python 3.6 35 | - PyTorch 0.4+ 36 | - GPU 37 | 38 | 39 | ## Getting started 40 | ### Installation 41 | 42 | - (Not necessary) Install [Anaconda3](https://www.anaconda.com/download/) 43 | - Install [CUDA 9.0](https://developer.nvidia.com/cuda-90-download-archive) 44 | - install [cuDNN7.0](https://developer.nvidia.com/cudnn) 45 | - Install [PyTorch](http://pytorch.org/) 46 | 47 | Noted that our code is tested based on [PyTorch 0.4](http://pytorch.org/) 48 | 49 | ### Dataset & Preparation 50 | Download [CASIA-B Dataset](http://www.cbsr.ia.ac.cn/english/Gait%20Databases.asp) 51 | 52 | **!!! ATTENTION !!! ATTENTION !!! ATTENTION !!!** 53 | 54 | Before training or test, please make sure you have prepared the dataset 55 | by this two steps: 56 | - **Step1:** Organize the directory as: 57 | `your_dataset_path/subject_ids/walking_conditions/views`. 58 | E.g. `CASIA-B/001/nm-01/000/`. 59 | - **Step2:** Cut and align the raw silhouettes with `pretreatment.py`. 60 | (See [pretreatment](#pretreatment) for details.) 61 | Welcome to try different ways of pretreatment but note that 62 | the silhouettes after pretreatment **MUST have a size of 64x64**. 63 | 64 | Futhermore, you also can test our code on [OU-MVLP Dataset](http://www.am.sanken.osaka-u.ac.jp/BiometricDB/GaitMVLP.html). 65 | The number of channels and the training batchsize is slightly different for this dataset. 66 | For more detail, please refer to [our paper](https://arxiv.org/abs/1811.06186). 67 | 68 | #### Pretreatment 69 | `pretreatment.py` uses the alignment method in 70 | [this paper](https://ipsjcva.springeropen.com/articles/10.1186/s41074-018-0039-6). 71 | Pretreatment your dataset by 72 | ``` 73 | python pretreatment.py --input_path='root_path_of_raw_dataset' --output_path='root_path_for_output' 74 | ``` 75 | - `--input_path` **(NECESSARY)** Root path of raw dataset. 76 | - `--output_path` **(NECESSARY)** Root path for output. 77 | - `--log_file` Log file path. #Default: './pretreatment.log' 78 | - `--log` If set as True, all logs will be saved. 79 | Otherwise, only warnings and errors will be saved. #Default: False 80 | - `--worker_num` How many subprocesses to use for data pretreatment. Default: 1 81 | 82 | ### Configuration 83 | 84 | In `config.py`, you might want to change the following settings: 85 | - `dataset_path` **(NECESSARY)** root path of the dataset 86 | (for the above example, it is "gaitdata") 87 | - `WORK_PATH` path to save/load checkpoints 88 | - `CUDA_VISIBLE_DEVICES` indices of GPUs 89 | 90 | ### Train 91 | Train a model by 92 | ```bash 93 | python train.py 94 | ``` 95 | - `--cache` if set as TRUE all the training data will be loaded at once before the training start. 96 | This will accelerate the training. 97 | **Note that** if this arg is set as FALSE, samples will NOT be kept in the memory 98 | even they have been used in the former iterations. #Default: TRUE 99 | 100 | ### Evaluation 101 | Evaluate the trained model by 102 | ```bash 103 | python test.py 104 | ``` 105 | - `--iter` iteration of the checkpoint to load. #Default: 80000 106 | - `--batch_size` batch size of the parallel test. #Default: 1 107 | - `--cache` if set as TRUE all the test data will be loaded at once before the transforming start. 108 | This might accelerate the testing. #Default: FALSE 109 | 110 | It will output Rank@1 of all three walking conditions. 111 | Note that the test is **parallelizable**. 112 | To conduct a faster evaluation, you could use `--batch_size` to change the batch size for test. 113 | 114 | #### OUMVLP 115 | Since the huge differences between OUMVLP and CASIA-B, the network setting on OUMVLP is slightly different. 116 | - The alternated network's code can be found at `./work/OUMVLP_network`. Use them to replace the corresponding files in `./model/network`. 117 | - The checkpoint can be found [here](https://1drv.ms/u/s!AurT2TsSKdxQuWN8drzIv_phTR5m?e=Gfbl3m). 118 | - In `./config.py`, modify `'batch_size': (8, 16)` into `'batch_size': (32,16)`. 119 | - Prepare your OUMVLP dataset according to the instructions in [Dataset & Preparation](#dataset--preparation). 120 | 121 | ## To Do List 122 | - Transformation: The script for transforming a set of silhouettes into a discriminative representation. 123 | 124 | ## Authors & Contributors 125 | GaitSet is authored by 126 | [Hanqing Chao](https://www.linkedin.com/in/hanqing-chao-9aa42412b/), 127 | [Yiwei He](https://www.linkedin.com/in/yiwei-he-4a6a6bbb/), 128 | [Junping Zhang](http://www.pami.fudan.edu.cn/~jpzhang/) 129 | and JianFeng Feng from Fudan Universiy. 130 | [Junping Zhang](http://www.pami.fudan.edu.cn/~jpzhang/) 131 | is the corresponding author. 132 | The code is developed by 133 | [Hanqing Chao](https://www.linkedin.com/in/hanqing-chao-9aa42412b/) 134 | and [Yiwei He](https://www.linkedin.com/in/yiwei-he-4a6a6bbb/). 135 | Currently, it is being maintained by 136 | [Hanqing Chao](https://www.linkedin.com/in/hanqing-chao-9aa42412b/) 137 | and Kun Wang. 138 | 139 | 140 | ## Citation 141 | Please cite these papers in your publications if it helps your research: 142 | ``` 143 | @inproceedings{chao2019gaitset, 144 | author = {Chao, Hanqing and He, Yiwei and Zhang, Junping and Feng, Jianfeng}, 145 | booktitle = {AAAI}, 146 | title = {{GaitSet}: Regarding Gait as a Set for Cross-View Gait Recognition}, 147 | year = {2019} 148 | } 149 | ``` 150 | Link to paper: 151 | - [GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition](https://arxiv.org/abs/1811.06186) 152 | 153 | 154 | ## License 155 | GaitSet is freely available for free non-commercial use, and may be redistributed under these conditions. 156 | For commercial queries, contact [Junping Zhang](http://www.pami.fudan.edu.cn/~jpzhang/). 157 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : admin 3 | # @Time : 2018/11/16 4 | 5 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/9/26 10:54 4 | # 文件 :__init__.py 5 | # IDE :PyCharm 6 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | conf = { 2 | "WORK_PATH": "./work", 3 | "CUDA_VISIBLE_DEVICES": "0", 4 | "data": { 5 | 'dataset_path': "/data/lwl/Gait_experiment/gait_data", 6 | 'resolution': '64', 7 | 'dataset': 'CASIA-B', 8 | # In CASIA-B, data of subject #5 is incomplete. 9 | # Thus, we ignore it in training. 10 | # For more detail, please refer to 11 | # function: utils.data_loader.load_data 12 | 'pid_num': 73, # LT划分方式 74用于训练(In CASIA-B, data of subject #5 is incomplete.),其余的用于测试 13 | 'pid_shuffle': False, 14 | # 是否进行随机的划分数据集,如果为False,那么直接选取1-74为训练集,剩余的测试集 15 | }, 16 | "model": { 17 | 'hidden_dim': 256, 18 | 'lr': 1e-4, 19 | 'hard_or_full_trip': 'full', 20 | # TODO 注意这里修改了batchsize的的大小 21 | 'batch_size': (8, 4), 22 | # TODO 注意这里修改了接着训练的轮数 23 | 'restore_iter': 7690, 24 | 'total_iter': 80000, 25 | 'margin': 0.2, 26 | 'num_workers': 8, 27 | 'frame_num': 30, 28 | 'model_name': 'GaitSet', 29 | }, 30 | } 31 | -------------------------------------------------------------------------------- /config/config_full_loss.py: -------------------------------------------------------------------------------- 1 | conf = { 2 | "WORK_PATH": "/data/lwl/Gait_experiment/GaitSet/work", 3 | "CUDA_VISIBLE_DEVICES": "0", 4 | "data": { 5 | 'dataset_path': "/data/lwl/Gait_experiment/gait_data", 6 | 'resolution': '64', 7 | 'dataset': 'CASIA-B', 8 | # In CASIA-B, data of subject #5 is incomplete. 9 | # Thus, we ignore it in training. 10 | # For more detail, please refer to 11 | # function: utils.data_loader.load_data 12 | 'pid_num': 73, # LT划分方式 74用于训练(In CASIA-B, data of subject #5 is incomplete.),其余的用于测试 13 | 'pid_shuffle': False, 14 | # 是否进行随机的划分数据集,如果为False,那么直接选取1-74为训练集,剩余的测试集 15 | }, 16 | "model": { 17 | 'hidden_dim': 256, 18 | 'lr': 1e-4, 19 | 'hard_or_full_trip': 'full', 20 | # TODO 注意这里修改了batchsize的的大小 21 | 'batch_size': (8, 16), 22 | # TODO 注意这里修改了接着训练的轮数 23 | 'restore_iter': 0, 24 | 'total_iter': 80000, 25 | 'margin': 0.2, 26 | 'num_workers': 8, 27 | 'frame_num': 30, 28 | 'model_name': 'GaitSet_full_loss', 29 | 'model_save_dir': "GaitSet/full_loss", 30 | 'logdir': './log/visdom/full_loss' 31 | }, 32 | } 33 | -------------------------------------------------------------------------------- /config/config_hard_full_loss.py: -------------------------------------------------------------------------------- 1 | conf = { 2 | "WORK_PATH": "/data/lwl/Gait_experiment/GaitSet/work", 3 | "CUDA_VISIBLE_DEVICES": "0", 4 | "data": { 5 | 'dataset_path': "/data/lwl/Gait_experiment/gait_data", 6 | 'resolution': '64', 7 | 'dataset': 'CASIA-B', 8 | # In CASIA-B, data of subject #5 is incomplete. 9 | # Thus, we ignore it in training. 10 | # For more detail, please refer to 11 | # function: utils.data_loader.load_data 12 | 'pid_num': 73, # LT划分方式 74用于训练(In CASIA-B, data of subject #5 is incomplete.),其余的用于测试 13 | 'pid_shuffle': False, 14 | # 是否进行随机的划分数据集,如果为False,那么直接选取1-74为训练集,剩余的测试集 15 | }, 16 | "model": { 17 | 'hidden_dim': 256, 18 | 'lr': 1e-5, 19 | 'hard_or_full_trip': 'full', 20 | # TODO 注意这里修改了batchsize的的大小 21 | 'batch_size': (8, 16), 22 | # TODO 注意这里修改了接着训练的轮数 23 | 'restore_iter': 100000, 24 | 'total_iter': 200000, 25 | 'margin': 0.2, 26 | 'num_workers': 8, 27 | 'frame_num': 30, 28 | 'model_name': 'GaitSet_hard_full_loss', 29 | 'model_save_dir': "GaitSet/hard_full_loss", 30 | 'logdir': './log/visdom/hard_full_loss' 31 | }, 32 | } 33 | -------------------------------------------------------------------------------- /config/config_v2.py: -------------------------------------------------------------------------------- 1 | conf = { 2 | "WORK_PATH": "/data/lwl/Gait_experiment/GaitSet/work", 3 | "CUDA_VISIBLE_DEVICES": "0", 4 | "data": { 5 | 'dataset_path': "/data/lwl/Gait_experiment/gait_data", 6 | 'resolution': '64', 7 | 'dataset': 'CASIA-B', 8 | # In CASIA-B, data of subject #5 is incomplete. 9 | # Thus, we ignore it in training. 10 | # For more detail, please refer to 11 | # function: utils.data_loader.load_data 12 | 'pid_num': 73, # LT划分方式 74用于训练(In CASIA-B, data of subject #5 is incomplete.),其余的用于测试 13 | 'pid_shuffle': False, 14 | # 是否进行随机的划分数据集,如果为False,那么直接选取1-74为训练集,剩余的测试集 15 | }, 16 | "model": { 17 | 'hidden_dim': 256, 18 | 'lr': 1e-5, 19 | 'hard_or_full_trip': 'full', 20 | # TODO 注意这里修改了batchsize的的大小 21 | 'batch_size': (8, 4), 22 | # TODO 注意这里修改了接着训练的轮数 23 | 'restore_iter': 86800, 24 | 'total_iter': 160000, 25 | 'margin': 0.2, 26 | 'num_workers': 8, 27 | 'frame_num': 30, 28 | 'model_name': 'GaitSet', 29 | 'logdir': './log/tensorboardV2_log' 30 | }, 31 | } 32 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/8/29 16:40 4 | # 文件 :demo 5 | # IDE :PyCharm 创建文件的IDE名称 6 | import numpy 7 | from visdom import Visdom 8 | 9 | if __name__ == '__main__': 10 | '''vis = Visdom(env="demo", log_to_filename="./visdom.log") 11 | print(len(np.random.rand(20, 3)), len(np.arange(0, 20))) 12 | vis.bar( 13 | X=np.random.rand(20, 3), 14 | Y=np.arange(0, 20), 15 | opts=dict( 16 | stacked=False, 17 | legend=['The Netherlands', 'France', 'United States'] 18 | ) 19 | )''' 20 | # Visdom.replay_log(log_filename="./visdom.log") 21 | 22 | vis = Visdom(env="test") 23 | acc_array_exclude = numpy.array([[87.61818181818184, 78.15718181818183, 64.21818181818182], 24 | [87.55454545454546, 78.20354545454546, 64.36363636363637]]) 25 | iter_list = [88800, 88900] 26 | numpy.savetxt("work/log/test/visdom/acc.txt", acc_array_exclude) 27 | numpy.savetxt("work/log/test/visdom/iter_list.txt", iter_list) 28 | vis.bar(X=acc_array_exclude, 29 | opts=dict( 30 | stacked=False, 31 | legend=['NM', 'BG', 'CL'], 32 | rownames=iter_list, 33 | title='Test_acc', 34 | ylabel='rank-1 accuracy', # y轴名称 35 | xtickmin=0.4 # x轴左端点起始位置 36 | # xtickstep=0.4 # 每个柱形间隔距离 37 | ), win="acc_array_exclude") 38 | vis.bar(X=acc_array_exclude, 39 | opts=dict( 40 | stacked=False, 41 | legend=['NM', 'BG', 'CL'], 42 | rownames=[88901, 88902], 43 | title='Test_acc', 44 | ylabel='rank-1 accuracy', # y轴名称 45 | xtickmin=0.4, # x轴左端点起始位置 46 | # xtickstep=0.4 # 每个柱形间隔距离 47 | ), win="acc_array_exclude", ) 48 | -------------------------------------------------------------------------------- /demo2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/9/28 9:38 4 | # 文件 :demo2 5 | # IDE :PyCharm 6 | import numpy as np 7 | import visdom 8 | 9 | if __name__ == '__main__': 10 | """vis = Visdom(env="GaitSet_test", log_to_filename="work/log/visdom/test_all_acc.log") 11 | acc_array_exclude = np.loadtxt("work/log/visdom/acc_array_exclude.txt") 12 | acc_array_include = np.loadtxt("work/log/visdom/acc_array_include.txt") 13 | iter_list = np.loadtxt("work/log/visdom/iter_list.txt") 14 | print(acc_array_exclude.shape) 15 | print(iter_list.shape) 16 | iter_list_new = list(filter(lambda x: 0 == x % 100, iter_list.tolist())) 17 | print(iter_list_new.__len__()) 18 | vis.bar(X=acc_array_exclude, 19 | opts=dict( 20 | stacked=False, 21 | legend=['NM', 'BG', 'CL'], 22 | rownames=iter_list_new, 23 | title='Test_acc exclude', 24 | ylabel='rank-1 accuracy', # y轴名称 25 | xtickmin=0.4 # x轴左端点起始位置 26 | # xtickstep=0.4 # 每个柱形间隔距离 27 | ), win="acc_array_exclude") 28 | vis.bar(X=acc_array_include, 29 | opts=dict( 30 | stacked=False, 31 | legend=['NM', 'BG', 'CL'], 32 | rownames=iter_list_new, 33 | title='Test_acc include', 34 | ylabel='rank-1 accuracy', # y轴名称 35 | xtickmin=0.4, # x轴左端点起始位置 36 | # xtickstep=0.4 # 每个柱形间隔距离, 37 | append=True 38 | ), win="acc_array_include")""" 39 | 40 | track_loss = 0 # for draw graph 41 | global_step = 0 42 | vis = visdom.Visdom(env=u"train_loss") 43 | win = vis.line(X=np.array([global_step]), Y=np.array([track_loss])) 44 | 45 | for epoch in range(10): 46 | # 此处省略代码 47 | 48 | for iter_num, dial_batch in enumerate(range(20)): 49 | # 此处省略代码 50 | loss = np.random.random() 51 | vis.line(X=np.array([global_step]), Y=np.array([loss]), win=win, 52 | update='append', opts=dict(title="demo")) # for draw graph 53 | global_step += 1 54 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/model/__init__.py -------------------------------------------------------------------------------- /model/initialization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : admin 3 | # @Time : 2018/11/15 4 | import os 5 | from copy import deepcopy 6 | 7 | import numpy as np 8 | 9 | from .model import Model 10 | from .utils import load_data 11 | 12 | 13 | def initialize_data(config, train=False, test=False): 14 | # 这里的train和test代表的是否使用cache 15 | print("Initializing data source...") 16 | # 得到Dateset对象 17 | train_source, test_source = load_data(**config['data'], cache=(train or test)) 18 | if train: 19 | print("Loading training data...") 20 | train_source.load_all_data() 21 | if test: 22 | print("Loading test data...") 23 | test_source.load_all_data() 24 | print("Data initialization complete.") 25 | return train_source, test_source 26 | 27 | 28 | def initialize_model(config, train_source, test_source): 29 | print("Initializing model...") 30 | data_config = config['data'] 31 | model_config = config['model'] 32 | model_param = deepcopy(model_config) 33 | model_param['train_source'] = train_source 34 | model_param['test_source'] = test_source 35 | model_param['train_pid_num'] = data_config['pid_num'] 36 | batch_size = int(np.prod(model_config['batch_size'])) # np.prod 计算所有元素的乘积 37 | model_param['save_name'] = '_'.join(map(str, [ 38 | model_config['model_name'], 39 | data_config['dataset'], 40 | data_config['pid_num'], 41 | data_config['pid_shuffle'], 42 | model_config['hidden_dim'], 43 | model_config['margin'], 44 | batch_size, 45 | model_config['hard_or_full_trip'], 46 | model_config['frame_num'], 47 | ])) 48 | 49 | m = Model(**model_param) 50 | print("Model initialization complete.") 51 | return m, model_param['save_name'] 52 | 53 | 54 | def initialization(config, train=False, test=False): 55 | print("Initialzing...") 56 | WORK_PATH = config['WORK_PATH'] 57 | os.chdir(WORK_PATH) # 改变当前工作目录到指定的路径 58 | os.environ["CUDA_VISIBLE_DEVICES"] = config["CUDA_VISIBLE_DEVICES"] 59 | train_source, test_source = initialize_data(config, train, test) 60 | return initialize_model(config, train_source, test_source) 61 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import os.path as osp 4 | import random 5 | import sys 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import torch 10 | import torch.autograd as autograd 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.utils.data as tordata 14 | from tensorboardX import SummaryWriter 15 | 16 | from .network import TripletLoss, SetNet 17 | from .utils import TripletSampler 18 | 19 | 20 | class Model: 21 | def __init__(self, 22 | hidden_dim, 23 | lr, 24 | hard_or_full_trip, 25 | margin, 26 | num_workers, 27 | batch_size, 28 | restore_iter, 29 | total_iter, 30 | save_name, 31 | train_pid_num, 32 | frame_num, 33 | model_name, 34 | train_source, 35 | test_source, 36 | img_size=64, 37 | logdir="./log"): 38 | 39 | self.save_name = save_name 40 | self.train_pid_num = train_pid_num 41 | self.train_source = train_source 42 | self.test_source = test_source 43 | 44 | self.hidden_dim = hidden_dim 45 | self.lr = lr 46 | self.hard_or_full_trip = hard_or_full_trip 47 | self.margin = margin 48 | self.frame_num = frame_num 49 | self.num_workers = num_workers 50 | self.batch_size = batch_size 51 | self.model_name = model_name 52 | self.P, self.M = batch_size 53 | 54 | self.restore_iter = restore_iter 55 | self.total_iter = total_iter 56 | 57 | self.img_size = img_size 58 | 59 | self.encoder = SetNet(self.hidden_dim).float() 60 | # TODO 这里修改了多卡运行的代码 61 | self.encoder = nn.DataParallel(self.encoder) 62 | self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float() 63 | self.triplet_loss = nn.DataParallel(self.triplet_loss) 64 | self.encoder.cuda() 65 | self.triplet_loss.cuda() 66 | 67 | self.optimizer = optim.Adam([ 68 | {'params': self.encoder.parameters()}, 69 | ], lr=self.lr) 70 | 71 | self.hard_loss_metric = [] 72 | self.full_loss_metric = [] 73 | self.full_loss_num = [] 74 | self.dist_list = [] 75 | self.mean_dist = 0.01 76 | 77 | self.sample_type = 'all' 78 | self.logdir = logdir 79 | 80 | def collate_fn(self, batch): 81 | 82 | batch_size = len(batch) # batch的大小 83 | """ 84 | data = [self.__loader__(_path) for _path in self.seq_dir[index]] 85 | feature_num代表的是data数据所包含的集合的个数,这里一直为1,因为读取的是 86 | _seq_dir = osp.join(seq_type_path, _view) 87 | seqs = os.listdir(_seq_dir) # 遍历出所有的轮廓剪影 88 | """ 89 | feature_num = len(batch[0][0]) 90 | 91 | seqs = [batch[i][0] for i in range(batch_size)] # 对应于data 92 | frame_sets = [batch[i][1] for i in range(batch_size)] # 对应于 frame_set 93 | view = [batch[i][2] for i in range(batch_size)] # 对应于view[index] 94 | seq_type = [batch[i][3] for i in range(batch_size)] # 对应于self.seq_type[index] 95 | label = [batch[i][4] for i in range(batch_size)] # 对应于self.label[index] 96 | batch = [seqs, view, seq_type, label, None] 97 | 98 | ''' 99 | 这里的一个样本由 data, frame_set, self. 100 | view[index], self.seq_type[index], self.label[index] 101 | 组成 102 | ''' 103 | 104 | def select_frame(index): 105 | sample = seqs[index] 106 | frame_set = frame_sets[index] 107 | if self.sample_type == 'random': 108 | # 这里的random.choices是有放回的抽取样本 109 | frame_id_list = random.choices(frame_set, k=self.frame_num) 110 | _ = [feature.loc[frame_id_list].values for feature in sample] 111 | # feature.loc[]传入list会取出一组的数据 112 | else: 113 | # 或者选取所有的帧 114 | _ = [feature.values for feature in sample] 115 | return _ 116 | 117 | # 提取出每个样本的帧,组成list,存的是array组成的list 118 | seqs = list(map(select_frame, range(len(seqs)))) 119 | # print(self.sample_type, "采样版本") 120 | if self.sample_type == 'random': 121 | seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] 122 | 123 | else: 124 | # TODO 这里更改了GPU的个数 125 | gpu_num = min(torch.cuda.device_count(), batch_size) 126 | # gpu_num = min(4, batch_size) 127 | batch_per_gpu = math.ceil(batch_size / gpu_num) 128 | 129 | # batch_frames的内容: 130 | # [[gpu1_sample_1_frameNumbers,gpu1_sample_2_frameNumbers,...],[gpu2_sample_1_frameNumbers,gpu2_sample_2_frameNumbers,...],....] 131 | batch_frames = [[ # 将数据划分到不同的GPU上 132 | len(frame_sets[i]) # 每个样本的帧的总数数 133 | for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) 134 | if i < batch_size 135 | ] for _ in range(gpu_num)] 136 | if len(batch_frames[-1]) != batch_per_gpu: 137 | for _ in range(batch_per_gpu - len(batch_frames[-1])): 138 | batch_frames[-1].append(0) # 最后一个GPU上的batch大小不够时,补0 139 | max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)]) # 求出哪个GPU上的帧最多 140 | 141 | # 将每个GPU上的对应的batch数据进行拼接,组成最终的一个大的array 142 | # seqs=[[gpu1_batch,gpu2_batch,gpu3_batch,.....]] 143 | seqs = [[ 144 | np.concatenate([ # 这里将一个batch所有的帧进行了拼接 145 | seqs[i][j] 146 | for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) 147 | if i < batch_size 148 | ], 0) for _ in range(gpu_num)] 149 | for j in range(feature_num)] 150 | # TODO 打印seqs[j][_]的形状大小 151 | # print("seqs的大小:", seqs[0][0].shape) 152 | # 此时的 153 | seqs = [np.asarray([ 154 | np.pad(seqs[j][_], # seqs[j][_]的大小为(GPU_batch_size*frame_number)*64*44 155 | ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)), 156 | 'constant', # 将每个batch的总帧数像最多帧数看齐,如果不够则补0 157 | constant_values=0) 158 | for _ in range(gpu_num)]) 159 | for j in range(feature_num)] 160 | 161 | batch[4] = np.asarray(batch_frames) 162 | 163 | batch[0] = seqs 164 | # TODO 打印seqs的形状大小 165 | # print("seqs的形状大小:", seqs[0].shape) 166 | return batch 167 | 168 | def fit(self): 169 | torch.backends.cudnn.benchmark = True 170 | writer = SummaryWriter(self.logdir) 171 | if self.restore_iter != 0: 172 | self.load(self.restore_iter) 173 | 174 | self.encoder.train() 175 | self.sample_type = 'random' 176 | # todo 这里改变了采样的方式 177 | # self.sample_type = 'all' 178 | for param_group in self.optimizer.param_groups: 179 | param_group['lr'] = self.lr 180 | triplet_sampler = TripletSampler(self.train_source, self.batch_size) # 自定义的采样函数 181 | train_loader = tordata.DataLoader( 182 | dataset=self.train_source, 183 | batch_sampler=triplet_sampler, 184 | collate_fn=self.collate_fn, 185 | num_workers=self.num_workers) 186 | 187 | train_label_set = list(self.train_source.label_set) 188 | train_label_set.sort() 189 | 190 | _time1 = datetime.now() 191 | for seq, view, seq_type, label, batch_frame in train_loader: 192 | self.restore_iter += 1 193 | self.optimizer.zero_grad() 194 | # TODO 这里修改了将数据放在CPU上 195 | for i in range(len(seq)): 196 | seq[i] = self.np2var(seq[i]).float() 197 | if batch_frame is not None: 198 | batch_frame = self.np2var(batch_frame).int() 199 | # with SummaryWriter(comment="encoder") as w: 200 | # self.encoder.cpu() 201 | # # seq.cpu() 202 | # # batch_frame.cpu() 203 | # w.add_graph(self.encoder, (seq[0],)) 204 | 205 | # todo 这里在退出程序,可视化encoder网络结构 206 | # sys.exit() 207 | # print("seq:", seq) 208 | # feature:128*62*256 209 | feature, label_prob = self.encoder(*seq, batch_frame) 210 | # 存放的是在train_label_set = list(self.train_source.label_set)中的下标位置信息 211 | target_label = [train_label_set.index(l) for l in label] 212 | target_label = self.np2var(np.array(target_label)).long() 213 | 214 | # 这里维度变换之后变成了 62*128*256 215 | triplet_feature = feature.permute(1, 0, 2).contiguous() 216 | # triplet_label:62*128 217 | triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1) 218 | (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num 219 | ) = self.triplet_loss(triplet_feature, triplet_label) 220 | if self.hard_or_full_trip == 'hard': 221 | loss = hard_loss_metric.mean() 222 | elif self.hard_or_full_trip == 'full': 223 | loss = full_loss_metric.mean() 224 | 225 | self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy()) 226 | self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy()) 227 | self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy()) 228 | self.dist_list.append(mean_dist.mean().data.cpu().numpy()) 229 | 230 | if loss > 1e-9: 231 | loss.backward() 232 | self.optimizer.step() 233 | 234 | if self.restore_iter % 1000 == 0: # 打印每隔1000代的训练时间 235 | print(datetime.now() - _time1) 236 | _time1 = datetime.now() 237 | # TODO 更改了每10个batch打印一次 238 | if self.restore_iter % 10 == 0: 239 | print('iter {}:'.format(self.restore_iter), end='') 240 | print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='') 241 | writer.add_scalar("hard_loss_metric", np.mean(self.hard_loss_metric), self.restore_iter) 242 | 243 | print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='') 244 | writer.add_scalar("full_loss_metric", np.mean(self.full_loss_metric), self.restore_iter) 245 | 246 | print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='') 247 | writer.add_scalar("full_loss_num", np.mean(self.full_loss_num), self.restore_iter) 248 | 249 | self.mean_dist = np.mean(self.dist_list) 250 | print(', mean_dist={0:.8f}'.format(self.mean_dist), end='') 251 | writer.add_scalar("mean_dist", self.mean_dist, self.restore_iter) 252 | 253 | print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='') 254 | writer.add_scalar("lr", self.optimizer.param_groups[0]['lr'], self.restore_iter) 255 | print(', hard or full=%r' % self.hard_or_full_trip) 256 | sys.stdout.flush() 257 | self.hard_loss_metric = [] 258 | self.full_loss_metric = [] 259 | self.full_loss_num = [] 260 | self.dist_list = [] 261 | if self.restore_iter % 100 == 0: 262 | self.save() 263 | 264 | # Visualization using t-SNE 265 | # if self.restore_iter % 500 == 0: 266 | # pca = TSNE(2) 267 | # pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy()) 268 | # for i in range(self.P): 269 | # plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0], 270 | # pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i]) 271 | # 272 | # plt.show() 273 | 274 | if self.restore_iter == self.total_iter: 275 | break 276 | del loss 277 | torch.cuda.empty_cache() 278 | 279 | def ts2var(self, x): 280 | # TODO 这里修改了不让数据到GPU上 281 | return autograd.Variable(x).cuda() 282 | # return autograd.Variable(x) 283 | 284 | def np2var(self, x): 285 | return self.ts2var(torch.from_numpy(x)) 286 | 287 | def transform(self, flag, batch_size=1): 288 | self.encoder.eval() 289 | source = self.test_source if flag == 'test' else self.train_source 290 | self.sample_type = 'all' 291 | data_loader = tordata.DataLoader( 292 | dataset=source, 293 | batch_size=batch_size, 294 | sampler=tordata.sampler.SequentialSampler(source), 295 | collate_fn=self.collate_fn, 296 | num_workers=self.num_workers) 297 | 298 | feature_list = list() 299 | view_list = list() 300 | seq_type_list = list() 301 | label_list = list() 302 | 303 | for i, x in enumerate(data_loader): 304 | seq, view, seq_type, label, batch_frame = x 305 | for j in range(len(seq)): 306 | seq[j] = self.np2var(seq[j]).float() 307 | if batch_frame is not None: 308 | batch_frame = self.np2var(batch_frame).int() 309 | # print(batch_frame, np.sum(batch_frame)) 310 | 311 | feature, _ = self.encoder(*seq, batch_frame) 312 | n, num_bin, _ = feature.size() 313 | feature_list.append(feature.view(n, -1).data.cpu().numpy()) 314 | view_list += view 315 | seq_type_list += seq_type 316 | label_list += label 317 | # feature_list中每个元素的形状为:128*(62X256)=128*15872 318 | # 返回所有样本的特征向量组成的数组,形状为:样本总数*15872 319 | return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list 320 | 321 | def save(self): 322 | os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True) 323 | torch.save(self.encoder.state_dict(), 324 | osp.join('checkpoint', self.model_name, 325 | '{}-{:0>5}-encoder.ptm'.format( 326 | self.save_name, self.restore_iter))) 327 | torch.save(self.optimizer.state_dict(), 328 | osp.join('checkpoint', self.model_name, 329 | '{}-{:0>5}-optimizer.ptm'.format( 330 | self.save_name, self.restore_iter))) 331 | 332 | # restore_iter: iteration index of the checkpoint to load 333 | def load(self, restore_iter): 334 | self.encoder.load_state_dict(torch.load(osp.join( 335 | 'checkpoint', self.model_name, 336 | '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)))) 337 | self.optimizer.load_state_dict(torch.load(osp.join( 338 | 'checkpoint', self.model_name, 339 | '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter)))) 340 | -------------------------------------------------------------------------------- /model/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaitset import SetNet 2 | from .triplet import TripletLoss 3 | -------------------------------------------------------------------------------- /model/network/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | # 基本卷积+激活函数,注意这里的激活函数采用的是 leaky_relu 6 | class BasicConv2d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 8 | super(BasicConv2d, self).__init__() 9 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) 10 | 11 | def forward(self, x): 12 | x = self.conv(x) 13 | return F.leaky_relu(x, inplace=True) 14 | 15 | 16 | class SetBlock(nn.Module): 17 | def __init__(self, forward_block, pooling=False): 18 | super(SetBlock, self).__init__() 19 | self.forward_block = forward_block 20 | self.pooling = pooling 21 | if pooling: 22 | self.pool2d = nn.MaxPool2d(2) 23 | 24 | def forward(self, x): 25 | n, s, c, h, w = x.size() 26 | x = self.forward_block(x.view(-1, c, h, w)) 27 | if self.pooling: 28 | x = self.pool2d(x) 29 | _, c, h, w = x.size() 30 | return x.view(n, s, c, h, w) 31 | -------------------------------------------------------------------------------- /model/network/basic_blocks_dyrelu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .dyrelu import DyReLUA 5 | 6 | 7 | # 基本卷积+激活函数,注意这里的激活函数采用的是 leaky_relu 8 | class BasicConv2d(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size, dy_relu=False, **kwargs): 10 | super(BasicConv2d, self).__init__() 11 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) 12 | if dy_relu is True: 13 | self.dy_relu = DyReLUA(out_channels) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | if hasattr(self, "dy_relu"): 18 | return self.dy_relu(x) 19 | else: 20 | return F.leaky_relu(x, inplace=True) 21 | 22 | 23 | class SetBlock(nn.Module): 24 | def __init__(self, forward_block, pooling=False): 25 | super(SetBlock, self).__init__() 26 | self.forward_block = forward_block 27 | self.pooling = pooling 28 | if pooling: 29 | self.pool2d = nn.MaxPool2d(2) 30 | 31 | def forward(self, x): 32 | n, s, c, h, w = x.size() 33 | x = self.forward_block(x.view(-1, c, h, w)) 34 | if self.pooling: 35 | x = self.pool2d(x) 36 | _, c, h, w = x.size() 37 | return x.view(n, s, c, h, w) 38 | -------------------------------------------------------------------------------- /model/network/dyrelu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # 跨空间和通道共享 7 | class DyReLU(nn.Module): 8 | def __init__(self, channels, reduction=4, k=2, conv_type='2d'): 9 | super(DyReLU, self).__init__() 10 | self.channels = channels 11 | self.k = k 12 | self.conv_type = conv_type 13 | assert self.conv_type in ['1d', '2d'] 14 | 15 | self.fc1 = nn.Linear(channels, channels // reduction) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.fc2 = nn.Linear(channels // reduction, 2 * k) 18 | self.sigmoid = nn.Sigmoid() 19 | 20 | self.register_buffer('lambdas', torch.Tensor([1.] * k + [0.5] * k).float()) 21 | self.register_buffer('init_v', torch.Tensor([1.] + [0.] * (2 * k - 1)).float()) 22 | 23 | def get_relu_coefs(self, x): 24 | theta = torch.mean(x, axis=-1) 25 | if self.conv_type == '2d': 26 | theta = torch.mean(theta, axis=-1) 27 | theta = self.fc1(theta) 28 | theta = self.relu(theta) 29 | theta = self.fc2(theta) 30 | theta = 2 * self.sigmoid(theta) - 1 31 | return theta 32 | 33 | def forward(self, x): 34 | raise NotImplementedError 35 | 36 | 37 | class DyReLUA(DyReLU): 38 | def __init__(self, channels, reduction=4, k=2, conv_type='2d'): 39 | super(DyReLUA, self).__init__(channels, reduction, k, conv_type) 40 | self.fc2 = nn.Linear(channels // reduction, 2 * k) 41 | 42 | def forward(self, x): 43 | assert x.shape[1] == self.channels 44 | theta = self.get_relu_coefs(x) 45 | 46 | relu_coefs = theta.view(-1, 2 * self.k) * self.lambdas + self.init_v 47 | # BxCxL -> LxCxBx1 48 | x_perm = x.transpose(0, -1).unsqueeze(-1) 49 | output = x_perm * relu_coefs[:, :self.k] + relu_coefs[:, self.k:] 50 | # LxCxBx2 -> BxCxL 51 | result = torch.max(output, dim=-1)[0].transpose(0, -1) 52 | del x_perm, output, relu_coefs, theta 53 | return result 54 | 55 | 56 | class DyReLUB(DyReLU): 57 | def __init__(self, channels, reduction=4, k=2, conv_type='2d'): 58 | super(DyReLUB, self).__init__(channels, reduction, k, conv_type) 59 | self.fc2 = nn.Linear(channels // reduction, 2 * k * channels) 60 | 61 | def forward(self, x): 62 | assert x.shape[1] == self.channels 63 | theta = self.get_relu_coefs(x) 64 | 65 | relu_coefs = theta.view(-1, self.channels, 2 * self.k) * self.lambdas + self.init_v 66 | 67 | if self.conv_type == '1d': 68 | # BxCxL -> LxBxCx1 69 | x_perm = x.permute(2, 0, 1).unsqueeze(-1) 70 | output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:] 71 | # LxBxCx2 -> BxCxL 72 | result = torch.max(output, dim=-1)[0].permute(1, 2, 0) 73 | 74 | elif self.conv_type == '2d': 75 | # BxCxHxW -> HxWxBxCx1 76 | x_perm = x.permute(2, 3, 0, 1).unsqueeze(-1) 77 | output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:] 78 | # HxWxBxCx2 -> BxCxHxW 79 | result = torch.max(output, dim=-1)[0].permute(2, 3, 0, 1) 80 | del x_perm, output, relu_coefs, theta 81 | return result 82 | -------------------------------------------------------------------------------- /model/network/gaitset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .basic_blocks_dyrelu import SetBlock, BasicConv2d 6 | 7 | 8 | class SetNet(nn.Module): 9 | def __init__(self, hidden_dim): 10 | super(SetNet, self).__init__() 11 | self.hidden_dim = hidden_dim 12 | self.batch_frame = None 13 | 14 | _set_in_channels = 1 15 | _set_channels = [32, 64, 128] 16 | self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 5, dy_relu=True, padding=2)) 17 | self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 3, dy_relu=True, padding=1), True) 18 | # set_layer1.set_layer2 对应于c1,c2 P 19 | self.set_layer3 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[1], 3, dy_relu=True, padding=1)) 20 | self.set_layer4 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[1], 3, dy_relu=True, padding=1), True) 21 | # set_layer3.set_layer4 对应于c3,c4 P 22 | self.set_layer5 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[2], 3, dy_relu=True, padding=1)) 23 | self.set_layer6 = SetBlock(BasicConv2d(_set_channels[2], _set_channels[2], 3, dy_relu=True, padding=1)) 24 | # set_layer4.set_layer5 对应于c5,c6 P 25 | 26 | _gl_in_channels = 32 27 | _gl_channels = [64, 128] 28 | # 和上面的结构相同,两个3*3卷积加上池化层 29 | self.gl_layer1 = BasicConv2d(_gl_in_channels, _gl_channels[0], 3, dy_relu=True, padding=1) 30 | self.gl_layer2 = BasicConv2d(_gl_channels[0], _gl_channels[0], 3, dy_relu=True, padding=1) 31 | # 这里也是一样 32 | self.gl_layer3 = BasicConv2d(_gl_channels[0], _gl_channels[1], 3, dy_relu=True, padding=1) 33 | self.gl_layer4 = BasicConv2d(_gl_channels[1], _gl_channels[1], 3, dy_relu=True, padding=1) 34 | self.gl_pooling = nn.MaxPool2d(2) 35 | 36 | self.bin_num = [1, 2, 4, 8, 16] # 论文中的五个尺度在HPM中的 37 | self.fc_bin = nn.ParameterList([ 38 | nn.Parameter( 39 | nn.init.xavier_uniform_( # 参数的形状为62*128*256 40 | torch.zeros(sum(self.bin_num) * 2, 128, hidden_dim)))]) 41 | 42 | for m in self.modules(): 43 | if isinstance(m, (nn.Conv2d, nn.Conv1d)): 44 | nn.init.xavier_uniform_(m.weight.data) 45 | elif isinstance(m, nn.Linear): 46 | nn.init.xavier_uniform_(m.weight.data) 47 | nn.init.constant(m.bias.data, 0.0) 48 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 49 | nn.init.normal(m.weight.data, 1.0, 0.02) 50 | nn.init.constant(m.bias.data, 0.0) 51 | 52 | def frame_max(self, x): 53 | if self.batch_frame is None: 54 | return torch.max(x, 1) 55 | else: 56 | _tmp = [ 57 | torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) 58 | for i in range(len(self.batch_frame) - 1) 59 | ] 60 | max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) 61 | arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) 62 | return max_list, arg_max_list 63 | 64 | def frame_median(self, x): 65 | if self.batch_frame is None: 66 | return torch.median(x, 1) 67 | else: 68 | _tmp = [ 69 | torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) 70 | for i in range(len(self.batch_frame) - 1) 71 | ] 72 | median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) 73 | arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) 74 | return median_list, arg_median_list 75 | 76 | def forward(self, silho, batch_frame=None): 77 | 78 | # TODO 注意这里为了进行可视化网络结构,修改了部分代码 79 | # batch_frame = silho[1] 80 | # silho = silho[0] 81 | # n: batch_size, s: frame_num, k: keypoints_num, c: channel 82 | 83 | if batch_frame is not None: 84 | # 取出了第一个GPU的batch数据 85 | batch_frame = batch_frame[0].data.cpu().numpy().tolist() 86 | _ = len(batch_frame) 87 | for i in range(len(batch_frame)): 88 | # 找出第一个帧数不为0的样本,因为之前batch_frame进行了填充0 89 | if batch_frame[-(i + 1)] != 0: 90 | break 91 | else: 92 | _ -= 1 93 | batch_frame = batch_frame[:_] # 排除掉填充的样本 94 | frame_sum = np.sum(batch_frame) 95 | if frame_sum < silho.size(1): 96 | silho = silho[:, :frame_sum, :, :] 97 | self.batch_frame = [0] + np.cumsum(batch_frame).tolist() 98 | # silho:128*30*64*44 99 | n = silho.size(0) 100 | # x:128*30*1*64*44 101 | x = silho.unsqueeze(2) 102 | 103 | x = self.set_layer1(x) 104 | x = self.set_layer2(x) 105 | # self.frame_max()的返回值为 torch.Size([128, 32, 32, 22]) 106 | # 这里的self.frame_max相当于set pooling 采用了max统计函数 107 | gl = self.gl_layer1(self.frame_max(x)[0]) 108 | gl = self.gl_layer2(gl) 109 | gl = self.gl_pooling(gl) 110 | 111 | x = self.set_layer3(x) 112 | x = self.set_layer4(x) 113 | gl = self.gl_layer3(gl + self.frame_max(x)[0]) 114 | gl = self.gl_layer4(gl) 115 | 116 | x = self.set_layer5(x) 117 | x = self.set_layer6(x) 118 | x = self.frame_max(x)[0] 119 | gl = gl + x 120 | 121 | feature = list() 122 | n, c, h, w = gl.size() 123 | for num_bin in self.bin_num: # 这里的循环相当于对feature map运用HPP 124 | z = x.view(n, c, num_bin, -1) # 按高度进行划分成strips 125 | z = z.mean(3) + z.max(3)[0] # 应用maxpool和avgpool 126 | feature.append(z) # z的形状为 n,c,num_bin 127 | z = gl.view(n, c, num_bin, -1) # 对gl也运用HPP 128 | z = z.mean(3) + z.max(3)[0] 129 | feature.append(z) # 将gl和z的都加入到feature中 130 | feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous() 131 | 132 | # 由于不同比例尺度上的条带描绘了不同的感受野特征,并且每个比例尺度上的不同条带描绘了不同的空间位置的特征,因此使用独立的FC很自然的 133 | # feature:62*128*128,self.fc_bin:62*128*256 134 | # 相当于62个条带,每个条带128维,那么对每个条带分别进行FC的映射 135 | feature = feature.matmul(self.fc_bin[0]) 136 | # 这样经过全连接层计算之后就变成了 62*128*256 137 | feature = feature.permute(1, 0, 2).contiguous() 138 | # 维度变换,128*62*256 139 | 140 | return feature, None 141 | -------------------------------------------------------------------------------- /model/network/triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TripletLoss(nn.Module): 7 | def __init__(self, batch_size, hard_or_full, margin): 8 | super(TripletLoss, self).__init__() 9 | self.batch_size = batch_size 10 | self.margin = margin 11 | 12 | def forward(self, feature, label): 13 | # feature: [n, m, d], label: [n, m] 14 | n, m, d = feature.size() 15 | # hp_mask是找出所有样本对中具有相同标签的,相同的为true,不同的为false 16 | hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).bool().view(-1) 17 | # hn_mask与上面相反,是找出不同的标签的样本对 18 | hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).bool().view(-1) 19 | # 62*128*128 20 | dist = self.batch_dist(feature) # 这里求出了batch中每个样本的各个条带之间的欧式距离 21 | # mean_dist:62 22 | mean_dist = dist.mean(1).mean(1) 23 | dist = dist.view(-1) 24 | # 这里是困难样本对发掘,找出每个样本对应的正样本对中的最大距离,找出每个样本的每个负样本对中最小距离,这就相对于进行困难样本挖掘 25 | # hard 26 | hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0] 27 | hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0] 28 | hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1) 29 | # 计算每个条带的hard_loss的平均值 30 | hard_loss_metric_mean = torch.mean(hard_loss_metric, 1) 31 | 32 | # 这里是求取所有正负样本对的loss,没有进行困难样本挖掘 33 | # non-zero full 34 | full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1) 35 | full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1) 36 | full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1) 37 | # 计算每个正样本对和负样本对之间的triplet loss 38 | # full_loss_metric_sum:62 39 | full_loss_metric_sum = full_loss_metric.sum(1) 40 | # 对每个条带中loss不为0的样本进行统计 41 | full_loss_num = (full_loss_metric != 0).sum(1).float() # loss不为0的进行计数 42 | # 计算每个条带的所有triple loss平均值 43 | full_loss_metric_mean = full_loss_metric_sum / full_loss_num # loss不为0的样本才贡献了损失,所以只对贡献的样本进行平均 44 | full_loss_metric_mean[full_loss_num == 0] = 0 45 | # 返回值的形状依次为:62 , 62, 62, 62 46 | return full_loss_metric_mean, hard_loss_metric_mean, mean_dist, full_loss_num 47 | 48 | def batch_dist(self, x): 49 | # x:[62, 128, 256] 50 | # 相当于:d(A,B)=A^2+B^2-2*A*B,这里采用批量的方式求取了每个样本之间的距离 51 | x2 = torch.sum(x ** 2, 2) 52 | dist = x2.unsqueeze(2) + x2.unsqueeze(2).transpose(1, 2) - 2 * torch.matmul(x, x.transpose(1, 2)) 53 | dist = torch.sqrt(F.relu(dist)) 54 | # 62*128*128 55 | return dist 56 | -------------------------------------------------------------------------------- /model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import load_data 2 | from .data_set import DataSet 3 | from .evaluator import evaluation 4 | from .sampler import TripletSampler -------------------------------------------------------------------------------- /model/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from .data_set import DataSet 7 | 8 | 9 | def load_data(dataset_path, resolution, dataset, pid_num, pid_shuffle, cache=True): 10 | seq_dir = list() # 存放的一个样本的路径地址(因为GaitSet中一个样本是一个轮廓剪影的集合),存放轮廓序列的地址, 11 | # 如:/data/lwl/Gait_experiment/gait_data/001/bg-01/000 12 | view = list() # 存放样本的视角标签,即000,018,...,180,注意这里存放的是和上面样本对应的视角信息 13 | seq_type = list() # 存放样本的序列标记信息,即bg-01,和上面一样对应于每个样本 14 | label = list() # 存放的是样本的ID信息,与每个样本分别对应 15 | 16 | for _label in sorted(list(os.listdir(dataset_path))): # 遍历人物ID标签 17 | # In CASIA-B, data of subject #5 is incomplete. 18 | # Thus, we ignore it in training. 19 | if dataset == 'CASIA-B' and _label == '005': 20 | continue 21 | label_path = osp.join(dataset_path, _label) 22 | for _seq_type in sorted(list(os.listdir(label_path))): # 遍历人物的轮廓序列类型 23 | seq_type_path = osp.join(label_path, _seq_type) 24 | for _view in sorted(list(os.listdir(seq_type_path))): # 遍历轮廓序列的视角 25 | _seq_dir = osp.join(seq_type_path, _view) 26 | seqs = os.listdir(_seq_dir) # 遍历出所有的轮廓剪影 27 | if len(seqs) > 0: 28 | seq_dir.append([_seq_dir]) 29 | label.append(_label) 30 | seq_type.append(_seq_type) 31 | view.append(_view) 32 | 33 | pid_fname = osp.join('partition', '{}_{}_{}.npy'.format( 34 | dataset, pid_num, pid_shuffle)) 35 | if not osp.exists(pid_fname): 36 | pid_list = sorted(list(set(label))) 37 | if pid_shuffle: 38 | np.random.shuffle(pid_list) # 是否对数据集进行随机的划分,注意的是第5个元素被忽略了 39 | pid_list = [pid_list[0:pid_num], pid_list[pid_num:]] 40 | os.makedirs('partition', exist_ok=True) 41 | np.save(pid_fname, pid_list) 42 | # 存放训练集测试集的划分,包括训练集和测试集的人物ID号,第一部分是训练集,第二部分是测试集 43 | 44 | pid_list = np.load(pid_fname, allow_pickle=True) 45 | train_list = pid_list[0] 46 | test_list = pid_list[1] 47 | train_source = DataSet( 48 | # 存放训练集样本的路径地址 49 | [seq_dir[i] for i, l in enumerate(label) if l in train_list], 50 | # 存放的是训练集样本的标签 51 | [label[i] for i, l in enumerate(label) if l in train_list], 52 | # 训练集样本的序列类型 如:bg-01之类 53 | [seq_type[i] for i, l in enumerate(label) if l in train_list], 54 | # 训练集样本对应的视角信息 55 | [view[i] for i, l in enumerate(label) if l in train_list], 56 | cache, 57 | resolution) 58 | # 以下同上存放的是测试集的相关样本信息 59 | test_source = DataSet( 60 | [seq_dir[i] for i, l in enumerate(label) if l in test_list], 61 | [label[i] for i, l in enumerate(label) if l in test_list], 62 | [seq_type[i] for i, l in enumerate(label) if l in test_list], 63 | [view[i] for i, l in enumerate(label) 64 | if l in test_list], 65 | cache, resolution) 66 | 67 | return train_source, test_source 68 | -------------------------------------------------------------------------------- /model/utils/data_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import cv2 5 | import numpy as np 6 | import torch.utils.data as tordata 7 | import xarray as xr 8 | 9 | 10 | class DataSet(tordata.Dataset): 11 | def __init__(self, seq_dir, label, seq_type, view, cache, resolution): 12 | self.seq_dir = seq_dir 13 | self.view = view 14 | self.seq_type = seq_type 15 | self.label = label 16 | self.cache = cache 17 | self.resolution = int(resolution) 18 | self.cut_padding = int(float(resolution) / 64 * 10) # 10 19 | self.data_size = len(self.label) # 数据集样本个数 20 | self.data = [None] * self.data_size 21 | self.frame_set = [None] * self.data_size 22 | 23 | self.label_set = set(self.label) # 去重 ,保存所有的人物标签 24 | self.seq_type_set = set(self.seq_type) # 去重,保存最终的种类(bg-01。。。) 25 | self.view_set = set(self.view) # 视角种类 26 | _ = np.zeros((len(self.label_set), 27 | len(self.seq_type_set), 28 | len(self.view_set))).astype('int') 29 | _ -= 1 # 如果有些轮廓序列缺失,那么其在index_dict中用-1表示其不存在 30 | self.index_dict = xr.DataArray( 31 | _, 32 | coords={'label': sorted(list(self.label_set)), 33 | 'seq_type': sorted(list(self.seq_type_set)), 34 | 'view': sorted(list(self.view_set))}, 35 | dims=['label', 'seq_type', 'view']) 36 | # 用来存储每个样本的对应的下标信息,将其对应到这个三维数组中去 37 | 38 | for i in range(self.data_size): 39 | _label = self.label[i] 40 | _seq_type = self.seq_type[i] 41 | _view = self.view[i] 42 | self.index_dict.loc[_label, _seq_type, _view] = i 43 | # 将所有的样本的下标信息(在self.label,self.seq_type,self.view中的下标信息进行保存)进行保存 44 | 45 | def load_all_data(self): 46 | for i in range(self.data_size): 47 | self.load_data(i) 48 | 49 | def load_data(self, index): 50 | return self.__getitem__(index) 51 | 52 | def __loader__(self, path): 53 | """ 54 | 一个样本的大小为 55 | `30 * 64 * 64`,然后进行了一个裁剪,对宽度进行了裁剪,处理后的大小为 56 | `32 * 64 * 44` 57 | """ 58 | return self.img2xarray(path)[:, :, self.cut_padding:-self.cut_padding]. \ 59 | astype('float32') / 255.0 60 | 61 | def __getitem__(self, index: int): 62 | # pose sequence sampling 63 | # 不使用cache的情况下,直接返回index下标的数据,否则将如果index数据之前没有读取过,就将其加载到self.data中进行缓存,下次用到直接读取,不用重新从磁盘中进行读取 64 | if not self.cache: 65 | # 加载index样本的所有的轮廓剪影图片,例如,_path:/data/lwl/Gait_experiment/gait_data/002/bg-01/000 66 | data = [self.__loader__(_path) for _path in self.seq_dir[index]] 67 | frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data] # 取出对应的帧序号组成集合 68 | frame_set = list(set.intersection(*frame_set)) # 返回集合交集 69 | elif self.data[index] is None: 70 | data = [self.__loader__(_path) for _path in self.seq_dir[index]] 71 | frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data] 72 | frame_set = list(set.intersection(*frame_set)) 73 | self.data[index] = data 74 | self.frame_set[index] = frame_set 75 | else: 76 | data = self.data[index] 77 | frame_set = self.frame_set[index] 78 | # TODO 打印data的大小,以及真正的帧的大小 79 | # print("data的大小为:", len(data)) 80 | # print(data[0].shape) 81 | return data, frame_set, self.view[index], self.seq_type[index], self.label[index], 82 | 83 | def img2xarray(self, flie_path): 84 | imgs = sorted(list(os.listdir(flie_path))) 85 | # 读取指定路径下的所有轮廓剪影,并且将其缩放到64*63*1大小,[:, :, 0]最后切片取出为一个矩阵64*64 86 | frame_list = [np.reshape( 87 | cv2.imread(osp.join(flie_path, _img_path)), [self.resolution, self.resolution, -1])[:, :, 0] 88 | for _img_path in imgs 89 | if osp.isfile(osp.join(flie_path, _img_path))] 90 | 91 | num_list = list(range(len(frame_list))) 92 | data_dict = xr.DataArray( 93 | frame_list, 94 | coords={'frame': num_list}, 95 | dims=['frame', 'img_y', 'img_x'], # 帧编号,帧高,帧宽 96 | ) 97 | return data_dict 98 | 99 | def __len__(self): 100 | return len(self.label) 101 | -------------------------------------------------------------------------------- /model/utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def cuda_dist(x, y): 7 | # 计算x中的每个样本和y中每个样本的距离 8 | x = torch.from_numpy(x).cuda() 9 | y = torch.from_numpy(y).cuda() 10 | dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze( 11 | 1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1)) 12 | dist = torch.sqrt(F.relu(dist)) 13 | # 返回的形状为:x.size(0) * y.size(0) 14 | return dist 15 | 16 | 17 | def evaluation(data, config): 18 | # data : np.concatenate(feature_list, 0), view_list, seq_type_list, label_list 19 | dataset = config['dataset'].split('-')[0] 20 | feature, view, seq_type, label = data 21 | label = np.array(label) 22 | view_list = list(set(view)) 23 | view_list.sort() 24 | view_num = len(view_list) 25 | sample_num = len(feature) 26 | 27 | probe_seq_dict = {'CASIA': [['nm-05', 'nm-06'], ['bg-01', 'bg-02'], ['cl-01', 'cl-02']], 28 | 'OUMVLP': [['00']]} 29 | gallery_seq_dict = {'CASIA': [['nm-01', 'nm-02', 'nm-03', 'nm-04']], 30 | 'OUMVLP': [['01']]} 31 | 32 | num_rank = 5 33 | # 下面的循环是求出probe在probe_view视角下,gallery视角在gallery_view的准确率,而且在是在probe_seq下和对应的gallery_seq下的, 34 | # probe_seq因为包含三种行走条件下的 35 | # 集合个数 视角个数 视角个数 top5 36 | acc = np.zeros([len(probe_seq_dict[dataset]), view_num, view_num, num_rank]) 37 | for (p, probe_seq) in enumerate(probe_seq_dict[dataset]): # probe集合 38 | for gallery_seq in gallery_seq_dict[dataset]: # gallery集合 39 | for (v1, probe_view) in enumerate(view_list): # probe视角列表 40 | for (v2, gallery_view) in enumerate(view_list): # gallery视角列表 41 | # seq(NM-01,NM-02...)类型元素在gallery_seq中,并且在当前的gallery_view 中,因为要求每个视角下的准确率 42 | # gallery_seq和probe_seq都是列表 43 | gseq_mask = np.isin(seq_type, gallery_seq) & np.isin(view, [gallery_view]) 44 | gallery_x = feature[gseq_mask, :] # 找出对应的gallery样本的特征 45 | gallery_y = label[gseq_mask] # 找出对应的gallery样本的标签 46 | # 下面的类似。找出相应的probe的样本特征,标签等 47 | pseq_mask = np.isin(seq_type, probe_seq) & np.isin(view, [probe_view]) 48 | probe_x = feature[pseq_mask, :] 49 | probe_y = label[pseq_mask] 50 | 51 | dist = cuda_dist(probe_x, gallery_x) 52 | idx = dist.sort(1)[1].cpu().numpy() # 对probe中的每个样本的预测的结果进行排序,这里返回的是在原始数组中的下标, 53 | acc[p, v1, v2, :] = np.round( # 这里相当于在计算top(num_rank)的准确率 54 | # acc[p, v1, v2, 0]保存的是top1准确率,而acc[p, v1, v2, num_rank-1]保存的是top5准确率(因为这里的num_rank=5) 55 | # gallery_y[idx[:, 0:num_rank] 按下标取出前num_rank个样本标签 56 | # 注意这里计算的是top(num_rank)的准确率, 57 | # np.cumsum做一个累计计算,计算top_1,top_2,...,top_num_rank的准确率 58 | np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, 59 | 0) * 100 / dist.shape[0], 2) 60 | 61 | return acc 62 | -------------------------------------------------------------------------------- /model/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch.utils.data as tordata 4 | 5 | 6 | class TripletSampler(tordata.sampler.Sampler): 7 | def __init__(self, dataset, batch_size): 8 | self.dataset = dataset 9 | self.batch_size = batch_size 10 | 11 | def __iter__(self): 12 | while (True): 13 | sample_indices = list() 14 | pid_list = random.sample( # 选出相应的p(batch_size[0])个人,这里设置的是选取8个人 15 | list(self.dataset.label_set), 16 | self.batch_size[0]) 17 | for pid in pid_list: 18 | _index = self.dataset.index_dict.loc[pid, :, :].values 19 | _index = _index[_index > 0].flatten().tolist() # 将那些存在轮廓信息的样本的下标取出来,因为下标为-1说明其轮廓序列不存在 20 | _index = random.choices( 21 | _index, # 从每个人的样本集合中选出k(batch_szie[1])个轮廓序列,这里设置的是16个 22 | k=self.batch_size[1]) 23 | sample_indices += _index 24 | yield sample_indices 25 | 26 | def __len__(self): 27 | return self.dataset.data_size 28 | -------------------------------------------------------------------------------- /model/utils/tensorboardDraw.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/9/19 15:21 4 | # 文件 :tensorboardDraw 5 | # IDE :PyCharm 6 | 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | class Tacotron2Logger(SummaryWriter): 11 | def __init__(self, logdir): 12 | super(Tacotron2Logger, self).__init__(logdir) 13 | 14 | def log_training(self, reduced_loss, grad_norm, learning_rate, duration, 15 | iteration): 16 | self.add_scalar("training.loss", reduced_loss, iteration) 17 | self.add_scalar("grad.norm", grad_norm, iteration) 18 | self.add_scalar("learning.rate", learning_rate, iteration) 19 | self.add_scalar("duration", duration, iteration) 20 | 21 | def log_validation(self, reduced_loss, model, y, y_pred, iteration): 22 | self.add_scalar("validation.loss", reduced_loss, iteration) 23 | _, mel_outputs, gate_outputs, alignments = y_pred 24 | mel_targets, gate_targets = y 25 | 26 | # plot distribution of parameters 27 | for tag, value in model.named_parameters(): 28 | tag = tag.replace('.', '/') 29 | self.add_histogram(tag, value.data.cpu().numpy(), iteration) 30 | 31 | # plot alignment, mel target and predicted, gate target and predicted 32 | idx = random.randint(0, alignments.size(0) - 1) 33 | self.add_image( 34 | "alignment", 35 | plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), 36 | iteration) 37 | self.add_image( 38 | "mel_target", 39 | plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), 40 | iteration) 41 | self.add_image( 42 | "mel_predicted", 43 | plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), 44 | iteration) 45 | self.add_image( 46 | "gate", 47 | plot_gate_outputs_to_numpy( 48 | gate_targets[idx].data.cpu().numpy(), 49 | F.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 50 | iteration) 51 | -------------------------------------------------------------------------------- /model2/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/9/26 17:08 4 | # 文件 :__init__.py 5 | # IDE :PyCharm 6 | -------------------------------------------------------------------------------- /model2/initialization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : admin 3 | # @Time : 2018/11/15 4 | import os 5 | from copy import deepcopy 6 | 7 | import numpy as np 8 | 9 | from .model import Model 10 | from .utils import load_data 11 | 12 | 13 | def initialize_data(config, train=False, test=False): 14 | # 这里的train和test代表的是否使用cache 15 | print("Initializing data source...") 16 | # 得到Dateset对象 17 | train_source, test_source = load_data(**config['data'], cache=(train or test)) 18 | if train: 19 | print("Loading training data...") 20 | train_source.load_all_data() 21 | if test: 22 | print("Loading test data...") 23 | test_source.load_all_data() 24 | print("Data initialization complete.") 25 | return train_source, test_source 26 | 27 | 28 | def initialize_model(config, train_source, test_source): 29 | print("Initializing model...") 30 | data_config = config['data'] 31 | model_config = config['model'] 32 | model_param = deepcopy(model_config) 33 | model_param['train_source'] = train_source 34 | model_param['test_source'] = test_source 35 | model_param['train_pid_num'] = data_config['pid_num'] 36 | batch_size = int(np.prod(model_config['batch_size'])) # np.prod 计算所有元素的乘积 37 | model_param['save_name'] = '_'.join(map(str, [ 38 | model_config['model_name'], 39 | data_config['dataset'], 40 | data_config['pid_num'], 41 | data_config['pid_shuffle'], 42 | model_config['hidden_dim'], 43 | model_config['margin'], 44 | batch_size, 45 | model_config['hard_or_full_trip'], 46 | model_config['frame_num'], 47 | ])) 48 | 49 | m = Model(**model_param) 50 | print("Model initialization complete.") 51 | return m, model_param['save_name'] 52 | 53 | 54 | def initialization(config, train=False, test=False): 55 | print("Initialzing...") 56 | WORK_PATH = config['WORK_PATH'] 57 | os.chdir(WORK_PATH) # 改变当前工作目录到指定的路径 58 | os.environ["CUDA_VISIBLE_DEVICES"] = config["CUDA_VISIBLE_DEVICES"] 59 | train_source, test_source = initialize_data(config, train, test) 60 | return initialize_model(config, train_source, test_source) 61 | -------------------------------------------------------------------------------- /model2/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import os.path as osp 4 | import random 5 | import sys 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import torch 10 | import torch.autograd as autograd 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.utils.data as tordata 14 | from tensorboardX import SummaryWriter 15 | from visdom import Visdom 16 | 17 | from .network import TripletLoss, SetNet 18 | from .utils import TripletSampler 19 | 20 | 21 | class Model: 22 | def __init__(self, 23 | hidden_dim, 24 | lr, 25 | hard_or_full_trip, 26 | margin, 27 | num_workers, 28 | batch_size, 29 | restore_iter, 30 | total_iter, 31 | save_name, 32 | train_pid_num, 33 | frame_num, 34 | model_name: str, 35 | train_source, 36 | test_source, 37 | img_size=64, 38 | logdir="./log", 39 | model_save_dir="GaitSet"): 40 | 41 | self.save_name = save_name 42 | self.train_pid_num = train_pid_num 43 | self.train_source = train_source 44 | self.test_source = test_source 45 | self.model_save_dir = model_save_dir 46 | 47 | self.hidden_dim = hidden_dim 48 | self.lr = lr 49 | self.hard_or_full_trip = hard_or_full_trip 50 | self.margin = margin 51 | self.frame_num = frame_num 52 | self.num_workers = num_workers 53 | self.batch_size = batch_size 54 | self.model_name = model_name 55 | self.P, self.M = batch_size 56 | 57 | self.restore_iter = restore_iter 58 | self.total_iter = total_iter 59 | 60 | self.img_size = img_size 61 | 62 | self.encoder = SetNet(self.hidden_dim).float() 63 | # TODO 这里修改了多卡运行的代码 64 | self.encoder = nn.DataParallel(self.encoder) 65 | self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float() 66 | self.triplet_loss = nn.DataParallel(self.triplet_loss) 67 | self.encoder.cuda() 68 | self.triplet_loss.cuda() 69 | 70 | self.optimizer = optim.Adam([ 71 | {'params': self.encoder.parameters()}, 72 | ], lr=self.lr) 73 | 74 | self.hard_loss_metric = [] 75 | self.full_loss_metric = [] 76 | self.full_loss_num = [] 77 | self.dist_list = [] 78 | self.mean_dist = 0.01 79 | 80 | self.sample_type = 'all' 81 | self.logdir = logdir 82 | 83 | def collate_fn(self, batch): 84 | 85 | batch_size = len(batch) # batch的大小 86 | """ 87 | data = [self.__loader__(_path) for _path in self.seq_dir[index]] 88 | feature_num代表的是data数据所包含的集合的个数,这里一直为1,因为读取的是 89 | _seq_dir = osp.join(seq_type_path, _view) 90 | seqs = os.listdir(_seq_dir) # 遍历出所有的轮廓剪影 91 | """ 92 | feature_num = len(batch[0][0]) 93 | 94 | seqs = [batch[i][0] for i in range(batch_size)] # 对应于data 95 | frame_sets = [batch[i][1] for i in range(batch_size)] # 对应于 frame_set 96 | view = [batch[i][2] for i in range(batch_size)] # 对应于view[index] 97 | seq_type = [batch[i][3] for i in range(batch_size)] # 对应于self.seq_type[index] 98 | label = [batch[i][4] for i in range(batch_size)] # 对应于self.label[index] 99 | batch = [seqs, view, seq_type, label, None] 100 | 101 | ''' 102 | 这里的一个样本由 data, frame_set, self. 103 | view[index], self.seq_type[index], self.label[index] 104 | 组成 105 | ''' 106 | 107 | def select_frame(index): 108 | sample = seqs[index] 109 | frame_set = frame_sets[index] 110 | if self.sample_type == 'random': 111 | # 这里的random.choices是有放回的抽取样本 112 | frame_id_list = random.choices(frame_set, k=self.frame_num) 113 | _ = [feature.loc[frame_id_list].values for feature in sample] 114 | # feature.loc[]传入list会取出一组的数据 115 | else: 116 | # 或者选取所有的帧 117 | _ = [feature.values for feature in sample] 118 | return _ 119 | 120 | # 提取出每个样本的帧,组成list,存的是array组成的list 121 | seqs = list(map(select_frame, range(len(seqs)))) 122 | # print(self.sample_type, "采样版本") 123 | if self.sample_type == 'random': 124 | seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] 125 | 126 | else: 127 | # TODO 这里更改了GPU的个数 128 | gpu_num = min(torch.cuda.device_count(), batch_size) 129 | # gpu_num = min(4, batch_size) 130 | batch_per_gpu = math.ceil(batch_size / gpu_num) 131 | 132 | # batch_frames的内容: 133 | # [[gpu1_sample_1_frameNumbers,gpu1_sample_2_frameNumbers,...],[gpu2_sample_1_frameNumbers,gpu2_sample_2_frameNumbers,...],....] 134 | batch_frames = [[ # 将数据划分到不同的GPU上 135 | len(frame_sets[i]) # 每个样本的帧的总数数 136 | for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) 137 | if i < batch_size 138 | ] for _ in range(gpu_num)] 139 | if len(batch_frames[-1]) != batch_per_gpu: 140 | for _ in range(batch_per_gpu - len(batch_frames[-1])): 141 | batch_frames[-1].append(0) # 最后一个GPU上的batch大小不够时,补0 142 | max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)]) # 求出哪个GPU上的帧最多 143 | 144 | # 将每个GPU上的对应的batch数据进行拼接,组成最终的一个大的array 145 | # seqs=[[gpu1_batch,gpu2_batch,gpu3_batch,.....]] 146 | seqs = [[ 147 | np.concatenate([ # 这里将一个batch所有的帧进行了拼接 148 | seqs[i][j] 149 | for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) 150 | if i < batch_size 151 | ], 0) for _ in range(gpu_num)] 152 | for j in range(feature_num)] 153 | # TODO 打印seqs[j][_]的形状大小 154 | # print("seqs的大小:", seqs[0][0].shape) 155 | # 此时的 156 | seqs = [np.asarray([ 157 | np.pad(seqs[j][_], # seqs[j][_]的大小为(GPU_batch_size*frame_number)*64*44 158 | ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)), 159 | 'constant', # 将每个batch的总帧数像最多帧数看齐,如果不够则补0 160 | constant_values=0) 161 | for _ in range(gpu_num)]) 162 | for j in range(feature_num)] 163 | 164 | batch[4] = np.asarray(batch_frames) 165 | 166 | batch[0] = seqs 167 | # TODO 打印seqs的形状大小 168 | # print("seqs的形状大小:", seqs[0].shape) 169 | return batch 170 | 171 | def fit(self): 172 | torch.backends.cudnn.benchmark = True 173 | writer = SummaryWriter(self.logdir) 174 | env_name = self.hard_or_full_trip 175 | vis = Visdom(env=env_name, log_to_filename=osp.join(self.logdir, self.hard_or_full_trip + ".log")) 176 | if self.restore_iter != 0: 177 | self.load(self.restore_iter) 178 | 179 | self.encoder.train() 180 | self.sample_type = 'random' 181 | # todo 这里改变了采样的方式 182 | # self.sample_type = 'all' 183 | for param_group in self.optimizer.param_groups: 184 | param_group['lr'] = self.lr 185 | triplet_sampler = TripletSampler(self.train_source, self.batch_size) # 自定义的采样函数 186 | train_loader = tordata.DataLoader( 187 | dataset=self.train_source, 188 | batch_sampler=triplet_sampler, 189 | collate_fn=self.collate_fn, 190 | num_workers=self.num_workers) 191 | 192 | train_label_set = list(self.train_source.label_set) 193 | train_label_set.sort() 194 | 195 | _time1 = datetime.now() 196 | for seq, view, seq_type, label, batch_frame in train_loader: 197 | self.restore_iter += 1 198 | self.optimizer.zero_grad() 199 | # TODO 这里修改了将数据放在CPU上 200 | for i in range(len(seq)): 201 | seq[i] = self.np2var(seq[i]).float() 202 | if batch_frame is not None: 203 | batch_frame = self.np2var(batch_frame).int() 204 | # with SummaryWriter(comment="encoder") as w: 205 | # self.encoder.cpu() 206 | # # seq.cpu() 207 | # # batch_frame.cpu() 208 | # w.add_graph(self.encoder, (seq[0],)) 209 | 210 | # todo 这里在退出程序,可视化encoder网络结构 211 | # sys.exit() 212 | # print("seq:", seq) 213 | # feature:128*62*256 214 | feature, label_prob = self.encoder(*seq, batch_frame) 215 | # 存放的是在train_label_set = list(self.train_source.label_set)中的下标位置信息 216 | target_label = [train_label_set.index(l) for l in label] 217 | target_label = self.np2var(np.array(target_label)).long() 218 | 219 | # 这里维度变换之后变成了 62*128*256 220 | triplet_feature = feature.permute(1, 0, 2).contiguous() 221 | # triplet_label:62*128 222 | triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1) 223 | (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num 224 | ) = self.triplet_loss(triplet_feature, triplet_label) 225 | loss = 0 226 | if self.hard_or_full_trip == 'hard': 227 | loss = hard_loss_metric.mean() 228 | elif self.hard_or_full_trip == 'full': 229 | loss = full_loss_metric.mean() 230 | else: 231 | # todo 增加了loss的值 232 | loss = hard_loss_metric.mean() + full_loss_metric.mean() 233 | 234 | self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy()) 235 | self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy()) 236 | self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy()) 237 | self.dist_list.append(mean_dist.mean().data.cpu().numpy()) 238 | 239 | if loss > 1e-9: 240 | loss.backward() 241 | self.optimizer.step() 242 | 243 | if self.restore_iter % 1000 == 0: # 打印每隔1000代的训练时间 244 | print(datetime.now() - _time1) 245 | _time1 = datetime.now() 246 | # TODO 更改了每10个batch打印一次 247 | if self.restore_iter % 10 == 0: 248 | print('iter {}:'.format(self.restore_iter), end='') 249 | print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='') 250 | writer.add_scalar("hard_loss_metric", np.mean(self.hard_loss_metric), self.restore_iter) 251 | vis.line(X=np.array([self.restore_iter]), Y=np.array([np.mean(self.hard_loss_metric)]), 252 | win="hard_loss_metric", 253 | update="append", 254 | opts=dict(title="hard_loss_metric")) 255 | 256 | print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='') 257 | writer.add_scalar("full_loss_metric", np.mean(self.full_loss_metric), self.restore_iter) 258 | vis.line(X=np.array([self.restore_iter]), Y=np.array([np.mean(self.full_loss_metric)]), 259 | win="full_loss_metric", 260 | update="append", 261 | opts=dict(title="full_loss_metric")) 262 | 263 | print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='') 264 | writer.add_scalar("full_loss_num", np.mean(self.full_loss_num), self.restore_iter) 265 | vis.line(X=np.array([self.restore_iter]), Y=np.array([np.mean(self.full_loss_num)]), 266 | win="full_loss_num", 267 | update="append", 268 | opts=dict(title="full_loss_num")) 269 | 270 | self.mean_dist = np.mean(self.dist_list) 271 | print(', mean_dist={0:.8f}'.format(self.mean_dist), end='') 272 | writer.add_scalar("mean_dist", self.mean_dist, self.restore_iter) 273 | vis.line(X=np.array([self.restore_iter]), Y=np.array([self.mean_dist]), win="mean_dist", 274 | update="append", 275 | opts=dict(title="mean_dist")) 276 | 277 | print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='') 278 | writer.add_scalar("lr", self.optimizer.param_groups[0]['lr'], self.restore_iter) 279 | vis.line(X=np.array([self.restore_iter]), Y=np.array([self.optimizer.param_groups[0]['lr']]), win="lr", 280 | update="append", 281 | opts=dict(title="lr")) 282 | vis.line(X=np.array([self.restore_iter]), Y=np.array([loss.data.cpu().numpy()]), win="all_loss", 283 | update="append", 284 | opts=dict(title="all_loss")) 285 | print(', hard or full=%r' % self.hard_or_full_trip) 286 | sys.stdout.flush() 287 | self.hard_loss_metric = [] 288 | self.full_loss_metric = [] 289 | self.full_loss_num = [] 290 | self.dist_list = [] 291 | if self.restore_iter % 500 == 0: 292 | self.save() 293 | 294 | # Visualization using t-SNE 295 | # if self.restore_iter % 500 == 0: 296 | # pca = TSNE(2) 297 | # pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy()) 298 | # for i in range(self.P): 299 | # plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0], 300 | # pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i]) 301 | # 302 | # plt.show() 303 | 304 | if self.restore_iter == self.total_iter: 305 | vis.save([env_name]) 306 | break 307 | 308 | def ts2var(self, x): 309 | # TODO 这里修改了不让数据到GPU上 310 | return autograd.Variable(x).cuda() 311 | # return autograd.Variable(x) 312 | 313 | def np2var(self, x): 314 | return self.ts2var(torch.from_numpy(x)) 315 | 316 | def transform(self, flag, batch_size=1): 317 | with torch.no_grad(): 318 | self.encoder.eval() 319 | source = self.test_source if flag == 'test' else self.train_source 320 | self.sample_type = 'all' 321 | data_loader = tordata.DataLoader( 322 | dataset=source, 323 | batch_size=batch_size, 324 | sampler=tordata.sampler.SequentialSampler(source), 325 | collate_fn=self.collate_fn, 326 | num_workers=self.num_workers) 327 | 328 | feature_list = list() 329 | view_list = list() 330 | seq_type_list = list() 331 | label_list = list() 332 | 333 | for i, x in enumerate(data_loader): 334 | seq, view, seq_type, label, batch_frame = x 335 | for j in range(len(seq)): 336 | seq[j] = self.np2var(seq[j]).float() 337 | if batch_frame is not None: 338 | batch_frame = self.np2var(batch_frame).int() 339 | # print(batch_frame, np.sum(batch_frame)) 340 | 341 | feature, _ = self.encoder(*seq, batch_frame) 342 | n, num_bin, _ = feature.size() 343 | feature_list.append(feature.view(n, -1).data.cpu().numpy()) 344 | view_list += view 345 | seq_type_list += seq_type 346 | label_list += label 347 | # feature_list中每个元素的形状为:128*(62X256)=128*15872 348 | # 返回所有样本的特征向量组成的数组,形状为:样本总数*15872 349 | return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list 350 | 351 | def save(self): 352 | os.makedirs(osp.join('checkpoint', self.model_save_dir), exist_ok=True) 353 | torch.save(self.encoder.state_dict(), 354 | osp.join('checkpoint', self.model_save_dir, 355 | '{}-{:0>5}-encoder.ptm'.format( 356 | self.save_name, self.restore_iter))) 357 | torch.save(self.optimizer.state_dict(), 358 | osp.join('checkpoint', self.model_save_dir, 359 | '{}-{:0>5}-optimizer.ptm'.format( 360 | self.save_name, self.restore_iter))) 361 | 362 | # restore_iter: iteration index of the checkpoint to load 363 | def load(self, restore_iter): 364 | self.encoder.load_state_dict(torch.load(osp.join( 365 | 'checkpoint', self.model_save_dir, 366 | '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)))) 367 | self.optimizer.load_state_dict(torch.load(osp.join( 368 | 'checkpoint', self.model_save_dir, 369 | '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter)))) 370 | -------------------------------------------------------------------------------- /model2/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaitset import SetNet 2 | from .triplet import TripletLoss 3 | -------------------------------------------------------------------------------- /model2/network/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | # 基本卷积+激活函数,注意这里的激活函数采用的是 leaky_relu 6 | class BasicConv2d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 8 | super(BasicConv2d, self).__init__() 9 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) 10 | 11 | def forward(self, x): 12 | x = self.conv(x) 13 | return F.leaky_relu(x, inplace=True) 14 | 15 | 16 | class SetBlock(nn.Module): 17 | def __init__(self, forward_block, pooling=False): 18 | super(SetBlock, self).__init__() 19 | self.forward_block = forward_block 20 | self.pooling = pooling 21 | if pooling: 22 | self.pool2d = nn.MaxPool2d(2) 23 | 24 | def forward(self, x): 25 | n, s, c, h, w = x.size() 26 | x = self.forward_block(x.view(-1, c, h, w)) 27 | if self.pooling: 28 | x = self.pool2d(x) 29 | _, c, h, w = x.size() 30 | return x.view(n, s, c, h, w) 31 | -------------------------------------------------------------------------------- /model2/network/basic_blocks_dyrelu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .dyrelu import DyReLUA 5 | 6 | 7 | # 基本卷积+激活函数,注意这里的激活函数采用的是 leaky_relu 8 | class BasicConv2d(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size, dy_relu=False, **kwargs): 10 | super(BasicConv2d, self).__init__() 11 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) 12 | if dy_relu is True: 13 | self.dy_relu = DyReLUA(out_channels) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | if hasattr(self, "dy_relu"): 18 | return self.dy_relu(x) 19 | else: 20 | return F.leaky_relu(x, inplace=True) 21 | 22 | 23 | class SetBlock(nn.Module): 24 | def __init__(self, forward_block, pooling=False): 25 | super(SetBlock, self).__init__() 26 | self.forward_block = forward_block 27 | self.pooling = pooling 28 | if pooling: 29 | self.pool2d = nn.MaxPool2d(2) 30 | 31 | def forward(self, x): 32 | n, s, c, h, w = x.size() 33 | x = self.forward_block(x.view(-1, c, h, w)) 34 | if self.pooling: 35 | x = self.pool2d(x) 36 | _, c, h, w = x.size() 37 | return x.view(n, s, c, h, w) 38 | -------------------------------------------------------------------------------- /model2/network/dyrelu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # 跨空间和通道共享 7 | class DyReLU(nn.Module): 8 | def __init__(self, channels, reduction=4, k=2, conv_type='2d'): 9 | super(DyReLU, self).__init__() 10 | self.channels = channels 11 | self.k = k 12 | self.conv_type = conv_type 13 | assert self.conv_type in ['1d', '2d'] 14 | 15 | self.fc1 = nn.Linear(channels, channels // reduction) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.fc2 = nn.Linear(channels // reduction, 2 * k) 18 | self.sigmoid = nn.Sigmoid() 19 | 20 | self.register_buffer('lambdas', torch.Tensor([1.] * k + [0.5] * k).float()) 21 | self.register_buffer('init_v', torch.Tensor([1.] + [0.] * (2 * k - 1)).float()) 22 | 23 | def get_relu_coefs(self, x): 24 | theta = torch.mean(x, axis=-1) 25 | if self.conv_type == '2d': 26 | theta = torch.mean(theta, axis=-1) 27 | theta = self.fc1(theta) 28 | theta = self.relu(theta) 29 | theta = self.fc2(theta) 30 | theta = 2 * self.sigmoid(theta) - 1 31 | return theta 32 | 33 | def forward(self, x): 34 | raise NotImplementedError 35 | 36 | 37 | class DyReLUA(DyReLU): 38 | def __init__(self, channels, reduction=4, k=2, conv_type='2d'): 39 | super(DyReLUA, self).__init__(channels, reduction, k, conv_type) 40 | self.fc2 = nn.Linear(channels // reduction, 2 * k) 41 | 42 | def forward(self, x): 43 | assert x.shape[1] == self.channels 44 | theta = self.get_relu_coefs(x) 45 | 46 | relu_coefs = theta.view(-1, 2 * self.k) * self.lambdas + self.init_v 47 | # BxCxL -> LxCxBx1 48 | x_perm = x.transpose(0, -1).unsqueeze(-1) 49 | output = x_perm * relu_coefs[:, :self.k] + relu_coefs[:, self.k:] 50 | # LxCxBx2 -> BxCxL 51 | result = torch.max(output, dim=-1)[0].transpose(0, -1) 52 | del x_perm, output, relu_coefs, theta 53 | return result 54 | 55 | 56 | class DyReLUB(DyReLU): 57 | def __init__(self, channels, reduction=4, k=2, conv_type='2d'): 58 | super(DyReLUB, self).__init__(channels, reduction, k, conv_type) 59 | self.fc2 = nn.Linear(channels // reduction, 2 * k * channels) 60 | 61 | def forward(self, x): 62 | assert x.shape[1] == self.channels 63 | theta = self.get_relu_coefs(x) 64 | 65 | relu_coefs = theta.view(-1, self.channels, 2 * self.k) * self.lambdas + self.init_v 66 | 67 | if self.conv_type == '1d': 68 | # BxCxL -> LxBxCx1 69 | x_perm = x.permute(2, 0, 1).unsqueeze(-1) 70 | output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:] 71 | # LxBxCx2 -> BxCxL 72 | result = torch.max(output, dim=-1)[0].permute(1, 2, 0) 73 | 74 | elif self.conv_type == '2d': 75 | # BxCxHxW -> HxWxBxCx1 76 | x_perm = x.permute(2, 3, 0, 1).unsqueeze(-1) 77 | output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:] 78 | # HxWxBxCx2 -> BxCxHxW 79 | result = torch.max(output, dim=-1)[0].permute(2, 3, 0, 1) 80 | del x_perm, output, relu_coefs, theta 81 | return result 82 | -------------------------------------------------------------------------------- /model2/network/gaitset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .basic_blocks_dyrelu import SetBlock, BasicConv2d 6 | 7 | 8 | class SetNet(nn.Module): 9 | def __init__(self, hidden_dim): 10 | super(SetNet, self).__init__() 11 | self.hidden_dim = hidden_dim 12 | self.batch_frame = None 13 | 14 | _set_in_channels = 1 15 | _set_channels = [32, 64, 128] 16 | self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 5, padding=2)) 17 | self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 3, padding=1), True) 18 | # set_layer1.set_layer2 对应于c1,c2 P 19 | self.set_layer3 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[1], 3, padding=1)) 20 | self.set_layer4 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[1], 3, padding=1), True) 21 | # set_layer3.set_layer4 对应于c3,c4 P 22 | self.set_layer5 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[2], 3, padding=1)) 23 | self.set_layer6 = SetBlock(BasicConv2d(_set_channels[2], _set_channels[2], 3, padding=1)) 24 | # set_layer4.set_layer5 对应于c5,c6 P 25 | 26 | _gl_in_channels = 32 27 | _gl_channels = [64, 128] 28 | # 和上面的结构相同,两个3*3卷积加上池化层 29 | self.gl_layer1 = BasicConv2d(_gl_in_channels, _gl_channels[0], 3, padding=1) 30 | self.gl_layer2 = BasicConv2d(_gl_channels[0], _gl_channels[0], 3, padding=1) 31 | # 这里也是一样 32 | self.gl_layer3 = BasicConv2d(_gl_channels[0], _gl_channels[1], 3, padding=1) 33 | self.gl_layer4 = BasicConv2d(_gl_channels[1], _gl_channels[1], 3, padding=1) 34 | self.gl_pooling = nn.MaxPool2d(2) 35 | 36 | self.bin_num = [1, 2, 4, 8, 16] # 论文中的五个尺度在HPM中的 37 | self.fc_bin = nn.ParameterList([ 38 | nn.Parameter( 39 | nn.init.xavier_uniform_( # 参数的形状为62*128*256 40 | torch.zeros(sum(self.bin_num) * 2, 128, hidden_dim)))]) 41 | 42 | for m in self.modules(): 43 | if isinstance(m, (nn.Conv2d, nn.Conv1d)): 44 | nn.init.xavier_uniform_(m.weight.data) 45 | elif isinstance(m, nn.Linear): 46 | nn.init.xavier_uniform_(m.weight.data) 47 | nn.init.constant(m.bias.data, 0.0) 48 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 49 | nn.init.normal(m.weight.data, 1.0, 0.02) 50 | nn.init.constant(m.bias.data, 0.0) 51 | 52 | def frame_max(self, x): 53 | if self.batch_frame is None: 54 | return torch.max(x, 1) 55 | else: 56 | _tmp = [ 57 | torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) 58 | for i in range(len(self.batch_frame) - 1) 59 | ] 60 | max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) 61 | arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) 62 | return max_list, arg_max_list 63 | 64 | def frame_median(self, x): 65 | if self.batch_frame is None: 66 | return torch.median(x, 1) 67 | else: 68 | _tmp = [ 69 | torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) 70 | for i in range(len(self.batch_frame) - 1) 71 | ] 72 | median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) 73 | arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) 74 | return median_list, arg_median_list 75 | 76 | def forward(self, silho, batch_frame=None): 77 | 78 | # TODO 注意这里为了进行可视化网络结构,修改了部分代码 79 | # batch_frame = silho[1] 80 | # silho = silho[0] 81 | # n: batch_size, s: frame_num, k: keypoints_num, c: channel 82 | 83 | if batch_frame is not None: 84 | # 取出了第一个GPU的batch数据 85 | batch_frame = batch_frame[0].data.cpu().numpy().tolist() 86 | _ = len(batch_frame) 87 | for i in range(len(batch_frame)): 88 | # 找出第一个帧数不为0的样本,因为之前batch_frame进行了填充0 89 | if batch_frame[-(i + 1)] != 0: 90 | break 91 | else: 92 | _ -= 1 93 | batch_frame = batch_frame[:_] # 排除掉填充的样本 94 | frame_sum = np.sum(batch_frame) 95 | if frame_sum < silho.size(1): 96 | silho = silho[:, :frame_sum, :, :] 97 | self.batch_frame = [0] + np.cumsum(batch_frame).tolist() 98 | # silho:128*30*64*44 99 | n = silho.size(0) 100 | # x:128*30*1*64*44 101 | x = silho.unsqueeze(2) 102 | 103 | x = self.set_layer1(x) 104 | x = self.set_layer2(x) 105 | # self.frame_max()的返回值为 torch.Size([128, 32, 32, 22]) 106 | # 这里的self.frame_max相当于set pooling 采用了max统计函数 107 | gl = self.gl_layer1(self.frame_max(x)[0]) 108 | gl = self.gl_layer2(gl) 109 | gl = self.gl_pooling(gl) 110 | 111 | x = self.set_layer3(x) 112 | x = self.set_layer4(x) 113 | gl = self.gl_layer3(gl + self.frame_max(x)[0]) 114 | gl = self.gl_layer4(gl) 115 | 116 | x = self.set_layer5(x) 117 | x = self.set_layer6(x) 118 | x = self.frame_max(x)[0] 119 | gl = gl + x 120 | 121 | feature = list() 122 | n, c, h, w = gl.size() 123 | for num_bin in self.bin_num: # 这里的循环相当于对feature map运用HPP 124 | z = x.view(n, c, num_bin, -1) # 按高度进行划分成strips 125 | z = z.mean(3) + z.max(3)[0] # 应用maxpool和avgpool 126 | feature.append(z) # z的形状为 n,c,num_bin 127 | z = gl.view(n, c, num_bin, -1) # 对gl也运用HPP 128 | z = z.mean(3) + z.max(3)[0] 129 | feature.append(z) # 将gl和z的都加入到feature中 130 | feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous() 131 | 132 | # 由于不同比例尺度上的条带描绘了不同的感受野特征,并且每个比例尺度上的不同条带描绘了不同的空间位置的特征,因此使用独立的FC很自然的 133 | # feature:62*128*128,self.fc_bin:62*128*256 134 | # 相当于62个条带,每个条带128维,那么对每个条带分别进行FC的映射 135 | feature = feature.matmul(self.fc_bin[0]) 136 | # 这样经过全连接层计算之后就变成了 62*128*256 137 | feature = feature.permute(1, 0, 2).contiguous() 138 | # 维度变换,128*62*256 139 | 140 | return feature, None 141 | -------------------------------------------------------------------------------- /model2/network/triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TripletLoss(nn.Module): 7 | def __init__(self, batch_size, hard_or_full, margin): 8 | super(TripletLoss, self).__init__() 9 | self.batch_size = batch_size 10 | self.margin = margin 11 | 12 | def forward(self, feature, label): 13 | # feature: [n, m, d], label: [n, m] 14 | n, m, d = feature.size() 15 | # hp_mask是找出所有样本对中具有相同标签的,相同的为true,不同的为false 16 | hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).bool().view(-1) 17 | # hn_mask与上面相反,是找出不同的标签的样本对 18 | hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).bool().view(-1) 19 | # 62*128*128 20 | dist = self.batch_dist(feature) # 这里求出了batch中每个样本的各个条带之间的欧式距离 21 | # mean_dist:62 22 | mean_dist = dist.mean(1).mean(1) 23 | dist = dist.view(-1) 24 | # 这里是困难样本对发掘,找出每个样本对应的正样本对中的最大距离,找出每个样本的每个负样本对中最小距离,这就相对于进行困难样本挖掘 25 | # hard 26 | hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0] 27 | hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0] 28 | hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1) 29 | # 计算每个条带的hard_loss的平均值 30 | hard_loss_metric_mean = torch.mean(hard_loss_metric, 1) 31 | 32 | # 这里是求取所有正负样本对的loss,没有进行困难样本挖掘 33 | # non-zero full 34 | full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1) 35 | full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1) 36 | full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1) 37 | # 计算每个正样本对和负样本对之间的triplet loss 38 | # full_loss_metric_sum:62 39 | full_loss_metric_sum = full_loss_metric.sum(1) 40 | # 对每个条带中loss不为0的样本进行统计 41 | full_loss_num = (full_loss_metric != 0).sum(1).float() # loss不为0的进行计数 42 | # 计算每个条带的所有triple loss平均值 43 | full_loss_metric_mean = full_loss_metric_sum / full_loss_num # loss不为0的样本才贡献了损失,所以只对贡献的样本进行平均 44 | full_loss_metric_mean[full_loss_num == 0] = 0 45 | # 返回值的形状依次为:62 , 62, 62, 62 46 | return full_loss_metric_mean, hard_loss_metric_mean, mean_dist, full_loss_num 47 | 48 | def batch_dist(self, x): 49 | # x:[62, 128, 256] 50 | # 相当于:d(A,B)=A^2+B^2-2*A*B,这里采用批量的方式求取了每个样本之间的距离 51 | x2 = torch.sum(x ** 2, 2) 52 | dist = x2.unsqueeze(2) + x2.unsqueeze(2).transpose(1, 2) - 2 * torch.matmul(x, x.transpose(1, 2)) 53 | dist = torch.sqrt(F.relu(dist)) 54 | # 62*128*128 55 | return dist 56 | -------------------------------------------------------------------------------- /model2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import load_data 2 | from .data_set import DataSet 3 | from .evaluator import evaluation 4 | from .sampler import TripletSampler -------------------------------------------------------------------------------- /model2/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from .data_set import DataSet 7 | 8 | 9 | def load_data(dataset_path, resolution, dataset, pid_num, pid_shuffle, cache=True): 10 | seq_dir = list() # 存放的一个样本的路径地址(因为GaitSet中一个样本是一个轮廓剪影的集合),存放轮廓序列的地址, 11 | # 如:/data/lwl/Gait_experiment/gait_data/001/bg-01/000 12 | view = list() # 存放样本的视角标签,即000,018,...,180,注意这里存放的是和上面样本对应的视角信息 13 | seq_type = list() # 存放样本的序列标记信息,即bg-01,和上面一样对应于每个样本 14 | label = list() # 存放的是样本的ID信息,与每个样本分别对应 15 | 16 | for _label in sorted(list(os.listdir(dataset_path))): # 遍历人物ID标签 17 | # In CASIA-B, data of subject #5 is incomplete. 18 | # Thus, we ignore it in training. 19 | if dataset == 'CASIA-B' and _label == '005': 20 | continue 21 | label_path = osp.join(dataset_path, _label) 22 | for _seq_type in sorted(list(os.listdir(label_path))): # 遍历人物的轮廓序列类型 23 | seq_type_path = osp.join(label_path, _seq_type) 24 | for _view in sorted(list(os.listdir(seq_type_path))): # 遍历轮廓序列的视角 25 | _seq_dir = osp.join(seq_type_path, _view) 26 | seqs = os.listdir(_seq_dir) # 遍历出所有的轮廓剪影 27 | if len(seqs) > 0: 28 | seq_dir.append([_seq_dir]) 29 | label.append(_label) 30 | seq_type.append(_seq_type) 31 | view.append(_view) 32 | 33 | pid_fname = osp.join('partition', '{}_{}_{}.npy'.format( 34 | dataset, pid_num, pid_shuffle)) 35 | if not osp.exists(pid_fname): 36 | pid_list = sorted(list(set(label))) 37 | if pid_shuffle: 38 | np.random.shuffle(pid_list) # 是否对数据集进行随机的划分,注意的是第5个元素被忽略了 39 | pid_list = [pid_list[0:pid_num], pid_list[pid_num:]] 40 | os.makedirs('partition', exist_ok=True) 41 | np.save(pid_fname, pid_list) 42 | # 存放训练集测试集的划分,包括训练集和测试集的人物ID号,第一部分是训练集,第二部分是测试集 43 | 44 | pid_list = np.load(pid_fname, allow_pickle=True) 45 | train_list = pid_list[0] 46 | test_list = pid_list[1] 47 | train_source = DataSet( 48 | # 存放训练集样本的路径地址 49 | [seq_dir[i] for i, l in enumerate(label) if l in train_list], 50 | # 存放的是训练集样本的标签 51 | [label[i] for i, l in enumerate(label) if l in train_list], 52 | # 训练集样本的序列类型 如:bg-01之类 53 | [seq_type[i] for i, l in enumerate(label) if l in train_list], 54 | # 训练集样本对应的视角信息 55 | [view[i] for i, l in enumerate(label) if l in train_list], 56 | cache, 57 | resolution) 58 | # 以下同上存放的是测试集的相关样本信息 59 | test_source = DataSet( 60 | [seq_dir[i] for i, l in enumerate(label) if l in test_list], 61 | [label[i] for i, l in enumerate(label) if l in test_list], 62 | [seq_type[i] for i, l in enumerate(label) if l in test_list], 63 | [view[i] for i, l in enumerate(label) 64 | if l in test_list], 65 | cache, resolution) 66 | 67 | return train_source, test_source 68 | -------------------------------------------------------------------------------- /model2/utils/data_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import cv2 5 | import numpy as np 6 | import torch.utils.data as tordata 7 | import xarray as xr 8 | 9 | 10 | class DataSet(tordata.Dataset): 11 | def __init__(self, seq_dir, label, seq_type, view, cache, resolution): 12 | self.seq_dir = seq_dir 13 | self.view = view 14 | self.seq_type = seq_type 15 | self.label = label 16 | self.cache = cache 17 | self.resolution = int(resolution) 18 | self.cut_padding = int(float(resolution) / 64 * 10) # 10 19 | self.data_size = len(self.label) # 数据集样本个数 20 | self.data = [None] * self.data_size 21 | self.frame_set = [None] * self.data_size 22 | 23 | self.label_set = set(self.label) # 去重 ,保存所有的人物标签 24 | self.seq_type_set = set(self.seq_type) # 去重,保存最终的种类(bg-01。。。) 25 | self.view_set = set(self.view) # 视角种类 26 | _ = np.zeros((len(self.label_set), 27 | len(self.seq_type_set), 28 | len(self.view_set))).astype('int') 29 | _ -= 1 # 如果有些轮廓序列缺失,那么其在index_dict中用-1表示其不存在 30 | self.index_dict = xr.DataArray( 31 | _, 32 | coords={'label': sorted(list(self.label_set)), 33 | 'seq_type': sorted(list(self.seq_type_set)), 34 | 'view': sorted(list(self.view_set))}, 35 | dims=['label', 'seq_type', 'view']) 36 | # 用来存储每个样本的对应的下标信息,将其对应到这个三维数组中去 37 | 38 | for i in range(self.data_size): 39 | _label = self.label[i] 40 | _seq_type = self.seq_type[i] 41 | _view = self.view[i] 42 | self.index_dict.loc[_label, _seq_type, _view] = i 43 | # 将所有的样本的下标信息(在self.label,self.seq_type,self.view中的下标信息进行保存)进行保存 44 | # 比如012/bg-02/090(第12个人bg-02条件下的90°视角下的轮廓序列)对应的样本下标信息为i 45 | 46 | def load_all_data(self): 47 | for i in range(self.data_size): 48 | self.load_data(i) 49 | 50 | def load_data(self, index): 51 | return self.__getitem__(index) 52 | 53 | def __loader__(self, path): 54 | """ 55 | 一个样本的大小为 56 | `30 * 64 * 64`,然后进行了一个裁剪,对宽度进行了裁剪,处理后的大小为 57 | `32 * 64 * 44` 58 | """ 59 | return self.img2xarray(path)[:, :, self.cut_padding:-self.cut_padding]. \ 60 | astype('float32') / 255.0 61 | 62 | def __getitem__(self, index: int): 63 | # pose sequence sampling 64 | # 不使用cache的情况下,直接返回index下标的数据,否则将如果index数据之前没有读取过,就将其加载到self.data中进行缓存,下次用到直接读取,不用重新从磁盘中进行读取 65 | if not self.cache: 66 | # 加载index样本的所有的轮廓剪影图片,例如,_path:/data/lwl/Gait_experiment/gait_data/002/bg-01/000 67 | data = [self.__loader__(_path) for _path in self.seq_dir[index]] 68 | frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data] # 取出对应的帧序号组成集合 69 | frame_set = list(set.intersection(*frame_set)) # 返回集合交集 70 | elif self.data[index] is None: 71 | data = [self.__loader__(_path) for _path in self.seq_dir[index]] 72 | frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data] 73 | frame_set = list(set.intersection(*frame_set)) 74 | self.data[index] = data 75 | self.frame_set[index] = frame_set 76 | else: 77 | data = self.data[index] 78 | frame_set = self.frame_set[index] 79 | # TODO 打印data的大小,以及真正的帧的大小 80 | # print("data的大小为:", len(data)) 81 | # print(data[0].shape) 82 | return data, frame_set, self.view[index], self.seq_type[index], self.label[index], 83 | 84 | def img2xarray(self, flie_path): 85 | imgs = sorted(list(os.listdir(flie_path))) 86 | # 读取指定路径下的所有轮廓剪影,并且将其缩放到64*63*1大小,[:, :, 0]最后切片取出为一个矩阵64*64 87 | frame_list = [np.reshape( 88 | cv2.imread(osp.join(flie_path, _img_path)), [self.resolution, self.resolution, -1])[:, :, 0] 89 | for _img_path in imgs 90 | if osp.isfile(osp.join(flie_path, _img_path))] 91 | 92 | num_list = list(range(len(frame_list))) 93 | data_dict = xr.DataArray( 94 | frame_list, 95 | coords={'frame': num_list}, 96 | dims=['frame', 'img_y', 'img_x'], # 帧编号,帧高,帧宽 97 | ) 98 | return data_dict 99 | 100 | def __len__(self): 101 | return len(self.label) 102 | -------------------------------------------------------------------------------- /model2/utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def cuda_dist(x, y): 7 | # 计算x中的每个样本和y中每个样本的距离 8 | x = torch.from_numpy(x).cuda() 9 | y = torch.from_numpy(y).cuda() 10 | dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze( 11 | 1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1)) 12 | dist = torch.sqrt(F.relu(dist)) 13 | # 返回的形状为:x.size(0) * y.size(0) 14 | return dist 15 | 16 | 17 | def evaluation(data, config): 18 | # data : np.concatenate(feature_list, 0), view_list, seq_type_list, label_list 19 | dataset = config['dataset'].split('-')[0] 20 | feature, view, seq_type, label = data 21 | label = np.array(label) 22 | view_list = list(set(view)) 23 | view_list.sort() 24 | view_num = len(view_list) 25 | sample_num = len(feature) 26 | 27 | probe_seq_dict = {'CASIA': [['nm-05', 'nm-06'], ['bg-01', 'bg-02'], ['cl-01', 'cl-02']], 28 | 'OUMVLP': [['00']]} 29 | gallery_seq_dict = {'CASIA': [['nm-01', 'nm-02', 'nm-03', 'nm-04']], 30 | 'OUMVLP': [['01']]} 31 | 32 | num_rank = 5 33 | # 下面的循环是求出probe在probe_view视角下,gallery视角在gallery_view的准确率,而且在是在probe_seq下和对应的gallery_seq下的, 34 | # probe_seq因为包含三种行走条件下的 35 | # 集合个数 视角个数 视角个数 top5 36 | acc = np.zeros([len(probe_seq_dict[dataset]), view_num, view_num, num_rank]) 37 | for (p, probe_seq) in enumerate(probe_seq_dict[dataset]): # probe集合 38 | for gallery_seq in gallery_seq_dict[dataset]: # gallery集合 39 | for (v1, probe_view) in enumerate(view_list): # probe视角列表 40 | for (v2, gallery_view) in enumerate(view_list): # gallery视角列表 41 | # seq(NM-01,NM-02...)类型元素在gallery_seq中,并且在当前的gallery_view 中,因为要求每个视角下的准确率 42 | # gallery_seq和probe_seq都是列表 43 | gseq_mask = np.isin(seq_type, gallery_seq) & np.isin(view, [gallery_view]) 44 | gallery_x = feature[gseq_mask, :] # 找出对应的gallery样本的特征 45 | gallery_y = label[gseq_mask] # 找出对应的gallery样本的标签 46 | # 下面的类似。找出相应的probe的样本特征,标签等 47 | pseq_mask = np.isin(seq_type, probe_seq) & np.isin(view, [probe_view]) 48 | probe_x = feature[pseq_mask, :] 49 | probe_y = label[pseq_mask] 50 | 51 | dist = cuda_dist(probe_x, gallery_x) 52 | idx = dist.sort(1)[1].cpu().numpy() # 对probe中的每个样本的预测的结果进行排序,这里返回的是在原始数组中的下标, 53 | acc[p, v1, v2, :] = np.round( # 这里相当于在计算top(num_rank)的准确率 54 | # acc[p, v1, v2, 0]保存的是top1准确率,而acc[p, v1, v2, num_rank-1]保存的是top5准确率(因为这里的num_rank=5) 55 | # gallery_y[idx[:, 0:num_rank] 按下标取出前num_rank个样本标签 56 | # 注意这里计算的是top(num_rank)的准确率, 57 | # np.cumsum做一个累计计算,计算top_1,top_2,...,top_num_rank的准确率 58 | np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, 59 | 0) * 100 / dist.shape[0], 2) 60 | 61 | return acc 62 | -------------------------------------------------------------------------------- /model2/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch.utils.data as tordata 4 | 5 | 6 | class TripletSampler(tordata.sampler.Sampler): 7 | def __init__(self, dataset, batch_size): 8 | self.dataset = dataset 9 | self.batch_size = batch_size 10 | 11 | def __iter__(self): 12 | while (True): 13 | sample_indices = list() 14 | pid_list = random.sample( # 选出相应的p(batch_size[0])个人,这里设置的是选取8个人 15 | list(self.dataset.label_set), 16 | self.batch_size[0]) 17 | for pid in pid_list: 18 | _index = self.dataset.index_dict.loc[pid, :, :].values 19 | _index = _index[_index > 0].flatten().tolist() # 将那些存在轮廓信息的样本的下标取出来,因为下标为-1说明其轮廓序列不存在 20 | _index = random.choices( 21 | _index, # 从每个人的样本集合中选出k(batch_szie[1])个轮廓序列,这里设置的是16个 22 | k=self.batch_size[1]) 23 | sample_indices += _index 24 | yield sample_indices 25 | 26 | def __len__(self): 27 | return self.dataset.data_size 28 | -------------------------------------------------------------------------------- /model2/utils/tensorboardDraw.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/9/19 15:21 4 | # 文件 :tensorboardDraw 5 | # IDE :PyCharm 6 | 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | class Tacotron2Logger(SummaryWriter): 11 | def __init__(self, logdir): 12 | super(Tacotron2Logger, self).__init__(logdir) 13 | 14 | def log_training(self, reduced_loss, grad_norm, learning_rate, duration, 15 | iteration): 16 | self.add_scalar("training.loss", reduced_loss, iteration) 17 | self.add_scalar("grad.norm", grad_norm, iteration) 18 | self.add_scalar("learning.rate", learning_rate, iteration) 19 | self.add_scalar("duration", duration, iteration) 20 | 21 | def log_validation(self, reduced_loss, model, y, y_pred, iteration): 22 | self.add_scalar("validation.loss", reduced_loss, iteration) 23 | _, mel_outputs, gate_outputs, alignments = y_pred 24 | mel_targets, gate_targets = y 25 | 26 | # plot distribution of parameters 27 | for tag, value in model.named_parameters(): 28 | tag = tag.replace('.', '/') 29 | self.add_histogram(tag, value.data.cpu().numpy(), iteration) 30 | 31 | # plot alignment, mel target and predicted, gate target and predicted 32 | idx = random.randint(0, alignments.size(0) - 1) 33 | self.add_image( 34 | "alignment", 35 | plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), 36 | iteration) 37 | self.add_image( 38 | "mel_target", 39 | plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), 40 | iteration) 41 | self.add_image( 42 | "mel_predicted", 43 | plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), 44 | iteration) 45 | self.add_image( 46 | "gate", 47 | plot_gate_outputs_to_numpy( 48 | gate_targets[idx].data.cpu().numpy(), 49 | F.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 50 | iteration) 51 | -------------------------------------------------------------------------------- /pretreatment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Abner 3 | # @Time : 2018/12/19 4 | 5 | import argparse 6 | import os 7 | from multiprocessing import Pool 8 | from multiprocessing import TimeoutError as MP_TimeoutError 9 | from time import sleep 10 | from warnings import warn 11 | 12 | import cv2 13 | import numpy as np 14 | 15 | START = "START" 16 | FINISH = "FINISH" 17 | WARNING = "WARNING" 18 | FAIL = "FAIL" 19 | 20 | 21 | def boolean_string(s): 22 | if s.upper() not in {'FALSE', 'TRUE'}: 23 | raise ValueError('Not a valid boolean string') 24 | return s.upper() == 'TRUE' 25 | 26 | 27 | parser = argparse.ArgumentParser(description='Test') 28 | parser.add_argument('--input_path', default='', type=str, 29 | help='Root path of raw dataset.') 30 | parser.add_argument('--output_path', default='', type=str, 31 | help='Root path for output.') 32 | parser.add_argument('--log_file', default='./pretreatment.log', type=str, 33 | help='Log file path. Default: ./pretreatment.log') 34 | parser.add_argument('--log', default=False, type=boolean_string, 35 | help='If set as True, all logs will be saved. ' 36 | 'Otherwise, only warnings and errors will be saved.' 37 | 'Default: False') 38 | parser.add_argument('--worker_num', default=1, type=int, 39 | help='How many subprocesses to use for data pretreatment. ' 40 | 'Default: 1') 41 | opt = parser.parse_args() 42 | 43 | INPUT_PATH = opt.input_path 44 | OUTPUT_PATH = opt.output_path 45 | IF_LOG = opt.log 46 | LOG_PATH = opt.log_file 47 | WORKERS = opt.worker_num 48 | 49 | T_H = 64 50 | T_W = 64 51 | 52 | 53 | def log2str(pid, comment, logs): 54 | str_log = '' 55 | if type(logs) is str: 56 | logs = [logs] 57 | for log in logs: 58 | str_log += "# JOB %d : --%s-- %s\n" % ( 59 | pid, comment, log) 60 | return str_log 61 | 62 | 63 | def log_print(pid, comment, logs): 64 | str_log = log2str(pid, comment, logs) 65 | if comment in [WARNING, FAIL]: 66 | with open(LOG_PATH, 'a') as log_f: 67 | log_f.write(str_log) 68 | if comment in [START, FINISH]: 69 | if pid % 500 != 0: 70 | return 71 | print(str_log, end='') 72 | 73 | 74 | def cut_img(img, seq_info, frame_name, pid): 75 | # A silhouette contains too little white pixels 76 | # might be not valid for identification. 77 | if img.sum() <= 10000: 78 | message = 'seq:%s, frame:%s, no data, %d.' % ( 79 | '-'.join(seq_info), frame_name, img.sum()) 80 | warn(message) 81 | log_print(pid, WARNING, message) 82 | return None 83 | # Get the top and bottom point 84 | y = img.sum(axis=1) 85 | y_top = (y != 0).argmax(axis=0) 86 | y_btm = (y != 0).cumsum(axis=0).argmax(axis=0) 87 | img = img[y_top:y_btm + 1, :] 88 | # As the height of a person is larger than the width, 89 | # use the height to calculate resize ratio. 90 | _r = img.shape[1] / img.shape[0] 91 | _t_w = int(T_H * _r) 92 | img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC) 93 | # Get the median of x axis and regard it as the x center of the person. 94 | sum_point = img.sum() 95 | sum_column = img.sum(axis=0).cumsum() 96 | x_center = -1 97 | for i in range(sum_column.size): 98 | if sum_column[i] > sum_point / 2: 99 | x_center = i 100 | break 101 | if x_center < 0: 102 | message = 'seq:%s, frame:%s, no center.' % ( 103 | '-'.join(seq_info), frame_name) 104 | warn(message) 105 | log_print(pid, WARNING, message) 106 | return None 107 | h_T_W = int(T_W / 2) 108 | left = x_center - h_T_W 109 | right = x_center + h_T_W 110 | if left <= 0 or right >= img.shape[1]: 111 | left += h_T_W 112 | right += h_T_W 113 | _ = np.zeros((img.shape[0], h_T_W)) 114 | img = np.concatenate([_, img, _], axis=1) 115 | img = img[:, left:right] 116 | return img.astype('uint8') 117 | 118 | 119 | def cut_pickle(seq_info, pid): 120 | seq_name = '-'.join(seq_info) 121 | log_print(pid, START, seq_name) 122 | seq_path = os.path.join(INPUT_PATH, *seq_info) 123 | out_dir = os.path.join(OUTPUT_PATH, *seq_info) 124 | frame_list = os.listdir(seq_path) 125 | frame_list.sort() 126 | count_frame = 0 127 | for _frame_name in frame_list: 128 | frame_path = os.path.join(seq_path, _frame_name) 129 | img = cv2.imread(frame_path)[:, :, 0] 130 | img = cut_img(img, seq_info, _frame_name, pid) 131 | if img is not None: 132 | # Save the cut img 133 | save_path = os.path.join(out_dir, _frame_name) 134 | cv2.imwrite(save_path, img) 135 | count_frame += 1 136 | # Warn if the sequence contains less than 5 frames 137 | if count_frame < 5: 138 | message = 'seq:%s, less than 5 valid data.' % ( 139 | '-'.join(seq_info)) 140 | warn(message) 141 | log_print(pid, WARNING, message) 142 | 143 | log_print(pid, FINISH, 144 | 'Contain %d valid frames. Saved to %s.' 145 | % (count_frame, out_dir)) 146 | 147 | 148 | pool = Pool(WORKERS) 149 | results = list() 150 | pid = 0 151 | 152 | print('Pretreatment Start.\n' 153 | 'Input path: %s\n' 154 | 'Output path: %s\n' 155 | 'Log file: %s\n' 156 | 'Worker num: %d' % ( 157 | INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS)) 158 | 159 | id_list = os.listdir(INPUT_PATH) 160 | id_list.sort() 161 | # Walk the input path 162 | for _id in id_list: 163 | seq_type = os.listdir(os.path.join(INPUT_PATH, _id)) 164 | seq_type.sort() 165 | for _seq_type in seq_type: 166 | view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type)) 167 | view.sort() 168 | for _view in view: 169 | seq_info = [_id, _seq_type, _view] 170 | out_dir = os.path.join(OUTPUT_PATH, *seq_info) 171 | if not os.path.exists(out_dir): 172 | os.makedirs(out_dir) 173 | results.append( 174 | pool.apply_async( 175 | cut_pickle, 176 | args=(seq_info, pid))) 177 | sleep(0.02) 178 | pid += 1 179 | 180 | pool.close() 181 | unfinish = 1 182 | while unfinish > 0: 183 | unfinish = 0 184 | for i, res in enumerate(results): 185 | try: 186 | res.get(timeout=0.1) 187 | except Exception as e: 188 | if type(e) == MP_TimeoutError: 189 | unfinish += 1 190 | continue 191 | else: 192 | print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n', 193 | i, type(e)) 194 | raise e 195 | pool.join() 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.1 2 | xarray==0.16.1 3 | opencv_python==4.1.2.30 4 | numpy==1.17.4 5 | tensorboardX==2.1 6 | visdom==0.1.8.9 7 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/9/26 10:54 4 | # 文件 :__init__.py 5 | # IDE :PyCharm 6 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | 6 | from config.config import conf 7 | from model.initialization import initialization 8 | from model.utils import evaluation 9 | 10 | 11 | def boolean_string(s): 12 | if s.upper() not in {'FALSE', 'TRUE'}: 13 | raise ValueError('Not a valid boolean string') 14 | return s.upper() == 'TRUE' 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Test') 18 | parser.add_argument('--iter', default='80000', type=int, 19 | help='iter: iteration of the checkpoint to load. Default: 80000') 20 | parser.add_argument('--batch_size', default='1', type=int, 21 | help='batch_size: batch size for parallel test. Default: 1') 22 | parser.add_argument('--cache', default=False, type=boolean_string, 23 | help='cache: if set as TRUE all the test data will be loaded at once' 24 | ' before the transforming start. Default: FALSE') 25 | opt = parser.parse_args() 26 | 27 | 28 | # Exclude identical-view cases 29 | def de_diag(acc, each_angle=False): 30 | result = np.sum(acc - np.diag(np.diag(acc)), 1) / 10.0 # 本来11个视角,除去相同视角还有10个 31 | if not each_angle: 32 | result = np.mean(result) 33 | return result 34 | 35 | 36 | m = initialization(conf, test=opt.cache)[0] 37 | 38 | # load model checkpoint of iteration opt.iter 39 | print('Loading the model of iteration %d...' % opt.iter) 40 | m.load(opt.iter) 41 | print('Transforming...') 42 | time = datetime.now() 43 | test = m.transform('test', opt.batch_size) 44 | print('Evaluating...') 45 | acc = evaluation(test, conf['data']) 46 | print('Evaluation complete. Cost:', datetime.now() - time) 47 | 48 | # Print rank-1 accuracy of the best model 49 | # e.g. 50 | # ===Rank-1 (Include identical-view cases)=== 51 | # NM: 95.405, BG: 88.284, CL: 72.041 52 | 53 | # 我训练得到的 54 | # ===Rank-1 (Include identical-view cases)=== 55 | # NM: 95.744, BG: 89.143, CL: 72.554 56 | 57 | for i in range(1): 58 | print('===Rank-%d (Include identical-view cases)===' % (i + 1)) 59 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 60 | np.mean(acc[0, :, :, i]), 61 | np.mean(acc[1, :, :, i]), 62 | np.mean(acc[2, :, :, i]))) 63 | 64 | # Print rank-1 accuracy of the best model,excluding identical-view cases 65 | # e.g. 66 | # ===Rank-1 (Exclude identical-view cases)=== 67 | # NM: 94.964, BG: 87.239, CL: 70.355 68 | 69 | # 我训练得到的 70 | # ===Rank-1 (Exclude identical-view cases)=== 71 | # NM: 95.327, BG: 88.221, CL: 70.745 72 | for i in range(1): 73 | print('===Rank-%d (Exclude identical-view cases)===' % (i + 1)) 74 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 75 | de_diag(acc[0, :, :, i]), 76 | de_diag(acc[1, :, :, i]), 77 | de_diag(acc[2, :, :, i]))) 78 | 79 | # Print rank-1 accuracy of the best model (Each Angle) 80 | # e.g. 81 | # ===Rank-1 of each angle (Exclude identical-view cases)=== 82 | # NM: [90.80 97.90 99.40 96.90 93.60 91.70 95.00 97.80 98.90 96.80 85.80] 83 | # BG: [83.80 91.20 91.80 88.79 83.30 81.00 84.10 90.00 92.20 94.45 79.00] 84 | # CL: [61.40 75.40 80.70 77.30 72.10 70.10 71.50 73.50 73.50 68.40 50.00] 85 | 86 | # 我训练得到的 87 | # ===Rank-1 of each angle (Exclude identical-view cases)=== 88 | # NM: [92.60 97.90 98.30 98.10 94.10 91.90 95.70 97.90 98.80 96.90 86.40] 89 | # BG: [85.40 92.30 93.90 90.61 86.70 82.30 83.90 91.20 93.40 92.12 78.60] 90 | # CL: [64.10 75.20 80.90 78.00 69.50 66.70 69.50 74.10 73.70 70.50 56.00] 91 | 92 | np.set_printoptions(precision=2, floatmode='fixed') 93 | for i in range(1): 94 | print('===Rank-%d of each angle (Exclude identical-view cases)===' % (i + 1)) 95 | print('NM:', de_diag(acc[0, :, :, i], True)) 96 | print('BG:', de_diag(acc[1, :, :, i], True)) 97 | print('CL:', de_diag(acc[2, :, :, i], True)) 98 | -------------------------------------------------------------------------------- /test/test_all.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from datetime import datetime 5 | 6 | sys.path.append(r"/data/lwl/Gait_experiment/GaitSet") 7 | # sys.path.append(r"/data/lwl/Gait_experiment/GaitSet/model") 8 | import numpy as np 9 | # from config.config_v2 import conf 10 | from config.config_hard_full_loss import conf 11 | from visdom import Visdom 12 | from model2.initialization import initialization 13 | from model2.utils import evaluation 14 | 15 | 16 | def boolean_string(s): 17 | if s.upper() not in {'FALSE', 'TRUE'}: 18 | raise ValueError('Not a valid boolean string') 19 | return s.upper() == 'TRUE' 20 | 21 | 22 | parser = argparse.ArgumentParser(description='Test') 23 | parser.add_argument('--iter', default='80000', type=int, 24 | help='iter: iteration of the checkpoint to load. Default: 80000') 25 | parser.add_argument('--batch_size', default='1', type=int, 26 | help='batch_size: batch size for parallel test. Default: 1') 27 | parser.add_argument('--cache', default=True, type=boolean_string, 28 | help='cache: if set as TRUE all the test data will be loaded at once' 29 | ' before the transforming start. Default: FALSE') 30 | opt = parser.parse_args() 31 | 32 | 33 | # Exclude identical-view cases 34 | def de_diag(acc, each_angle=False): 35 | result = np.sum(acc - np.diag(np.diag(acc)), 1) / 10.0 # 本来11个视角,除去相同视角还有10个 36 | if not each_angle: 37 | result = np.mean(result) 38 | return result 39 | 40 | 41 | if __name__ == '__main__': 42 | log_test_dir = "./log/test/visdom/hard_full_loss" 43 | m = initialization(conf, test=opt.cache)[0] 44 | """print(os.listdir(os.path.join(conf['WORK_PATH'], "checkpoint", 45 | conf["model"]["model_name"]))[-1].split("-")[-2]) 46 | test1 = [] 47 | for filename in os.listdir(os.path.join(conf['WORK_PATH'], "checkpoint", 48 | conf["model"]["model_name"])): 49 | if int(filename.split("-")[-2]) == 60100: 50 | test1.append(filename) 51 | print(test1)""" 52 | 53 | # 找出所有保存的模型的对应的代数,记得去重(因为和之前的128batchsize大小的重复了,那个没有采用dy-relu) 54 | iter_list = sorted(list(set(map(lambda a: int(a.split("-")[-2]), 55 | os.listdir(os.path.join("checkpoint", 56 | conf["model"]["model_save_dir"])))))) 57 | # writer = SummaryWriter(log_dir="./log/all_test_acc_log") 58 | 59 | # visdom可视化 60 | vis = Visdom(env="GaitSet_test", log_to_filename=os.path.join(log_test_dir, "test_all_acc.log")) 61 | 62 | # load model checkpoint of iteration opt.iter 63 | acc_array_include = [] 64 | acc_array_exclude = [] 65 | # 只取出大于60000的进行测试 66 | iter_list = list(filter(lambda x: x > 100000, iter_list)) 67 | for iter_s in iter_list: 68 | if iter_s % 100 == 0: 69 | print(iter_s) 70 | print('Loading the model of iteration %d...' % iter_s) 71 | m.load(iter_s) 72 | print('Transforming...') 73 | time = datetime.now() 74 | test = m.transform('test', opt.batch_size) 75 | print('Evaluating...') 76 | acc = evaluation(test, conf['data']) 77 | print('Evaluation complete. Cost:', datetime.now() - time) 78 | 79 | # Print rank-1 accuracy of the best model 80 | # e.g. 81 | # ===Rank-1 (Include identical-view cases)=== 82 | # NM: 95.405, BG: 88.284, CL: 72.041 83 | 84 | # 我训练得到的 85 | # ===Rank-1 (Include identical-view cases)=== 86 | # NM: 95.744, BG: 89.143, CL: 72.554 87 | 88 | for i in range(1): 89 | print('===Rank-%d (Include identical-view cases)===' % (i + 1)) 90 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 91 | np.mean(acc[0, :, :, i]), 92 | np.mean(acc[1, :, :, i]), 93 | np.mean(acc[2, :, :, i]))) 94 | acc_array_include.append([np.mean(acc[0, :, :, i]), 95 | np.mean(acc[1, :, :, i]), 96 | np.mean(acc[2, :, :, i])]) 97 | 98 | # Print rank-1 accuracy of the best model,excluding identical-view cases 99 | # e.g. 100 | # ===Rank-1 (Exclude identical-view cases)=== 101 | # NM: 94.964, BG: 87.239, CL: 70.355 102 | 103 | # 我训练得到的 104 | # ===Rank-1 (Exclude identical-view cases)=== 105 | # NM: 95.327, BG: 88.221, CL: 70.745 106 | for i in range(1): 107 | print('===Rank-%d (Exclude identical-view cases)===' % (i + 1)) 108 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 109 | de_diag(acc[0, :, :, i]), 110 | de_diag(acc[1, :, :, i]), 111 | de_diag(acc[2, :, :, i]))) 112 | acc_array_exclude.append([de_diag(acc[0, :, :, i]), 113 | de_diag(acc[1, :, :, i]), 114 | de_diag(acc[2, :, :, i])]) 115 | print("这是acc数组:", acc_array_exclude) 116 | print("这是rownames:", iter_list) 117 | # 保存数据 118 | np.savetxt(os.path.join(log_test_dir, "acc_array_exclude.txt"), np.array(acc_array_exclude)) 119 | np.savetxt(os.path.join(log_test_dir, "acc_array_include.txt"), np.array(acc_array_include)) 120 | np.savetxt(os.path.join(log_test_dir, "iter_list.txt"), iter_list) 121 | vis.bar(X=acc_array_exclude, 122 | opts=dict( 123 | stacked=False, 124 | legend=['NM', 'BG', 'CL'], 125 | rownames=iter_list, 126 | title='acc_array_exclude', 127 | ylabel='rank-1 accuracy', # y轴名称 128 | xtickmin=0.4 # x轴左端点起始位置 129 | # xtickstep=0.4 # 每个柱形间隔距离 130 | ), win="acc_array_exclude") 131 | vis.bar(X=acc_array_include, 132 | opts=dict( 133 | stacked=False, 134 | legend=['NM', 'BG', 'CL'], 135 | rownames=iter_list, 136 | title='acc_array_include', 137 | ylabel='rank-1 accuracy', # y轴名称 138 | xtickmin=0.4 # x轴左端点起始位置 139 | # xtickstep=0.4 # 每个柱形间隔距离 140 | ), win="acc_array_include") 141 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作者 :wanglin 3 | # 创建时间 :2020/9/26 10:56 4 | # 文件 :__init__.py 5 | # IDE :PyCharm 6 | import sys 7 | 8 | sys.path.append(r"/data/lwl/Gait_experiment/GaitSet") 9 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from config.config_v2 import conf 4 | from model.initialization import initialization 5 | 6 | 7 | def boolean_string(s): 8 | if s.upper() not in {'FALSE', 'TRUE'}: 9 | raise ValueError('Not a valid boolean string') 10 | return s.upper() == 'TRUE' 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Train') 14 | # todo 注意这里修改了cache的值,实际训练中最好打开 15 | parser.add_argument('--cache', default=True, type=boolean_string, 16 | help='cache: if set as TRUE all the training data will be loaded at once' 17 | ' before the training start. Default: TRUE') 18 | opt = parser.parse_args() 19 | 20 | m = initialization(conf, train=opt.cache)[0] 21 | 22 | print("Training START") 23 | m.fit() 24 | print("Training COMPLETE") 25 | -------------------------------------------------------------------------------- /train/train_full.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | sys.path.append(r"/data/lwl/Gait_experiment/GaitSet") 5 | from config.config_full_loss import conf 6 | from model2.initialization import initialization 7 | 8 | 9 | def boolean_string(s): 10 | if s.upper() not in {'FALSE', 'TRUE'}: 11 | raise ValueError('Not a valid boolean string') 12 | return s.upper() == 'TRUE' 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Train') 16 | # todo 注意这里修改了cache的值,实际训练中最好打开 17 | parser.add_argument('--cache', default=True, type=boolean_string, 18 | help='cache: if set as TRUE all the training data will be loaded at once' 19 | ' before the training start. Default: TRUE') 20 | opt = parser.parse_args() 21 | 22 | m = initialization(conf, train=opt.cache)[0] 23 | 24 | print("Training START") 25 | m.fit() 26 | print("Training COMPLETE") 27 | -------------------------------------------------------------------------------- /train/train_hard_full.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | sys.path.append(r"/data/lwl/Gait_experiment/GaitSet") 5 | from config.config_hard_full_loss import conf 6 | from model2.initialization import initialization 7 | 8 | def boolean_string(s): 9 | if s.upper() not in {'FALSE', 'TRUE'}: 10 | raise ValueError('Not a valid boolean string') 11 | return s.upper() == 'TRUE' 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Train') 15 | # todo 注意这里修改了cache的值,实际训练中最好打开 16 | parser.add_argument('--cache', default=True, type=boolean_string, 17 | help='cache: if set as TRUE all the training data will be loaded at once' 18 | ' before the training start. Default: TRUE') 19 | opt = parser.parse_args() 20 | 21 | m = initialization(conf, train=opt.cache)[0] 22 | 23 | print("Training START") 24 | m.fit() 25 | print("Training COMPLETE") 26 | -------------------------------------------------------------------------------- /work/OUMVLP_network/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv2d(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 7 | super(BasicConv2d, self).__init__() 8 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) 9 | 10 | def forward(self, x): 11 | x = self.conv(x) 12 | return F.leaky_relu(x, inplace=True) 13 | 14 | 15 | class SetBlock(nn.Module): 16 | def __init__(self, forward_block, pooling=False): 17 | super(SetBlock, self).__init__() 18 | self.forward_block = forward_block 19 | self.pooling = pooling 20 | if pooling: 21 | self.pool2d = nn.MaxPool2d(2) 22 | def forward(self, x): 23 | n, s, c, h, w = x.size() 24 | x = self.forward_block(x.view(-1,c,h,w)) 25 | if self.pooling: 26 | x = self.pool2d(x) 27 | _, c, h, w = x.size() 28 | return x.view(n, s, c, h ,w) 29 | 30 | 31 | class HPM(nn.Module): 32 | def __init__(self, in_dim, out_dim, bin_level_num=5): 33 | super(HPM, self).__init__() 34 | self.bin_num = [2**i for i in range(bin_level_num)] 35 | self.fc_bin = nn.ParameterList([ 36 | nn.Parameter( 37 | nn.init.xavier_uniform( 38 | torch.zeros(sum(self.bin_num), in_dim, out_dim)))]) 39 | def forward(self, x): 40 | feature = list() 41 | n, c, h, w = x.size() 42 | for num_bin in self.bin_num: 43 | z = x.view(n, c, num_bin, -1) 44 | z = z.mean(3)+z.max(3)[0] 45 | feature.append(z) 46 | feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous() 47 | 48 | feature = feature.matmul(self.fc_bin[0]) 49 | return feature.permute(1, 0, 2).contiguous() -------------------------------------------------------------------------------- /work/OUMVLP_network/gaitset.py: -------------------------------------------------------------------------------- 1 | class SetNet(nn.Module): 2 | def __init__(self, hidden_dim): 3 | super(SetNet, self).__init__() 4 | self.hidden_dim = hidden_dim 5 | self.batch_frame = None 6 | 7 | _in_channels = 1 8 | _channels = [64,128,256] 9 | self.set_layer1 = SetBlock(BasicConv2d(_in_channels, _channels[0], 5, padding=2)) 10 | self.set_layer2 = SetBlock(BasicConv2d(_channels[0], _channels[0], 3, padding=1), True) 11 | self.set_layer3 = SetBlock(BasicConv2d(_channels[0], _channels[1], 3, padding=1)) 12 | self.set_layer4 = SetBlock(BasicConv2d(_channels[1], _channels[1], 3, padding=1), True) 13 | self.set_layer5 = SetBlock(BasicConv2d(_channels[1], _channels[2], 3, padding=1)) 14 | self.set_layer6 = SetBlock(BasicConv2d(_channels[2], _channels[2], 3, padding=1)) 15 | 16 | self.gl_layer1 = BasicConv2d(_channels[0], _channels[1], 3, padding=1) 17 | self.gl_layer2 = BasicConv2d(_channels[1], _channels[1], 3, padding=1) 18 | self.gl_layer3 = BasicConv2d(_channels[1], _channels[2], 3, padding=1) 19 | self.gl_layer4 = BasicConv2d(_channels[2], _channels[2], 3, padding=1) 20 | self.gl_pooling = nn.MaxPool2d(2) 21 | 22 | self.gl_hpm = HPM(_channels[-1], hidden_dim) 23 | self.x_hpm = HPM(_channels[-1], hidden_dim) 24 | 25 | for m in self.modules(): 26 | if isinstance(m, (nn.Conv2d, nn.Conv1d)): 27 | nn.init.xavier_uniform(m.weight.data) 28 | elif isinstance(m, nn.Linear): 29 | nn.init.xavier_uniform(m.weight.data) 30 | nn.init.constant(m.bias.data, 0.0) 31 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 32 | nn.init.normal(m.weight.data, 1.0, 0.02) 33 | nn.init.constant(m.bias.data, 0.0) 34 | 35 | 36 | def frame_max(self, x): 37 | if self.batch_frame is None: 38 | return torch.max(x, 1) 39 | else: 40 | _tmp = [ 41 | torch.max(x[:, self.batch_frame[i]:self.batch_frame[i+1], :, :, :], 1) 42 | for i in range(len(self.batch_frame)-1) 43 | ] 44 | max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) 45 | arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) 46 | return max_list, arg_max_list 47 | 48 | 49 | def forward(self, silho, batch_frame=None): 50 | silho = silho/255 51 | # n: batch_size, s: frame_num, k: keypoints_num, c: channel 52 | if batch_frame is not None: 53 | batch_frame = batch_frame[0].data.cpu().numpy().tolist() 54 | _ = len(batch_frame) 55 | for i in range(len(batch_frame)): 56 | if batch_frame[-(i+1)] != 0: 57 | break 58 | else: 59 | _ -= 1 60 | batch_frame = batch_frame[:_] 61 | frame_sum = np.sum(batch_frame) 62 | if frame_sum < silho.size(1): 63 | silho = silho[:, :frame_sum,:,:] 64 | self.batch_frame = [0]+np.cumsum(batch_frame).tolist() 65 | n = silho.size(0) 66 | x = silho.unsqueeze(2) 67 | del silho 68 | 69 | x = self.set_layer1(x) 70 | x = self.set_layer2(x) 71 | gl = self.gl_layer1(self.frame_max(x)[0]) 72 | gl = self.gl_layer2(gl) 73 | gl = self.gl_pooling(gl) 74 | 75 | x = self.set_layer3(x) 76 | x = self.set_layer4(x) 77 | gl = self.gl_layer3(gl+self.frame_max(x)[0]) 78 | gl = self.gl_layer4(gl) 79 | 80 | x = self.set_layer5(x) 81 | x = self.set_layer6(x) 82 | x = self.frame_max(x)[0] 83 | gl = gl+x 84 | 85 | gl_f = self.gl_hpm(gl) 86 | x_f = self.x_hpm(x) 87 | 88 | return torch.cat([gl_f, x_f], 1), None -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601021568.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601021568.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601021725.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601021725.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601021764.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601021764.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601022036.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601022036.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601022451.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601022451.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601022615.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601022615.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601040881.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601040881.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601082118.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601082118.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601084376.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601084376.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardV2_log/events.out.tfevents.1601084535.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardV2_log/events.out.tfevents.1601084535.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardX_log/events.out.tfevents.1600595800.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardX_log/events.out.tfevents.1600595800.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardX_log/events.out.tfevents.1600595961.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardX_log/events.out.tfevents.1600595961.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardX_log/events.out.tfevents.1600596086.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardX_log/events.out.tfevents.1600596086.lab206-Server -------------------------------------------------------------------------------- /work/log/tensorboardX_log/events.out.tfevents.1600608993.lab206-Server: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/log/tensorboardX_log/events.out.tfevents.1600608993.lab206-Server -------------------------------------------------------------------------------- /work/log/visdom/acc_array_exclude.txt: -------------------------------------------------------------------------------- 1 | 8.734545454545454390e+01 7.713963636363637022e+01 6.213636363636364734e+01 2 | 8.728181818181818130e+01 7.798554545454545917e+01 6.173636363636362745e+01 3 | 8.666363636363635692e+01 7.627400000000000091e+01 6.052727272727273089e+01 4 | 8.748181818181818414e+01 7.865863636363634726e+01 6.283636363636364308e+01 5 | 8.669090909090908781e+01 7.732181818181817334e+01 5.973636363636362745e+01 6 | 8.655454545454546178e+01 7.687445454545454027e+01 6.089999999999999858e+01 7 | 8.708181818181819267e+01 7.712054545454545007e+01 6.133636363636364308e+01 8 | 8.660909090909090935e+01 7.665472727272727127e+01 6.186363636363636687e+01 9 | 8.670909090909090366e+01 7.627299999999999613e+01 6.277272727272725916e+01 10 | 8.563636363636364024e+01 7.699445454545453060e+01 6.028181818181817420e+01 11 | 8.661818181818181017e+01 7.659309090909090401e+01 6.218181818181817988e+01 12 | 8.630909090909091219e+01 7.748563636363634544e+01 6.235454545454545183e+01 13 | 8.710000000000000853e+01 7.916045454545454163e+01 6.430000000000001137e+01 14 | 8.789999999999999147e+01 7.889590909090908610e+01 6.167272727272727195e+01 15 | 8.731818181818181301e+01 7.818454545454545723e+01 6.279999999999999716e+01 16 | 8.677272727272726627e+01 7.761981818181818937e+01 6.166363636363637113e+01 17 | 8.756363636363636260e+01 7.788418181818182973e+01 6.330909090909091930e+01 18 | 8.795454545454546746e+01 7.936145454545454925e+01 6.326363636363637255e+01 19 | 8.672727272727273373e+01 7.732009090909090787e+01 6.367272727272727195e+01 20 | 8.599090909090908497e+01 7.691081818181818619e+01 6.354545454545454675e+01 21 | 8.607272727272729185e+01 7.691009090909091128e+01 6.309090909090909349e+01 22 | 8.723636363636363455e+01 7.904245454545454663e+01 6.443636363636363740e+01 23 | 8.669090909090908781e+01 7.852145454545454584e+01 6.425454545454545041e+01 24 | 8.631818181818181301e+01 7.660145454545452992e+01 6.263636363636364734e+01 25 | 8.619999999999998863e+01 7.787699999999999534e+01 6.323636363636364166e+01 26 | 8.688181818181817562e+01 7.690045454545453651e+01 6.270909090909091788e+01 27 | 8.644545454545453822e+01 7.749481818181817516e+01 6.305454545454546889e+01 28 | 8.540909090909090651e+01 7.669327272727272771e+01 6.170909090909090366e+01 29 | 8.643636363636363740e+01 7.719172727272726320e+01 6.078181818181818130e+01 30 | 8.737272727272727479e+01 7.823099999999999454e+01 6.360000000000000142e+01 31 | 8.672727272727273373e+01 7.835818181818181927e+01 6.306363636363636971e+01 32 | 8.642727272727272236e+01 7.671854545454543484e+01 6.328181818181818130e+01 33 | 8.740909090909090651e+01 7.760063636363635453e+01 6.158181818181817846e+01 34 | 8.656363636363636260e+01 7.794872727272728241e+01 6.301818181818180875e+01 35 | 8.766363636363635692e+01 7.823090909090909406e+01 6.210909090909091645e+01 36 | 8.604545454545453254e+01 7.649236363636363478e+01 6.127272727272728048e+01 37 | 8.631818181818181301e+01 7.746445454545454368e+01 6.309090909090907928e+01 38 | 8.699090909090909918e+01 7.839536363636364058e+01 6.310909090909092356e+01 39 | 8.740000000000000568e+01 7.800254545454544086e+01 6.099090909090909207e+01 40 | 8.551818181818181586e+01 7.637290909090910418e+01 6.367272727272727195e+01 41 | 8.727272727272726627e+01 7.785718181818182870e+01 6.370909090909091788e+01 42 | 8.663636363636364024e+01 7.743809090909090287e+01 6.010909090909091645e+01 43 | 8.729090909090909634e+01 7.808636363636362887e+01 6.470909090909090366e+01 44 | 8.668181818181818699e+01 7.702909090909089684e+01 6.155454545454545467e+01 45 | 8.702727272727273089e+01 7.694718181818181790e+01 6.251818181818180875e+01 46 | 8.717272727272728616e+01 7.773763636363636920e+01 6.370909090909091788e+01 47 | 8.711818181818182438e+01 7.699154545454545939e+01 5.984545454545454390e+01 48 | 8.674545454545453538e+01 7.762981818181818028e+01 6.104545454545454675e+01 49 | 8.590000000000000568e+01 7.755645454545455664e+01 6.325454545454547173e+01 50 | 8.623636363636363455e+01 7.734709090909090889e+01 6.414545454545455527e+01 51 | 8.732727272727272805e+01 7.723963636363636454e+01 6.322727272727272663e+01 52 | 8.660909090909090935e+01 7.713845454545453606e+01 6.318181818181817277e+01 53 | 8.627272727272726627e+01 7.712827272727271577e+01 6.307272727272727053e+01 54 | 8.580000000000001137e+01 7.635545454545454902e+01 6.319999999999999574e+01 55 | 8.524545454545454959e+01 7.620945454545454822e+01 6.059999999999999432e+01 56 | 8.688181818181817562e+01 7.698254545454545905e+01 6.099090909090909918e+01 57 | 8.749090909090908497e+01 7.624699999999999989e+01 6.255454545454545467e+01 58 | 8.556363636363636260e+01 7.603427272727273589e+01 6.237272727272726769e+01 59 | 8.666363636363637113e+01 7.671972727272726900e+01 6.035454545454544473e+01 60 | 8.641818181818180733e+01 7.573763636363636920e+01 6.123636363636364166e+01 61 | 8.687272727272726058e+01 7.779472727272727184e+01 6.269999999999999574e+01 62 | 8.754545454545453254e+01 7.790354545454545132e+01 6.300909090909090082e+01 63 | 8.560909090909090935e+01 7.527790909090909111e+01 6.149090909090909207e+01 64 | 8.643636363636363740e+01 7.598990909090910861e+01 6.233636363636364308e+01 65 | 8.582727272727272805e+01 7.677381818181817152e+01 6.262727272727272521e+01 66 | 8.666363636363635692e+01 7.711063636363634544e+01 6.017272727272727195e+01 67 | 8.680909090909091219e+01 7.863845454545455027e+01 6.351818181818180875e+01 68 | 8.516363636363635692e+01 7.520772727272726854e+01 6.114545454545454817e+01 69 | 8.693636363636363740e+01 7.747454545454546349e+01 6.356363636363637681e+01 70 | 8.799090909090908497e+01 7.762054545454545007e+01 6.458181818181817846e+01 71 | 8.402727272727273089e+01 7.488836363636363558e+01 6.293636363636364450e+01 72 | 8.726363636363636545e+01 7.781090909090909236e+01 6.541818181818182154e+01 73 | 8.728181818181816709e+01 7.794727272727271838e+01 6.262727272727273231e+01 74 | 8.658181818181817846e+01 7.801036363636363546e+01 6.385454545454546604e+01 75 | 8.709090909090907928e+01 7.804999999999999716e+01 6.336363636363636687e+01 76 | 8.694545454545453822e+01 7.783154545454546280e+01 6.474545454545453538e+01 77 | 8.743636363636362319e+01 7.777545454545453651e+01 6.359090909090909349e+01 78 | 8.758181818181817846e+01 7.776800000000001489e+01 6.320000000000000284e+01 79 | 8.704545454545456096e+01 7.730227272727273657e+01 6.319090909090910202e+01 80 | 8.683636363636364308e+01 7.784090909090909349e+01 6.165454545454545610e+01 81 | 8.695454545454545325e+01 7.785736363636362967e+01 6.151818181818180875e+01 82 | 8.835454545454545894e+01 7.977009090909091071e+01 6.214545454545453396e+01 83 | 8.661818181818182438e+01 7.742063636363636192e+01 6.341818181818182154e+01 84 | 8.620909090909090366e+01 7.753990909090909156e+01 6.225454545454545752e+01 85 | 8.490909090909089230e+01 7.600881818181818517e+01 6.125454545454545041e+01 86 | 8.609090909090909349e+01 7.676554545454546030e+01 6.340909090909090651e+01 87 | 8.685454545454545894e+01 7.810490909090907508e+01 6.217272727272725774e+01 88 | 8.737272727272728901e+01 7.926100000000000989e+01 6.230909090909090509e+01 89 | 8.669090909090908781e+01 7.642636363636363228e+01 6.149090909090909207e+01 90 | 8.668181818181818699e+01 7.688127272727273009e+01 6.272727272727274084e+01 91 | 8.628181818181818130e+01 7.643781818181815879e+01 6.339090909090909776e+01 92 | 8.648181818181818414e+01 7.720227272727272805e+01 6.252727272727270957e+01 93 | 8.624545454545453538e+01 7.610172727272725979e+01 6.220000000000000284e+01 94 | 8.520000000000000284e+01 7.537290909090908997e+01 5.981818181818182012e+01 95 | 8.493636363636363740e+01 7.625718181818182018e+01 6.220909090909090366e+01 96 | 8.779090909090909634e+01 7.817609090909090241e+01 6.232727272727272094e+01 97 | 8.644545454545453822e+01 7.759363636363636374e+01 6.276363636363635834e+01 98 | 8.701818181818181586e+01 7.779481818181818653e+01 6.282727272727272094e+01 99 | 8.676363636363636545e+01 7.646399999999999864e+01 6.374545454545454248e+01 100 | 8.742727272727273657e+01 7.875818181818182495e+01 6.343636363636363029e+01 101 | 8.775454545454546462e+01 7.895027272727271850e+01 6.406363636363636260e+01 102 | 8.759090909090907928e+01 7.894963636363635828e+01 6.464545454545454106e+01 103 | 8.771818181818181870e+01 7.915963636363636624e+01 6.422727272727273373e+01 104 | 8.768181818181818699e+01 7.889518181818181120e+01 6.390909090909089940e+01 105 | 8.787272727272726058e+01 7.915045454545453651e+01 6.479999999999999716e+01 106 | 8.786363636363635976e+01 7.902272727272726627e+01 6.427272727272726627e+01 107 | 8.792727272727272236e+01 7.916899999999999693e+01 6.502727272727273089e+01 108 | 8.788181818181818983e+01 7.889600000000000080e+01 6.457272727272727764e+01 109 | 8.718181818181818699e+01 7.832236363636361887e+01 6.467272727272728616e+01 110 | 8.779090909090908212e+01 7.891418181818183086e+01 6.457272727272726343e+01 111 | 8.735454545454545894e+01 7.888709090909090094e+01 6.460909090909090935e+01 112 | 8.739090909090909065e+01 7.846799999999997510e+01 6.400000000000000000e+01 113 | 8.754545454545454675e+01 7.847690909090907496e+01 6.435454545454544473e+01 114 | 8.735454545454544473e+01 7.847718181818180483e+01 6.495454545454545325e+01 115 | 8.756363636363637681e+01 7.874090909090909918e+01 6.459090909090909349e+01 116 | 8.760909090909090935e+01 7.856736363636363762e+01 6.449090909090909918e+01 117 | 8.750000000000000000e+01 7.825809090909091026e+01 6.501818181818181586e+01 118 | 8.764545454545454106e+01 7.866763636363636181e+01 6.488181818181818983e+01 119 | 8.773636363636363455e+01 7.911345454545454459e+01 6.537272727272726058e+01 120 | 8.779090909090908212e+01 7.902281818181816675e+01 6.523636363636363455e+01 121 | 8.766363636363635692e+01 7.869490909090909270e+01 6.486363636363635976e+01 122 | 8.747272727272728332e+01 7.840345454545455084e+01 6.478181818181818130e+01 123 | 8.720909090909091788e+01 7.850445454545454993e+01 6.508181818181817846e+01 124 | 8.747272727272728332e+01 7.870427272727272339e+01 6.481818181818181301e+01 125 | 8.749999999999998579e+01 7.862245454545454493e+01 6.499090909090909918e+01 126 | 8.723636363636363455e+01 7.818445454545454254e+01 6.426363636363636545e+01 127 | 8.743636363636363740e+01 7.837609090909091947e+01 6.459090909090909349e+01 128 | 8.761818181818181017e+01 7.824863636363636488e+01 6.449090909090909918e+01 129 | 8.761818181818182438e+01 7.806618181818180346e+01 6.467272727272727195e+01 130 | 8.761818181818181017e+01 7.831227272727272748e+01 6.491818181818182154e+01 131 | 8.740909090909092072e+01 7.831199999999999761e+01 6.461818181818181017e+01 132 | 8.750000000000000000e+01 7.815718181818182586e+01 6.449999999999998579e+01 133 | 8.727272727272726627e+01 7.842209090909092595e+01 6.458181818181817846e+01 134 | 8.733636363636362887e+01 7.802981818181817175e+01 6.379090909090908212e+01 135 | 8.778181818181818130e+01 7.852163636363637522e+01 6.396363636363636829e+01 136 | 8.759999999999999432e+01 7.793918181818182234e+01 6.400909090909091503e+01 137 | 8.739090909090910486e+01 7.771081818181818335e+01 6.443636363636363740e+01 138 | 8.747272727272728332e+01 7.803963636363636169e+01 6.487272727272727479e+01 139 | 8.762727272727273942e+01 7.808536363636363831e+01 6.452727272727271668e+01 140 | 8.749090909090909918e+01 7.781190909090909713e+01 6.467272727272725774e+01 141 | 8.736363636363635976e+01 7.778436363636363637e+01 6.438181818181817562e+01 142 | 8.762727272727272521e+01 7.824872727272726536e+01 6.458181818181819267e+01 143 | 8.770909090909090366e+01 7.843954545454545269e+01 6.503636363636363171e+01 144 | 8.757272727272727764e+01 7.827590909090908156e+01 6.481818181818181301e+01 145 | 8.695454545454545325e+01 7.790254545454544655e+01 6.439090909090909065e+01 146 | 8.729999999999999716e+01 7.790245454545453185e+01 6.483636363636362887e+01 147 | 8.740909090909090651e+01 7.790136363636364081e+01 6.466363636363637113e+01 148 | 8.764545454545454106e+01 7.819363636363635806e+01 6.411818181818181017e+01 149 | 8.779999999999999716e+01 7.852118181818181597e+01 6.479999999999999716e+01 150 | 8.770000000000000284e+01 7.843036363636362296e+01 6.447272727272726911e+01 151 | 8.777272727272728048e+01 7.810245454545453470e+01 6.451818181818181586e+01 152 | 8.797272727272728332e+01 7.853090909090907701e+01 6.440000000000000568e+01 153 | 8.770000000000000284e+01 7.838472727272726104e+01 6.448181818181818414e+01 154 | 8.764545454545455527e+01 7.840300000000000580e+01 6.498181818181818414e+01 155 | 8.756363636363636260e+01 7.820263636363635840e+01 6.485454545454545894e+01 156 | 8.782727272727271384e+01 7.800227272727272521e+01 6.359999999999999432e+01 157 | 8.764545454545454106e+01 7.823000000000000398e+01 6.450909090909091503e+01 158 | 8.772727272727273373e+01 7.863081818181817084e+01 6.418181818181817277e+01 159 | 8.782727272727274226e+01 7.820345454545454800e+01 6.440909090909090651e+01 160 | 8.745454545454545325e+01 7.754745454545454209e+01 6.400000000000000000e+01 161 | 8.729090909090906791e+01 7.752872727272728071e+01 6.394545454545455243e+01 162 | 8.766363636363635692e+01 7.793936363636363751e+01 6.336363636363635266e+01 163 | 8.745454545454545325e+01 7.759381818181817891e+01 6.371818181818182580e+01 164 | 8.742727272727273657e+01 7.790272727272726172e+01 6.415454545454547031e+01 165 | 8.740909090909092072e+01 7.744763636363636294e+01 6.392727272727272236e+01 166 | 8.746363636363636829e+01 7.757545454545454788e+01 6.388181818181819693e+01 167 | 8.742727272727273657e+01 7.728318181818180221e+01 6.400000000000000000e+01 168 | 8.759090909090909349e+01 7.748399999999999466e+01 6.407272727272727764e+01 169 | 8.720000000000000284e+01 7.766581818181818164e+01 6.379090909090909634e+01 170 | 8.691818181818182154e+01 7.777463636363634691e+01 6.415454545454545610e+01 171 | 8.707272727272727764e+01 7.759245454545454379e+01 6.390909090909092072e+01 172 | 8.752727272727273089e+01 7.836727272727272009e+01 6.379090909090909634e+01 173 | 8.741818181818182154e+01 7.796609090909090867e+01 6.407272727272727764e+01 174 | 8.724545454545454959e+01 7.764672727272726149e+01 6.406363636363637681e+01 175 | 8.730909090909091219e+01 7.773863636363635976e+01 6.357272727272727053e+01 176 | 8.716363636363635692e+01 7.794836363636362364e+01 6.378181818181818130e+01 177 | 8.737272727272727479e+01 7.821263636363636351e+01 6.350909090909092214e+01 178 | 8.721818181818183291e+01 7.800354545454545985e+01 6.387272727272726769e+01 179 | 8.723636363636363455e+01 7.783027272727271395e+01 6.339090909090908355e+01 180 | 8.724545454545454959e+01 7.785681818181818414e+01 6.370909090909091788e+01 181 | 8.721818181818183291e+01 7.763872727272728014e+01 6.400909090909091503e+01 182 | 8.728181818181818130e+01 7.767554545454544268e+01 6.337272727272728190e+01 183 | 8.710000000000000853e+01 7.725618181818182961e+01 6.387272727272726769e+01 184 | 8.719999999999998863e+01 7.810299999999999443e+01 6.376363636363637255e+01 185 | 8.722727272727273373e+01 7.778390909090909133e+01 6.337272727272727479e+01 186 | 8.731818181818181301e+01 7.757418181818181324e+01 6.364545454545455527e+01 187 | 8.763636363636362603e+01 7.813881818181818062e+01 6.424545454545454959e+01 188 | 8.761818181818183859e+01 7.815718181818182586e+01 6.421818181818181870e+01 189 | 8.755454545454546178e+01 7.820354545454546269e+01 6.436363636363637397e+01 190 | -------------------------------------------------------------------------------- /work/log/visdom/acc_array_include.txt: -------------------------------------------------------------------------------- 1 | 8.848760330578512878e+01 7.884528925619834183e+01 6.408264462809917461e+01 2 | 8.842975206611569661e+01 7.963090909090908553e+01 6.366115702479338978e+01 3 | 8.786776859504132631e+01 7.811636363636364422e+01 6.254545454545454675e+01 4 | 8.861157024793388359e+01 8.035884297520661335e+01 6.493388429752066315e+01 5 | 8.789256198347106874e+01 7.900280991735537839e+01 6.168595041322313932e+01 6 | 8.776859504132231393e+01 7.864578512396693100e+01 6.300000000000000000e+01 7 | 8.824793388429752383e+01 7.891090909090910088e+01 6.341322314049586595e+01 8 | 8.781818181818181301e+01 7.838834710743803669e+01 6.384297520661156966e+01 9 | 8.790909090909090651e+01 7.804132231404958020e+01 6.480991735537189413e+01 10 | 8.693388429752066315e+01 7.882107438016527112e+01 6.243801652892562259e+01 11 | 8.782644628099173190e+01 7.837347107438014859e+01 6.427272727272726627e+01 12 | 8.754545454545454675e+01 7.927595041322314273e+01 6.419008264462810587e+01 13 | 8.826446280991736160e+01 8.086462809917354377e+01 6.628925619834710403e+01 14 | 8.899173553719008112e+01 8.053314049586778367e+01 6.365289256198347090e+01 15 | 8.845454545454545325e+01 7.991958677685950363e+01 6.476859504132231393e+01 16 | 8.796694214876032447e+01 7.935652892561982696e+01 6.372727272727272663e+01 17 | 8.868595041322313932e+01 7.956388429752065861e+01 6.528925619834710403e+01 18 | 8.904132231404958020e+01 8.103090909090909122e+01 6.520661157024792942e+01 19 | 8.792561983471074427e+01 7.913371900826446392e+01 6.569421487603305820e+01 20 | 8.725619834710744271e+01 7.871198347107439020e+01 6.566942148760330156e+01 21 | 8.732231404958677956e+01 7.872793388429754202e+01 6.510743801652893126e+01 22 | 8.838842975206611641e+01 8.073264462809918030e+01 6.628925619834710403e+01 23 | 8.789256198347106874e+01 8.025066115702479408e+01 6.624793388429752383e+01 24 | 8.755371900826446563e+01 7.839768595041321930e+01 6.461983471074380248e+01 25 | 8.744628099173553437e+01 7.966471074380164907e+01 6.529752066115702291e+01 26 | 8.806611570247933685e+01 7.866123966942149082e+01 6.464462809917355912e+01 27 | 8.766942148760330156e+01 7.928421487603304740e+01 6.510743801652893126e+01 28 | 8.671900826446281485e+01 7.853082644628098308e+01 6.383471074380165078e+01 29 | 8.766115702479338267e+01 7.897561983471074143e+01 6.266115702479338978e+01 30 | 8.851239669421487122e+01 7.995347107438017531e+01 6.554545454545454675e+01 31 | 8.792561983471074427e+01 8.007743801652891591e+01 6.511570247933883593e+01 32 | 8.765289256198347800e+01 7.845438016528925118e+01 6.519008264462810587e+01 33 | 8.854545454545454675e+01 7.925636363636364479e+01 6.348760330578512168e+01 34 | 8.777685950413223281e+01 7.968859504132231564e+01 6.508264462809917461e+01 35 | 8.877685950413223281e+01 7.982107438016528533e+01 6.394214876033058204e+01 36 | 8.730578512396694180e+01 7.829024793388428805e+01 6.327272727272727337e+01 37 | 8.755371900826446563e+01 7.924834710743800770e+01 6.519008264462810587e+01 38 | 8.816528925619834922e+01 8.008636363636364308e+01 6.504958677685949908e+01 39 | 8.853719008264462786e+01 7.972909090909089969e+01 6.300000000000000000e+01 40 | 8.682644628099173190e+01 7.818165289256198491e+01 6.569421487603305820e+01 41 | 8.842148760330579194e+01 7.958892561983469705e+01 6.568595041322313932e+01 42 | 8.784297520661156966e+01 7.916644628099173531e+01 6.200000000000000000e+01 43 | 8.843801652892561549e+01 7.980545454545455186e+01 6.661983471074380248e+01 44 | 8.788429752066116407e+01 7.874495867768594337e+01 6.343801652892562259e+01 45 | 8.819834710743801054e+01 7.874495867768594337e+01 6.444628099173553437e+01 46 | 8.833057851239669844e+01 7.948016528925620605e+01 6.564462809917355912e+01 47 | 8.828099173553718515e+01 7.878537190082643349e+01 6.195041322314049381e+01 48 | 8.794214876033058204e+01 7.944008264462810587e+01 6.318181818181817988e+01 49 | 8.717355371900826810e+01 7.934016528925619127e+01 6.518181818181818699e+01 50 | 8.747933884297520990e+01 7.914157024793389894e+01 6.610743801652893126e+01 51 | 8.847107438016529102e+01 7.891173553719008282e+01 6.523966942148760495e+01 52 | 8.781818181818181301e+01 7.891057851239669674e+01 6.509917355371901238e+01 53 | 8.751239669421487122e+01 7.888479338842974187e+01 6.495867768595041980e+01 54 | 8.708264462809917461e+01 7.817396694214876618e+01 6.519008264462810587e+01 55 | 8.657851239669420806e+01 7.812396694214875481e+01 6.270247933884297709e+01 56 | 8.806611570247933685e+01 7.875231404958677217e+01 6.281818181818182012e+01 57 | 8.861983471074380248e+01 7.798446280991734625e+01 6.452066115702479010e+01 58 | 8.686776859504132631e+01 7.780735537190081175e+01 6.425619834710744271e+01 59 | 8.786776859504132631e+01 7.855479338842975778e+01 6.225619834710743561e+01 60 | 8.764462809917355912e+01 7.752950413223140913e+01 6.328925619834710403e+01 61 | 8.805785123966941796e+01 7.950719008264462673e+01 6.461983471074380248e+01 62 | 8.866942148760330156e+01 7.957322314049585543e+01 6.489256198347106874e+01 63 | 8.690082644628098762e+01 7.709495867768596611e+01 6.357024793388429629e+01 64 | 8.766115702479338267e+01 7.779206611570248242e+01 6.428099173553718515e+01 65 | 8.710743801652893126e+01 7.852107438016528818e+01 6.457851239669420806e+01 66 | 8.785950413223140743e+01 7.892652892561983435e+01 6.219008264462809876e+01 67 | 8.800000000000000000e+01 8.034033057851239334e+01 6.550413223140495234e+01 68 | 8.650413223140495234e+01 7.700644628099172451e+01 6.322314049586776719e+01 69 | 8.810743801652893126e+01 7.922462809917355742e+01 6.550413223140495234e+01 70 | 8.907438016528925573e+01 7.929925619834710915e+01 6.650413223140495234e+01 71 | 8.547107438016529102e+01 7.685685950413223111e+01 6.494214876033058204e+01 72 | 8.841322314049587305e+01 7.955512396694216193e+01 6.737190082644627864e+01 73 | 8.842975206611569661e+01 7.964586776859503914e+01 6.466942148760330156e+01 74 | 8.779338842975207058e+01 7.976123966942149934e+01 6.587603305785124519e+01 75 | 8.825619834710744271e+01 7.982206611570248356e+01 6.537190082644627864e+01 76 | 8.812396694214875481e+01 7.962347107438016280e+01 6.673553719008263840e+01 77 | 8.857024793388430339e+01 7.942355371900826810e+01 6.551239669421487122e+01 78 | 8.870247933884297709e+01 7.952438016528925857e+01 6.508264462809917461e+01 79 | 8.821487603305784830e+01 7.903479338842974755e+01 6.508264462809917461e+01 80 | 8.802479338842975665e+01 7.956578512396693270e+01 6.372727272727272663e+01 81 | 8.813223140495867369e+01 7.961388429752065576e+01 6.345454545454545325e+01 82 | 8.940495867768595417e+01 8.139404958677685897e+01 6.414049586776859257e+01 83 | 8.782644628099173190e+01 7.913396694214877414e+01 6.538842975206611641e+01 84 | 8.745454545454545325e+01 7.934173553719008964e+01 6.437190082644627864e+01 85 | 8.627272727272726627e+01 7.787537190082645111e+01 6.342975206611570371e+01 86 | 8.734710743801652200e+01 7.857157024793387734e+01 6.538842975206611641e+01 87 | 8.804132231404958020e+01 7.985537190082644088e+01 6.420661157024792942e+01 88 | 8.851239669421487122e+01 8.091462809917356935e+01 6.421487603305784830e+01 89 | 8.789256198347106874e+01 7.822198347107438110e+01 6.342148760330578483e+01 90 | 8.788429752066116407e+01 7.857768595041324033e+01 6.465289256198347800e+01 91 | 8.752066115702479010e+01 7.826537190082645168e+01 6.539669421487603529e+01 92 | 8.770247933884297709e+01 7.897685950413223566e+01 6.452892561983470898e+01 93 | 8.747933884297520990e+01 7.790181818181817164e+01 6.425619834710744271e+01 94 | 8.653719008264462786e+01 7.717305785123967610e+01 6.185950413223140743e+01 95 | 8.628099173553718515e+01 7.814247933884298902e+01 6.418181818181818699e+01 96 | 8.889256198347106874e+01 7.985388429752067907e+01 6.433057851239669844e+01 97 | 8.766942148760330156e+01 7.933272727272726854e+01 6.478512396694215170e+01 98 | 8.818181818181818699e+01 7.958181818181817846e+01 6.468595041322313932e+01 99 | 8.795041322314050092e+01 7.828099173553718515e+01 6.576033057851239505e+01 100 | 8.856198347107438451e+01 8.046570247933884446e+01 6.539669421487603529e+01 101 | 8.885950413223140743e+01 8.059900826446281030e+01 6.595867768595041980e+01 102 | 8.871074380165289597e+01 8.062322314049586680e+01 6.655371900826446563e+01 103 | 8.882644628099173190e+01 8.079760330578513106e+01 6.618181818181818699e+01 104 | 8.879338842975207058e+01 8.055719008264462389e+01 6.586776859504132631e+01 105 | 8.896694214876032447e+01 8.081404958677686068e+01 6.668595041322313932e+01 106 | 8.895867768595041980e+01 8.068966942148760779e+01 6.622314049586776719e+01 107 | 8.901652892561983776e+01 8.083090909090910259e+01 6.691735537190082539e+01 108 | 8.897520661157024335e+01 8.059099173553718742e+01 6.649586776859504766e+01 109 | 8.833884297520661733e+01 8.006123966942148229e+01 6.655371900826446563e+01 110 | 8.889256198347106874e+01 8.057438016528925573e+01 6.653719008264462786e+01 111 | 8.849586776859504766e+01 8.054983471074379509e+01 6.652892561983470898e+01 112 | 8.852892561983470898e+01 8.016049586776858860e+01 6.595041322314050092e+01 113 | 8.866942148760330156e+01 8.014380165289256297e+01 6.629752066115702291e+01 114 | 8.849586776859504766e+01 8.019380165289257434e+01 6.686776859504132631e+01 115 | 8.868595041322313932e+01 8.043355371900825901e+01 6.651239669421487122e+01 116 | 8.872727272727273373e+01 8.026743801652892785e+01 6.640495867768595417e+01 117 | 8.862809917355372136e+01 7.996975206611570286e+01 6.689256198347106874e+01 118 | 8.876033057851239505e+01 8.034214876033057351e+01 6.679338842975207058e+01 119 | 8.884297520661156966e+01 8.077223140495868847e+01 6.729752066115702291e+01 120 | 8.889256198347106874e+01 8.063190082644628376e+01 6.713223140495867369e+01 121 | 8.877685950413223281e+01 8.034206611570247958e+01 6.676033057851239505e+01 122 | 8.860330578512396471e+01 8.008537190082644486e+01 6.663636363636364024e+01 123 | 8.836363636363635976e+01 8.019371900826445199e+01 6.693388429752066315e+01 124 | 8.860330578512396471e+01 8.039198347107438281e+01 6.670247933884297709e+01 125 | 8.862809917355372136e+01 8.030107438016528931e+01 6.686776859504132631e+01 126 | 8.838842975206611641e+01 7.989462809917355912e+01 6.616528925619834922e+01 127 | 8.857024793388430339e+01 8.010190082644629683e+01 6.652066115702479010e+01 128 | 8.873553719008263840e+01 7.995297520661158330e+01 6.636363636363635976e+01 129 | 8.873553719008263840e+01 7.977876033057850691e+01 6.654545454545454675e+01 130 | 8.873553719008263840e+01 8.001909090909092015e+01 6.682644628099173190e+01 131 | 8.854545454545454675e+01 8.001876033057851600e+01 6.656198347107438451e+01 132 | 8.862809917355372136e+01 7.985322314049587078e+01 6.641322314049587305e+01 133 | 8.842148760330579194e+01 8.013545454545453595e+01 6.653719008264462786e+01 134 | 8.847107438016529102e+01 7.974561983471073745e+01 6.573553719008263840e+01 135 | 8.888429752066116407e+01 8.021768595041322669e+01 6.588429752066116407e+01 136 | 8.871900826446281485e+01 7.967991735537189868e+01 6.597520661157024335e+01 137 | 8.852892561983470898e+01 7.941446280991736728e+01 6.637190082644627864e+01 138 | 8.860330578512396471e+01 7.978776859504131380e+01 6.678512396694215170e+01 139 | 8.874380165289255729e+01 7.982107438016531376e+01 6.640495867768595417e+01 140 | 8.861983471074380248e+01 7.955595041322314387e+01 6.652892561983470898e+01 141 | 8.850413223140495234e+01 7.952256198347106420e+01 6.628099173553718515e+01 142 | 8.874380165289255729e+01 7.999438016528925743e+01 6.648760330578512878e+01 143 | 8.881818181818181301e+01 8.016785123966943161e+01 6.691735537190082539e+01 144 | 8.869421487603305820e+01 7.998595041322313648e+01 6.673553719008263840e+01 145 | 8.813223140495867369e+01 7.963826446280991433e+01 6.631404958677686068e+01 146 | 8.844628099173553437e+01 7.962983471074379338e+01 6.674380165289255729e+01 147 | 8.854545454545454675e+01 7.963719008264462218e+01 6.657851239669420806e+01 148 | 8.876033057851239505e+01 7.990289256198347800e+01 6.608264462809917461e+01 149 | 8.890082644628098762e+01 8.020892561983471580e+01 6.668595041322313932e+01 150 | 8.880991735537189413e+01 8.014289256198347289e+01 6.640495867768595417e+01 151 | 8.887603305785124519e+01 7.986140495867768152e+01 6.645454545454545325e+01 152 | 8.905785123966941796e+01 8.019297520661157819e+01 6.629752066115702291e+01 153 | 8.880991735537189413e+01 8.010140495867769062e+01 6.637190082644627864e+01 154 | 8.876033057851239505e+01 8.010148760330577034e+01 6.687603305785124519e+01 155 | 8.868595041322313932e+01 7.993595041322313932e+01 6.677685950413223281e+01 156 | 8.892561983471074427e+01 7.973727272727272464e+01 6.549586776859504766e+01 157 | 8.876033057851239505e+01 7.996082644628098990e+01 6.640495867768595417e+01 158 | 8.883471074380165078e+01 8.034173553719008964e+01 6.611570247933883593e+01 159 | 8.892561983471074427e+01 7.990355371900827208e+01 6.628099173553718515e+01 160 | 8.858677685950412695e+01 7.925752066115701666e+01 6.593388429752066315e+01 161 | 8.843801652892561549e+01 7.923223140495868222e+01 6.585950413223140743e+01 162 | 8.877685950413223281e+01 7.967173553719008794e+01 6.536363636363635976e+01 163 | 8.858677685950412695e+01 7.933280991735537668e+01 6.568595041322313932e+01 164 | 8.856198347107438451e+01 7.960545454545453481e+01 6.607438016528925573e+01 165 | 8.854545454545454675e+01 7.916685950413223338e+01 6.584297520661156966e+01 166 | 8.859504132231404583e+01 7.929958677685949908e+01 6.576859504132231393e+01 167 | 8.856198347107438451e+01 7.903388429752065747e+01 6.587603305785124519e+01 168 | 8.871074380165289597e+01 7.921644628099174668e+01 6.599173553719008112e+01 169 | 8.835537190082644088e+01 7.943966942148760779e+01 6.571074380165289597e+01 170 | 8.809917355371901238e+01 7.954685950413224305e+01 6.609090909090909349e+01 171 | 8.823966942148760495e+01 7.933991735537190948e+01 6.585123966942148854e+01 172 | 8.865289256198347800e+01 8.005256198347107954e+01 6.568595041322313932e+01 173 | 8.855371900826446563e+01 7.970438016528926539e+01 6.595867768595041980e+01 174 | 8.839669421487603529e+01 7.941404958677686921e+01 6.601652892561983776e+01 175 | 8.845454545454545325e+01 7.948107438016529613e+01 6.553719008264462786e+01 176 | 8.832231404958677956e+01 7.968000000000000682e+01 6.580165289256198946e+01 177 | 8.851239669421487122e+01 7.992016528925620378e+01 6.547107438016529102e+01 178 | 8.837190082644627864e+01 7.975495867768594849e+01 6.588429752066116407e+01 179 | 8.838842975206611641e+01 7.957256198347107556e+01 6.537190082644627864e+01 180 | 8.839669421487603529e+01 7.962157024793388871e+01 6.571074380165289597e+01 181 | 8.837190082644627864e+01 7.939016528925620264e+01 6.600000000000000000e+01 182 | 8.842975206611569661e+01 7.939876033057851146e+01 6.530578512396694180e+01 183 | 8.826446280991736160e+01 7.900925619834710290e+01 6.585950413223140743e+01 184 | 8.835537190082644088e+01 7.978735537190082994e+01 6.573553719008263840e+01 185 | 8.838016528925619752e+01 7.949727272727274396e+01 6.536363636363635976e+01 186 | 8.846280991735537214e+01 7.932314049586777571e+01 6.558677685950412695e+01 187 | 8.875206611570247617e+01 7.983652892561983094e+01 6.619008264462810587e+01 188 | 8.873553719008263840e+01 7.987801652892562743e+01 6.616528925619834922e+01 189 | 8.867768595041322044e+01 7.992016528925620378e+01 6.626446280991736160e+01 190 | -------------------------------------------------------------------------------- /work/log/visdom/iter_list.txt: -------------------------------------------------------------------------------- 1 | 7.001000000000000000e+04 2 | 7.002000000000000000e+04 3 | 7.003000000000000000e+04 4 | 7.004000000000000000e+04 5 | 7.005000000000000000e+04 6 | 7.006000000000000000e+04 7 | 7.007000000000000000e+04 8 | 7.008000000000000000e+04 9 | 7.009000000000000000e+04 10 | 7.010000000000000000e+04 11 | 7.011000000000000000e+04 12 | 7.012000000000000000e+04 13 | 7.013000000000000000e+04 14 | 7.014000000000000000e+04 15 | 7.015000000000000000e+04 16 | 7.016000000000000000e+04 17 | 7.017000000000000000e+04 18 | 7.018000000000000000e+04 19 | 7.019000000000000000e+04 20 | 7.020000000000000000e+04 21 | 7.021000000000000000e+04 22 | 7.022000000000000000e+04 23 | 7.023000000000000000e+04 24 | 7.024000000000000000e+04 25 | 7.025000000000000000e+04 26 | 7.026000000000000000e+04 27 | 7.027000000000000000e+04 28 | 7.028000000000000000e+04 29 | 7.029000000000000000e+04 30 | 7.030000000000000000e+04 31 | 7.031000000000000000e+04 32 | 7.032000000000000000e+04 33 | 7.033000000000000000e+04 34 | 7.034000000000000000e+04 35 | 7.035000000000000000e+04 36 | 7.036000000000000000e+04 37 | 7.037000000000000000e+04 38 | 7.038000000000000000e+04 39 | 7.039000000000000000e+04 40 | 7.040000000000000000e+04 41 | 7.041000000000000000e+04 42 | 7.042000000000000000e+04 43 | 7.043000000000000000e+04 44 | 7.044000000000000000e+04 45 | 7.045000000000000000e+04 46 | 7.046000000000000000e+04 47 | 7.047000000000000000e+04 48 | 7.048000000000000000e+04 49 | 7.049000000000000000e+04 50 | 7.050000000000000000e+04 51 | 7.051000000000000000e+04 52 | 7.052000000000000000e+04 53 | 7.053000000000000000e+04 54 | 7.054000000000000000e+04 55 | 7.055000000000000000e+04 56 | 7.056000000000000000e+04 57 | 7.057000000000000000e+04 58 | 7.058000000000000000e+04 59 | 7.059000000000000000e+04 60 | 7.060000000000000000e+04 61 | 7.061000000000000000e+04 62 | 7.062000000000000000e+04 63 | 7.063000000000000000e+04 64 | 7.064000000000000000e+04 65 | 7.065000000000000000e+04 66 | 7.066000000000000000e+04 67 | 7.067000000000000000e+04 68 | 7.068000000000000000e+04 69 | 7.069000000000000000e+04 70 | 7.070000000000000000e+04 71 | 7.071000000000000000e+04 72 | 7.072000000000000000e+04 73 | 7.073000000000000000e+04 74 | 7.074000000000000000e+04 75 | 7.075000000000000000e+04 76 | 7.076000000000000000e+04 77 | 7.077000000000000000e+04 78 | 7.078000000000000000e+04 79 | 7.079000000000000000e+04 80 | 7.080000000000000000e+04 81 | 7.081000000000000000e+04 82 | 7.082000000000000000e+04 83 | 7.083000000000000000e+04 84 | 7.084000000000000000e+04 85 | 7.085000000000000000e+04 86 | 7.086000000000000000e+04 87 | 7.087000000000000000e+04 88 | 7.088000000000000000e+04 89 | 7.089000000000000000e+04 90 | 7.090000000000000000e+04 91 | 7.091000000000000000e+04 92 | 7.092000000000000000e+04 93 | 7.093000000000000000e+04 94 | 7.094000000000000000e+04 95 | 7.095000000000000000e+04 96 | 7.096000000000000000e+04 97 | 7.097000000000000000e+04 98 | 7.098000000000000000e+04 99 | 7.099000000000000000e+04 100 | 7.100000000000000000e+04 101 | 7.101000000000000000e+04 102 | 7.102000000000000000e+04 103 | 7.103000000000000000e+04 104 | 7.104000000000000000e+04 105 | 7.105000000000000000e+04 106 | 7.106000000000000000e+04 107 | 7.107000000000000000e+04 108 | 7.108000000000000000e+04 109 | 7.109000000000000000e+04 110 | 7.110000000000000000e+04 111 | 7.111000000000000000e+04 112 | 7.112000000000000000e+04 113 | 7.113000000000000000e+04 114 | 7.114000000000000000e+04 115 | 7.115000000000000000e+04 116 | 7.116000000000000000e+04 117 | 7.117000000000000000e+04 118 | 7.118000000000000000e+04 119 | 7.119000000000000000e+04 120 | 7.120000000000000000e+04 121 | 7.121000000000000000e+04 122 | 7.122000000000000000e+04 123 | 7.123000000000000000e+04 124 | 7.124000000000000000e+04 125 | 7.125000000000000000e+04 126 | 7.126000000000000000e+04 127 | 7.127000000000000000e+04 128 | 7.128000000000000000e+04 129 | 7.129000000000000000e+04 130 | 7.130000000000000000e+04 131 | 7.131000000000000000e+04 132 | 7.132000000000000000e+04 133 | 7.133000000000000000e+04 134 | 7.134000000000000000e+04 135 | 7.135000000000000000e+04 136 | 7.136000000000000000e+04 137 | 7.137000000000000000e+04 138 | 7.138000000000000000e+04 139 | 7.139000000000000000e+04 140 | 7.140000000000000000e+04 141 | 7.141000000000000000e+04 142 | 7.142000000000000000e+04 143 | 7.143000000000000000e+04 144 | 7.144000000000000000e+04 145 | 7.145000000000000000e+04 146 | 7.146000000000000000e+04 147 | 7.147000000000000000e+04 148 | 7.148000000000000000e+04 149 | 7.149000000000000000e+04 150 | 7.150000000000000000e+04 151 | 7.151000000000000000e+04 152 | 7.152000000000000000e+04 153 | 7.153000000000000000e+04 154 | 7.154000000000000000e+04 155 | 7.155000000000000000e+04 156 | 7.156000000000000000e+04 157 | 7.157000000000000000e+04 158 | 7.158000000000000000e+04 159 | 7.159000000000000000e+04 160 | 7.160000000000000000e+04 161 | 7.161000000000000000e+04 162 | 7.162000000000000000e+04 163 | 7.163000000000000000e+04 164 | 7.164000000000000000e+04 165 | 7.165000000000000000e+04 166 | 7.166000000000000000e+04 167 | 7.167000000000000000e+04 168 | 7.168000000000000000e+04 169 | 7.169000000000000000e+04 170 | 7.170000000000000000e+04 171 | 7.171000000000000000e+04 172 | 7.172000000000000000e+04 173 | 7.173000000000000000e+04 174 | 7.174000000000000000e+04 175 | 7.175000000000000000e+04 176 | 7.176000000000000000e+04 177 | 7.177000000000000000e+04 178 | 7.178000000000000000e+04 179 | 7.179000000000000000e+04 180 | 7.180000000000000000e+04 181 | 7.181000000000000000e+04 182 | 7.182000000000000000e+04 183 | 7.183000000000000000e+04 184 | 7.184000000000000000e+04 185 | 7.185000000000000000e+04 186 | 7.186000000000000000e+04 187 | 7.187000000000000000e+04 188 | 7.188000000000000000e+04 189 | 7.189000000000000000e+04 190 | 7.190000000000000000e+04 191 | 7.191000000000000000e+04 192 | 7.192000000000000000e+04 193 | 7.193000000000000000e+04 194 | 7.194000000000000000e+04 195 | 7.195000000000000000e+04 196 | 7.196000000000000000e+04 197 | 7.197000000000000000e+04 198 | 7.198000000000000000e+04 199 | 7.199000000000000000e+04 200 | 7.200000000000000000e+04 201 | 7.201000000000000000e+04 202 | 7.202000000000000000e+04 203 | 7.203000000000000000e+04 204 | 7.204000000000000000e+04 205 | 7.205000000000000000e+04 206 | 7.206000000000000000e+04 207 | 7.207000000000000000e+04 208 | 7.208000000000000000e+04 209 | 7.209000000000000000e+04 210 | 7.210000000000000000e+04 211 | 7.211000000000000000e+04 212 | 7.212000000000000000e+04 213 | 7.213000000000000000e+04 214 | 7.214000000000000000e+04 215 | 7.215000000000000000e+04 216 | 7.216000000000000000e+04 217 | 7.217000000000000000e+04 218 | 7.218000000000000000e+04 219 | 7.219000000000000000e+04 220 | 7.220000000000000000e+04 221 | 7.221000000000000000e+04 222 | 7.222000000000000000e+04 223 | 7.223000000000000000e+04 224 | 7.224000000000000000e+04 225 | 7.225000000000000000e+04 226 | 7.226000000000000000e+04 227 | 7.227000000000000000e+04 228 | 7.228000000000000000e+04 229 | 7.229000000000000000e+04 230 | 7.230000000000000000e+04 231 | 7.231000000000000000e+04 232 | 7.232000000000000000e+04 233 | 7.233000000000000000e+04 234 | 7.234000000000000000e+04 235 | 7.235000000000000000e+04 236 | 7.236000000000000000e+04 237 | 7.237000000000000000e+04 238 | 7.238000000000000000e+04 239 | 7.239000000000000000e+04 240 | 7.240000000000000000e+04 241 | 7.241000000000000000e+04 242 | 7.242000000000000000e+04 243 | 7.243000000000000000e+04 244 | 7.244000000000000000e+04 245 | 7.245000000000000000e+04 246 | 7.246000000000000000e+04 247 | 7.247000000000000000e+04 248 | 7.248000000000000000e+04 249 | 7.249000000000000000e+04 250 | 7.250000000000000000e+04 251 | 7.251000000000000000e+04 252 | 7.252000000000000000e+04 253 | 7.253000000000000000e+04 254 | 7.254000000000000000e+04 255 | 7.255000000000000000e+04 256 | 7.256000000000000000e+04 257 | 7.257000000000000000e+04 258 | 7.258000000000000000e+04 259 | 7.259000000000000000e+04 260 | 7.260000000000000000e+04 261 | 7.261000000000000000e+04 262 | 7.262000000000000000e+04 263 | 7.263000000000000000e+04 264 | 7.264000000000000000e+04 265 | 7.265000000000000000e+04 266 | 7.266000000000000000e+04 267 | 7.267000000000000000e+04 268 | 7.268000000000000000e+04 269 | 7.269000000000000000e+04 270 | 7.270000000000000000e+04 271 | 7.271000000000000000e+04 272 | 7.272000000000000000e+04 273 | 7.273000000000000000e+04 274 | 7.274000000000000000e+04 275 | 7.275000000000000000e+04 276 | 7.276000000000000000e+04 277 | 7.277000000000000000e+04 278 | 7.278000000000000000e+04 279 | 7.279000000000000000e+04 280 | 7.280000000000000000e+04 281 | 7.281000000000000000e+04 282 | 7.282000000000000000e+04 283 | 7.283000000000000000e+04 284 | 7.284000000000000000e+04 285 | 7.285000000000000000e+04 286 | 7.286000000000000000e+04 287 | 7.287000000000000000e+04 288 | 7.288000000000000000e+04 289 | 7.289000000000000000e+04 290 | 7.290000000000000000e+04 291 | 7.291000000000000000e+04 292 | 7.292000000000000000e+04 293 | 7.293000000000000000e+04 294 | 7.294000000000000000e+04 295 | 7.295000000000000000e+04 296 | 7.296000000000000000e+04 297 | 7.297000000000000000e+04 298 | 7.298000000000000000e+04 299 | 7.299000000000000000e+04 300 | 7.300000000000000000e+04 301 | 7.301000000000000000e+04 302 | 7.302000000000000000e+04 303 | 7.303000000000000000e+04 304 | 7.304000000000000000e+04 305 | 7.305000000000000000e+04 306 | 7.306000000000000000e+04 307 | 7.307000000000000000e+04 308 | 7.308000000000000000e+04 309 | 7.309000000000000000e+04 310 | 7.310000000000000000e+04 311 | 7.311000000000000000e+04 312 | 7.312000000000000000e+04 313 | 7.313000000000000000e+04 314 | 7.314000000000000000e+04 315 | 7.315000000000000000e+04 316 | 7.316000000000000000e+04 317 | 7.317000000000000000e+04 318 | 7.318000000000000000e+04 319 | 7.319000000000000000e+04 320 | 7.320000000000000000e+04 321 | 7.321000000000000000e+04 322 | 7.322000000000000000e+04 323 | 7.323000000000000000e+04 324 | 7.324000000000000000e+04 325 | 7.325000000000000000e+04 326 | 7.326000000000000000e+04 327 | 7.327000000000000000e+04 328 | 7.328000000000000000e+04 329 | 7.329000000000000000e+04 330 | 7.330000000000000000e+04 331 | 7.331000000000000000e+04 332 | 7.332000000000000000e+04 333 | 7.333000000000000000e+04 334 | 7.334000000000000000e+04 335 | 7.335000000000000000e+04 336 | 7.336000000000000000e+04 337 | 7.337000000000000000e+04 338 | 7.338000000000000000e+04 339 | 7.339000000000000000e+04 340 | 7.340000000000000000e+04 341 | 7.341000000000000000e+04 342 | 7.342000000000000000e+04 343 | 7.343000000000000000e+04 344 | 7.344000000000000000e+04 345 | 7.345000000000000000e+04 346 | 7.346000000000000000e+04 347 | 7.347000000000000000e+04 348 | 7.348000000000000000e+04 349 | 7.349000000000000000e+04 350 | 7.350000000000000000e+04 351 | 7.351000000000000000e+04 352 | 7.352000000000000000e+04 353 | 7.353000000000000000e+04 354 | 7.354000000000000000e+04 355 | 7.355000000000000000e+04 356 | 7.356000000000000000e+04 357 | 7.357000000000000000e+04 358 | 7.358000000000000000e+04 359 | 7.359000000000000000e+04 360 | 7.360000000000000000e+04 361 | 7.361000000000000000e+04 362 | 7.362000000000000000e+04 363 | 7.363000000000000000e+04 364 | 7.364000000000000000e+04 365 | 7.365000000000000000e+04 366 | 7.366000000000000000e+04 367 | 7.367000000000000000e+04 368 | 7.368000000000000000e+04 369 | 7.369000000000000000e+04 370 | 7.370000000000000000e+04 371 | 7.371000000000000000e+04 372 | 7.372000000000000000e+04 373 | 7.373000000000000000e+04 374 | 7.374000000000000000e+04 375 | 7.375000000000000000e+04 376 | 7.376000000000000000e+04 377 | 7.377000000000000000e+04 378 | 7.378000000000000000e+04 379 | 7.379000000000000000e+04 380 | 7.380000000000000000e+04 381 | 7.381000000000000000e+04 382 | 7.382000000000000000e+04 383 | 7.383000000000000000e+04 384 | 7.384000000000000000e+04 385 | 7.385000000000000000e+04 386 | 7.386000000000000000e+04 387 | 7.387000000000000000e+04 388 | 7.388000000000000000e+04 389 | 7.389000000000000000e+04 390 | 7.390000000000000000e+04 391 | 7.391000000000000000e+04 392 | 7.392000000000000000e+04 393 | 7.393000000000000000e+04 394 | 7.394000000000000000e+04 395 | 7.395000000000000000e+04 396 | 7.396000000000000000e+04 397 | 7.397000000000000000e+04 398 | 7.398000000000000000e+04 399 | 7.399000000000000000e+04 400 | 7.400000000000000000e+04 401 | 7.401000000000000000e+04 402 | 7.402000000000000000e+04 403 | 7.403000000000000000e+04 404 | 7.404000000000000000e+04 405 | 7.405000000000000000e+04 406 | 7.406000000000000000e+04 407 | 7.407000000000000000e+04 408 | 7.408000000000000000e+04 409 | 7.409000000000000000e+04 410 | 7.410000000000000000e+04 411 | 7.411000000000000000e+04 412 | 7.412000000000000000e+04 413 | 7.413000000000000000e+04 414 | 7.414000000000000000e+04 415 | 7.415000000000000000e+04 416 | 7.416000000000000000e+04 417 | 7.417000000000000000e+04 418 | 7.418000000000000000e+04 419 | 7.419000000000000000e+04 420 | 7.420000000000000000e+04 421 | 7.421000000000000000e+04 422 | 7.422000000000000000e+04 423 | 7.423000000000000000e+04 424 | 7.424000000000000000e+04 425 | 7.425000000000000000e+04 426 | 7.426000000000000000e+04 427 | 7.427000000000000000e+04 428 | 7.428000000000000000e+04 429 | 7.429000000000000000e+04 430 | 7.430000000000000000e+04 431 | 7.431000000000000000e+04 432 | 7.432000000000000000e+04 433 | 7.433000000000000000e+04 434 | 7.434000000000000000e+04 435 | 7.435000000000000000e+04 436 | 7.436000000000000000e+04 437 | 7.437000000000000000e+04 438 | 7.438000000000000000e+04 439 | 7.439000000000000000e+04 440 | 7.440000000000000000e+04 441 | 7.441000000000000000e+04 442 | 7.442000000000000000e+04 443 | 7.443000000000000000e+04 444 | 7.444000000000000000e+04 445 | 7.445000000000000000e+04 446 | 7.446000000000000000e+04 447 | 7.447000000000000000e+04 448 | 7.448000000000000000e+04 449 | 7.449000000000000000e+04 450 | 7.450000000000000000e+04 451 | 7.451000000000000000e+04 452 | 7.452000000000000000e+04 453 | 7.453000000000000000e+04 454 | 7.454000000000000000e+04 455 | 7.455000000000000000e+04 456 | 7.456000000000000000e+04 457 | 7.457000000000000000e+04 458 | 7.458000000000000000e+04 459 | 7.459000000000000000e+04 460 | 7.460000000000000000e+04 461 | 7.461000000000000000e+04 462 | 7.462000000000000000e+04 463 | 7.463000000000000000e+04 464 | 7.464000000000000000e+04 465 | 7.465000000000000000e+04 466 | 7.466000000000000000e+04 467 | 7.467000000000000000e+04 468 | 7.468000000000000000e+04 469 | 7.469000000000000000e+04 470 | 7.470000000000000000e+04 471 | 7.471000000000000000e+04 472 | 7.472000000000000000e+04 473 | 7.473000000000000000e+04 474 | 7.474000000000000000e+04 475 | 7.475000000000000000e+04 476 | 7.476000000000000000e+04 477 | 7.477000000000000000e+04 478 | 7.478000000000000000e+04 479 | 7.479000000000000000e+04 480 | 7.480000000000000000e+04 481 | 7.481000000000000000e+04 482 | 7.482000000000000000e+04 483 | 7.483000000000000000e+04 484 | 7.484000000000000000e+04 485 | 7.485000000000000000e+04 486 | 7.486000000000000000e+04 487 | 7.487000000000000000e+04 488 | 7.488000000000000000e+04 489 | 7.489000000000000000e+04 490 | 7.490000000000000000e+04 491 | 7.491000000000000000e+04 492 | 7.492000000000000000e+04 493 | 7.493000000000000000e+04 494 | 7.494000000000000000e+04 495 | 7.495000000000000000e+04 496 | 7.496000000000000000e+04 497 | 7.497000000000000000e+04 498 | 7.498000000000000000e+04 499 | 7.499000000000000000e+04 500 | 7.500000000000000000e+04 501 | 7.501000000000000000e+04 502 | 7.502000000000000000e+04 503 | 7.503000000000000000e+04 504 | 7.504000000000000000e+04 505 | 7.505000000000000000e+04 506 | 7.506000000000000000e+04 507 | 7.507000000000000000e+04 508 | 7.508000000000000000e+04 509 | 7.509000000000000000e+04 510 | 7.510000000000000000e+04 511 | 7.511000000000000000e+04 512 | 7.512000000000000000e+04 513 | 7.513000000000000000e+04 514 | 7.514000000000000000e+04 515 | 7.515000000000000000e+04 516 | 7.516000000000000000e+04 517 | 7.517000000000000000e+04 518 | 7.518000000000000000e+04 519 | 7.519000000000000000e+04 520 | 7.520000000000000000e+04 521 | 7.521000000000000000e+04 522 | 7.522000000000000000e+04 523 | 7.523000000000000000e+04 524 | 7.524000000000000000e+04 525 | 7.525000000000000000e+04 526 | 7.526000000000000000e+04 527 | 7.527000000000000000e+04 528 | 7.528000000000000000e+04 529 | 7.529000000000000000e+04 530 | 7.530000000000000000e+04 531 | 7.531000000000000000e+04 532 | 7.532000000000000000e+04 533 | 7.533000000000000000e+04 534 | 7.534000000000000000e+04 535 | 7.535000000000000000e+04 536 | 7.536000000000000000e+04 537 | 7.537000000000000000e+04 538 | 7.538000000000000000e+04 539 | 7.539000000000000000e+04 540 | 7.540000000000000000e+04 541 | 7.541000000000000000e+04 542 | 7.542000000000000000e+04 543 | 7.543000000000000000e+04 544 | 7.544000000000000000e+04 545 | 7.545000000000000000e+04 546 | 7.546000000000000000e+04 547 | 7.547000000000000000e+04 548 | 7.548000000000000000e+04 549 | 7.549000000000000000e+04 550 | 7.550000000000000000e+04 551 | 7.551000000000000000e+04 552 | 7.552000000000000000e+04 553 | 7.553000000000000000e+04 554 | 7.554000000000000000e+04 555 | 7.555000000000000000e+04 556 | 7.556000000000000000e+04 557 | 7.557000000000000000e+04 558 | 7.558000000000000000e+04 559 | 7.559000000000000000e+04 560 | 7.560000000000000000e+04 561 | 7.561000000000000000e+04 562 | 7.562000000000000000e+04 563 | 7.563000000000000000e+04 564 | 7.564000000000000000e+04 565 | 7.565000000000000000e+04 566 | 7.566000000000000000e+04 567 | 7.567000000000000000e+04 568 | 7.568000000000000000e+04 569 | 7.569000000000000000e+04 570 | 7.570000000000000000e+04 571 | 7.571000000000000000e+04 572 | 7.572000000000000000e+04 573 | 7.573000000000000000e+04 574 | 7.574000000000000000e+04 575 | 7.575000000000000000e+04 576 | 7.576000000000000000e+04 577 | 7.577000000000000000e+04 578 | 7.578000000000000000e+04 579 | 7.579000000000000000e+04 580 | 7.580000000000000000e+04 581 | 7.581000000000000000e+04 582 | 7.582000000000000000e+04 583 | 7.583000000000000000e+04 584 | 7.584000000000000000e+04 585 | 7.585000000000000000e+04 586 | 7.586000000000000000e+04 587 | 7.587000000000000000e+04 588 | 7.588000000000000000e+04 589 | 7.589000000000000000e+04 590 | 7.590000000000000000e+04 591 | 7.591000000000000000e+04 592 | 7.592000000000000000e+04 593 | 7.593000000000000000e+04 594 | 7.594000000000000000e+04 595 | 7.595000000000000000e+04 596 | 7.596000000000000000e+04 597 | 7.597000000000000000e+04 598 | 7.598000000000000000e+04 599 | 7.599000000000000000e+04 600 | 7.600000000000000000e+04 601 | 7.601000000000000000e+04 602 | 7.602000000000000000e+04 603 | 7.603000000000000000e+04 604 | 7.604000000000000000e+04 605 | 7.605000000000000000e+04 606 | 7.606000000000000000e+04 607 | 7.607000000000000000e+04 608 | 7.608000000000000000e+04 609 | 7.609000000000000000e+04 610 | 7.610000000000000000e+04 611 | 7.611000000000000000e+04 612 | 7.612000000000000000e+04 613 | 7.613000000000000000e+04 614 | 7.614000000000000000e+04 615 | 7.615000000000000000e+04 616 | 7.616000000000000000e+04 617 | 7.617000000000000000e+04 618 | 7.618000000000000000e+04 619 | 7.619000000000000000e+04 620 | 7.620000000000000000e+04 621 | 7.621000000000000000e+04 622 | 7.622000000000000000e+04 623 | 7.623000000000000000e+04 624 | 7.624000000000000000e+04 625 | 7.625000000000000000e+04 626 | 7.626000000000000000e+04 627 | 7.627000000000000000e+04 628 | 7.628000000000000000e+04 629 | 7.629000000000000000e+04 630 | 7.630000000000000000e+04 631 | 7.631000000000000000e+04 632 | 7.632000000000000000e+04 633 | 7.633000000000000000e+04 634 | 7.634000000000000000e+04 635 | 7.635000000000000000e+04 636 | 7.636000000000000000e+04 637 | 7.637000000000000000e+04 638 | 7.638000000000000000e+04 639 | 7.639000000000000000e+04 640 | 7.640000000000000000e+04 641 | 7.641000000000000000e+04 642 | 7.642000000000000000e+04 643 | 7.643000000000000000e+04 644 | 7.644000000000000000e+04 645 | 7.645000000000000000e+04 646 | 7.646000000000000000e+04 647 | 7.647000000000000000e+04 648 | 7.648000000000000000e+04 649 | 7.649000000000000000e+04 650 | 7.650000000000000000e+04 651 | 7.651000000000000000e+04 652 | 7.652000000000000000e+04 653 | 7.653000000000000000e+04 654 | 7.654000000000000000e+04 655 | 7.655000000000000000e+04 656 | 7.656000000000000000e+04 657 | 7.657000000000000000e+04 658 | 7.658000000000000000e+04 659 | 7.659000000000000000e+04 660 | 7.660000000000000000e+04 661 | 7.661000000000000000e+04 662 | 7.662000000000000000e+04 663 | 7.663000000000000000e+04 664 | 7.664000000000000000e+04 665 | 7.665000000000000000e+04 666 | 7.666000000000000000e+04 667 | 7.667000000000000000e+04 668 | 7.668000000000000000e+04 669 | 7.669000000000000000e+04 670 | 7.670000000000000000e+04 671 | 7.671000000000000000e+04 672 | 7.672000000000000000e+04 673 | 7.673000000000000000e+04 674 | 7.674000000000000000e+04 675 | 7.675000000000000000e+04 676 | 7.676000000000000000e+04 677 | 7.677000000000000000e+04 678 | 7.678000000000000000e+04 679 | 7.679000000000000000e+04 680 | 7.680000000000000000e+04 681 | 7.681000000000000000e+04 682 | 7.682000000000000000e+04 683 | 7.683000000000000000e+04 684 | 7.684000000000000000e+04 685 | 7.685000000000000000e+04 686 | 7.686000000000000000e+04 687 | 7.687000000000000000e+04 688 | 7.688000000000000000e+04 689 | 7.689000000000000000e+04 690 | 7.690000000000000000e+04 691 | 7.691000000000000000e+04 692 | 7.692000000000000000e+04 693 | 7.693000000000000000e+04 694 | 7.694000000000000000e+04 695 | 7.695000000000000000e+04 696 | 7.696000000000000000e+04 697 | 7.697000000000000000e+04 698 | 7.698000000000000000e+04 699 | 7.699000000000000000e+04 700 | 7.700000000000000000e+04 701 | 7.701000000000000000e+04 702 | 7.702000000000000000e+04 703 | 7.703000000000000000e+04 704 | 7.704000000000000000e+04 705 | 7.705000000000000000e+04 706 | 7.706000000000000000e+04 707 | 7.707000000000000000e+04 708 | 7.708000000000000000e+04 709 | 7.709000000000000000e+04 710 | 7.710000000000000000e+04 711 | 7.711000000000000000e+04 712 | 7.712000000000000000e+04 713 | 7.713000000000000000e+04 714 | 7.714000000000000000e+04 715 | 7.715000000000000000e+04 716 | 7.716000000000000000e+04 717 | 7.717000000000000000e+04 718 | 7.718000000000000000e+04 719 | 7.719000000000000000e+04 720 | 7.720000000000000000e+04 721 | 7.721000000000000000e+04 722 | 7.722000000000000000e+04 723 | 7.723000000000000000e+04 724 | 7.724000000000000000e+04 725 | 7.725000000000000000e+04 726 | 7.726000000000000000e+04 727 | 7.727000000000000000e+04 728 | 7.728000000000000000e+04 729 | 7.729000000000000000e+04 730 | 7.730000000000000000e+04 731 | 7.731000000000000000e+04 732 | 7.732000000000000000e+04 733 | 7.733000000000000000e+04 734 | 7.734000000000000000e+04 735 | 7.735000000000000000e+04 736 | 7.736000000000000000e+04 737 | 7.737000000000000000e+04 738 | 7.738000000000000000e+04 739 | 7.739000000000000000e+04 740 | 7.740000000000000000e+04 741 | 7.741000000000000000e+04 742 | 7.742000000000000000e+04 743 | 7.743000000000000000e+04 744 | 7.744000000000000000e+04 745 | 7.745000000000000000e+04 746 | 7.746000000000000000e+04 747 | 7.747000000000000000e+04 748 | 7.748000000000000000e+04 749 | 7.749000000000000000e+04 750 | 7.750000000000000000e+04 751 | 7.751000000000000000e+04 752 | 7.752000000000000000e+04 753 | 7.753000000000000000e+04 754 | 7.754000000000000000e+04 755 | 7.755000000000000000e+04 756 | 7.756000000000000000e+04 757 | 7.757000000000000000e+04 758 | 7.758000000000000000e+04 759 | 7.759000000000000000e+04 760 | 7.760000000000000000e+04 761 | 7.761000000000000000e+04 762 | 7.762000000000000000e+04 763 | 7.763000000000000000e+04 764 | 7.764000000000000000e+04 765 | 7.765000000000000000e+04 766 | 7.766000000000000000e+04 767 | 7.767000000000000000e+04 768 | 7.768000000000000000e+04 769 | 7.769000000000000000e+04 770 | 7.770000000000000000e+04 771 | 7.771000000000000000e+04 772 | 7.772000000000000000e+04 773 | 7.773000000000000000e+04 774 | 7.774000000000000000e+04 775 | 7.775000000000000000e+04 776 | 7.776000000000000000e+04 777 | 7.777000000000000000e+04 778 | 7.778000000000000000e+04 779 | 7.779000000000000000e+04 780 | 7.780000000000000000e+04 781 | 7.781000000000000000e+04 782 | 7.782000000000000000e+04 783 | 7.783000000000000000e+04 784 | 7.784000000000000000e+04 785 | 7.785000000000000000e+04 786 | 7.786000000000000000e+04 787 | 7.787000000000000000e+04 788 | 7.788000000000000000e+04 789 | 7.789000000000000000e+04 790 | 7.790000000000000000e+04 791 | 7.791000000000000000e+04 792 | 7.792000000000000000e+04 793 | 7.793000000000000000e+04 794 | 7.794000000000000000e+04 795 | 7.795000000000000000e+04 796 | 7.796000000000000000e+04 797 | 7.797000000000000000e+04 798 | 7.798000000000000000e+04 799 | 7.799000000000000000e+04 800 | 7.800000000000000000e+04 801 | 7.801000000000000000e+04 802 | 7.802000000000000000e+04 803 | 7.803000000000000000e+04 804 | 7.804000000000000000e+04 805 | 7.805000000000000000e+04 806 | 7.806000000000000000e+04 807 | 7.807000000000000000e+04 808 | 7.808000000000000000e+04 809 | 7.809000000000000000e+04 810 | 7.810000000000000000e+04 811 | 7.811000000000000000e+04 812 | 7.812000000000000000e+04 813 | 7.813000000000000000e+04 814 | 7.814000000000000000e+04 815 | 7.815000000000000000e+04 816 | 7.816000000000000000e+04 817 | 7.817000000000000000e+04 818 | 7.818000000000000000e+04 819 | 7.819000000000000000e+04 820 | 7.820000000000000000e+04 821 | 7.821000000000000000e+04 822 | 7.822000000000000000e+04 823 | 7.823000000000000000e+04 824 | 7.824000000000000000e+04 825 | 7.825000000000000000e+04 826 | 7.826000000000000000e+04 827 | 7.827000000000000000e+04 828 | 7.828000000000000000e+04 829 | 7.829000000000000000e+04 830 | 7.830000000000000000e+04 831 | 7.831000000000000000e+04 832 | 7.832000000000000000e+04 833 | 7.833000000000000000e+04 834 | 7.834000000000000000e+04 835 | 7.835000000000000000e+04 836 | 7.836000000000000000e+04 837 | 7.837000000000000000e+04 838 | 7.838000000000000000e+04 839 | 7.839000000000000000e+04 840 | 7.840000000000000000e+04 841 | 7.841000000000000000e+04 842 | 7.842000000000000000e+04 843 | 7.843000000000000000e+04 844 | 7.844000000000000000e+04 845 | 7.845000000000000000e+04 846 | 7.846000000000000000e+04 847 | 7.847000000000000000e+04 848 | 7.848000000000000000e+04 849 | 7.849000000000000000e+04 850 | 7.850000000000000000e+04 851 | 7.851000000000000000e+04 852 | 7.852000000000000000e+04 853 | 7.853000000000000000e+04 854 | 7.854000000000000000e+04 855 | 7.855000000000000000e+04 856 | 7.856000000000000000e+04 857 | 7.857000000000000000e+04 858 | 7.858000000000000000e+04 859 | 7.859000000000000000e+04 860 | 7.860000000000000000e+04 861 | 7.861000000000000000e+04 862 | 7.862000000000000000e+04 863 | 7.863000000000000000e+04 864 | 7.864000000000000000e+04 865 | 7.865000000000000000e+04 866 | 7.866000000000000000e+04 867 | 7.867000000000000000e+04 868 | 7.868000000000000000e+04 869 | 7.869000000000000000e+04 870 | 7.870000000000000000e+04 871 | 7.871000000000000000e+04 872 | 7.872000000000000000e+04 873 | 7.873000000000000000e+04 874 | 7.874000000000000000e+04 875 | 7.875000000000000000e+04 876 | 7.876000000000000000e+04 877 | 7.877000000000000000e+04 878 | 7.878000000000000000e+04 879 | 7.879000000000000000e+04 880 | 7.880000000000000000e+04 881 | 7.881000000000000000e+04 882 | 7.882000000000000000e+04 883 | 7.883000000000000000e+04 884 | 7.884000000000000000e+04 885 | 7.885000000000000000e+04 886 | 7.886000000000000000e+04 887 | 7.887000000000000000e+04 888 | 7.888000000000000000e+04 889 | 7.889000000000000000e+04 890 | 7.890000000000000000e+04 891 | 7.891000000000000000e+04 892 | 7.892000000000000000e+04 893 | 7.893000000000000000e+04 894 | 7.894000000000000000e+04 895 | 7.895000000000000000e+04 896 | 7.896000000000000000e+04 897 | 7.897000000000000000e+04 898 | 7.898000000000000000e+04 899 | 7.899000000000000000e+04 900 | 7.900000000000000000e+04 901 | 7.901000000000000000e+04 902 | 7.902000000000000000e+04 903 | 7.903000000000000000e+04 904 | 7.904000000000000000e+04 905 | 7.905000000000000000e+04 906 | 7.906000000000000000e+04 907 | 7.907000000000000000e+04 908 | 7.908000000000000000e+04 909 | 7.909000000000000000e+04 910 | 7.910000000000000000e+04 911 | 7.911000000000000000e+04 912 | 7.912000000000000000e+04 913 | 7.913000000000000000e+04 914 | 7.914000000000000000e+04 915 | 7.915000000000000000e+04 916 | 7.916000000000000000e+04 917 | 7.917000000000000000e+04 918 | 7.918000000000000000e+04 919 | 7.919000000000000000e+04 920 | 7.920000000000000000e+04 921 | 7.921000000000000000e+04 922 | 7.922000000000000000e+04 923 | 7.923000000000000000e+04 924 | 7.924000000000000000e+04 925 | 7.925000000000000000e+04 926 | 7.926000000000000000e+04 927 | 7.927000000000000000e+04 928 | 7.928000000000000000e+04 929 | 7.929000000000000000e+04 930 | 7.930000000000000000e+04 931 | 7.931000000000000000e+04 932 | 7.932000000000000000e+04 933 | 7.933000000000000000e+04 934 | 7.934000000000000000e+04 935 | 7.935000000000000000e+04 936 | 7.936000000000000000e+04 937 | 7.937000000000000000e+04 938 | 7.938000000000000000e+04 939 | 7.939000000000000000e+04 940 | 7.940000000000000000e+04 941 | 7.941000000000000000e+04 942 | 7.942000000000000000e+04 943 | 7.943000000000000000e+04 944 | 7.944000000000000000e+04 945 | 7.945000000000000000e+04 946 | 7.946000000000000000e+04 947 | 7.947000000000000000e+04 948 | 7.948000000000000000e+04 949 | 7.949000000000000000e+04 950 | 7.950000000000000000e+04 951 | 7.951000000000000000e+04 952 | 7.952000000000000000e+04 953 | 7.953000000000000000e+04 954 | 7.954000000000000000e+04 955 | 7.955000000000000000e+04 956 | 7.956000000000000000e+04 957 | 7.957000000000000000e+04 958 | 7.958000000000000000e+04 959 | 7.959000000000000000e+04 960 | 7.960000000000000000e+04 961 | 7.961000000000000000e+04 962 | 7.962000000000000000e+04 963 | 7.963000000000000000e+04 964 | 7.964000000000000000e+04 965 | 7.965000000000000000e+04 966 | 7.966000000000000000e+04 967 | 7.967000000000000000e+04 968 | 7.968000000000000000e+04 969 | 7.969000000000000000e+04 970 | 7.970000000000000000e+04 971 | 7.971000000000000000e+04 972 | 7.972000000000000000e+04 973 | 7.973000000000000000e+04 974 | 7.974000000000000000e+04 975 | 7.975000000000000000e+04 976 | 7.976000000000000000e+04 977 | 7.977000000000000000e+04 978 | 7.978000000000000000e+04 979 | 7.979000000000000000e+04 980 | 7.980000000000000000e+04 981 | 7.981000000000000000e+04 982 | 7.982000000000000000e+04 983 | 7.983000000000000000e+04 984 | 7.984000000000000000e+04 985 | 7.985000000000000000e+04 986 | 7.986000000000000000e+04 987 | 7.987000000000000000e+04 988 | 7.988000000000000000e+04 989 | 7.989000000000000000e+04 990 | 7.990000000000000000e+04 991 | 7.991000000000000000e+04 992 | 7.992000000000000000e+04 993 | 7.993000000000000000e+04 994 | 7.994000000000000000e+04 995 | 7.995000000000000000e+04 996 | 7.996000000000000000e+04 997 | 7.997000000000000000e+04 998 | 7.998000000000000000e+04 999 | 7.999000000000000000e+04 1000 | 8.000000000000000000e+04 1001 | 8.001000000000000000e+04 1002 | 8.010000000000000000e+04 1003 | 8.020000000000000000e+04 1004 | 8.030000000000000000e+04 1005 | 8.040000000000000000e+04 1006 | 8.050000000000000000e+04 1007 | 8.060000000000000000e+04 1008 | 8.070000000000000000e+04 1009 | 8.080000000000000000e+04 1010 | 8.090000000000000000e+04 1011 | 8.100000000000000000e+04 1012 | 8.110000000000000000e+04 1013 | 8.120000000000000000e+04 1014 | 8.130000000000000000e+04 1015 | 8.140000000000000000e+04 1016 | 8.150000000000000000e+04 1017 | 8.160000000000000000e+04 1018 | 8.170000000000000000e+04 1019 | 8.180000000000000000e+04 1020 | 8.190000000000000000e+04 1021 | 8.200000000000000000e+04 1022 | 8.210000000000000000e+04 1023 | 8.220000000000000000e+04 1024 | 8.230000000000000000e+04 1025 | 8.240000000000000000e+04 1026 | 8.250000000000000000e+04 1027 | 8.260000000000000000e+04 1028 | 8.270000000000000000e+04 1029 | 8.280000000000000000e+04 1030 | 8.290000000000000000e+04 1031 | 8.300000000000000000e+04 1032 | 8.310000000000000000e+04 1033 | 8.320000000000000000e+04 1034 | 8.330000000000000000e+04 1035 | 8.340000000000000000e+04 1036 | 8.350000000000000000e+04 1037 | 8.360000000000000000e+04 1038 | 8.370000000000000000e+04 1039 | 8.380000000000000000e+04 1040 | 8.390000000000000000e+04 1041 | 8.400000000000000000e+04 1042 | 8.410000000000000000e+04 1043 | 8.420000000000000000e+04 1044 | 8.430000000000000000e+04 1045 | 8.440000000000000000e+04 1046 | 8.450000000000000000e+04 1047 | 8.460000000000000000e+04 1048 | 8.470000000000000000e+04 1049 | 8.480000000000000000e+04 1050 | 8.490000000000000000e+04 1051 | 8.500000000000000000e+04 1052 | 8.510000000000000000e+04 1053 | 8.520000000000000000e+04 1054 | 8.530000000000000000e+04 1055 | 8.540000000000000000e+04 1056 | 8.550000000000000000e+04 1057 | 8.560000000000000000e+04 1058 | 8.570000000000000000e+04 1059 | 8.580000000000000000e+04 1060 | 8.590000000000000000e+04 1061 | 8.600000000000000000e+04 1062 | 8.610000000000000000e+04 1063 | 8.620000000000000000e+04 1064 | 8.630000000000000000e+04 1065 | 8.640000000000000000e+04 1066 | 8.650000000000000000e+04 1067 | 8.660000000000000000e+04 1068 | 8.670000000000000000e+04 1069 | 8.680000000000000000e+04 1070 | 8.690000000000000000e+04 1071 | 8.700000000000000000e+04 1072 | 8.710000000000000000e+04 1073 | 8.720000000000000000e+04 1074 | 8.730000000000000000e+04 1075 | 8.740000000000000000e+04 1076 | 8.750000000000000000e+04 1077 | 8.760000000000000000e+04 1078 | 8.770000000000000000e+04 1079 | 8.780000000000000000e+04 1080 | 8.790000000000000000e+04 1081 | 8.800000000000000000e+04 1082 | 8.810000000000000000e+04 1083 | 8.820000000000000000e+04 1084 | 8.830000000000000000e+04 1085 | 8.840000000000000000e+04 1086 | 8.850000000000000000e+04 1087 | 8.860000000000000000e+04 1088 | 8.870000000000000000e+04 1089 | 8.880000000000000000e+04 1090 | 8.890000000000000000e+04 1091 | -------------------------------------------------------------------------------- /work/partition/CASIA-B_73_False.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwanglin/GaitSet_learning/5c7f63e1b4bf85b3afab6cd9ec00ea36a2a6e4e4/work/partition/CASIA-B_73_False.npy --------------------------------------------------------------------------------