├── opengait ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── sampler.cpython-38.pyc │ │ ├── collate_fn.cpython-38.pyc │ │ └── transform.cpython-38.pyc │ ├── sampler.py │ ├── collate_fn.py │ ├── dataset.py │ └── transform.py ├── evaluation │ ├── __pycache__ │ │ ├── metric.cpython-38.pyc │ │ ├── re_rank.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── evaluator.cpython-38.pyc │ ├── __init__.py │ ├── re_rank.py │ ├── metric.py │ └── evaluator.py ├── utils │ ├── __init__.py │ ├── msg_manager.py │ └── common.py ├── modeling │ ├── backbones │ │ ├── __init__.py │ │ └── GLGait.py │ ├── losses │ │ ├── __init__.py │ │ ├── ce.py │ │ ├── base.py │ │ └── triplet.py │ ├── models │ │ ├── __init__.py │ │ └── baseline_trans.py │ ├── loss_aggregator.py │ ├── modules.py │ └── base_model.py └── main.py ├── output ├── GREW │ └── Baseline_trans │ │ └── GLGait-L │ │ └── checkpoints │ │ └── readme.md └── Gait3D │ └── Baseline_trans │ └── GLGait-L │ └── checkpoints │ └── readme.md ├── configs ├── default.yaml └── GLGait │ ├── GLGait_Gait3D.yaml │ └── GLGait_GREW.yaml ├── README.md └── datasets ├── Gait3D ├── merge_two_modality.py └── README.md └── GREW ├── README.md └── rearrange_GREW.py /opengait/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/GREW/Baseline_trans/GLGait-L/checkpoints/readme.md: -------------------------------------------------------------------------------- 1 | Download the checkpoint in this file. 2 | -------------------------------------------------------------------------------- /output/Gait3D/Baseline_trans/GLGait-L/checkpoints/readme.md: -------------------------------------------------------------------------------- 1 | Download the checkpoints in this file. 2 | -------------------------------------------------------------------------------- /opengait/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/data/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/data/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/data/__pycache__/collate_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/data/__pycache__/collate_fn.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/data/__pycache__/transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/data/__pycache__/transform.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/evaluation/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/evaluation/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/evaluation/__pycache__/re_rank.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/evaluation/__pycache__/re_rank.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/evaluation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/evaluation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/evaluation/__pycache__/evaluator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bgdpgz/GLGait/HEAD/opengait/evaluation/__pycache__/evaluator.cpython-38.pyc -------------------------------------------------------------------------------- /opengait/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import mean_iou 2 | from numpy import set_printoptions 3 | set_printoptions(suppress=True, formatter={'float': '{:0.2f}'.format}) 4 | -------------------------------------------------------------------------------- /opengait/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import get_ddp_module, ddp_all_gather 2 | from .common import Odict, Ntuple 3 | from .common import get_valid_args 4 | from .common import is_list_or_tuple, is_bool, is_str, is_list, is_dict, is_tensor, is_array, config_loader, init_seeds, handler, params_count 5 | from .common import ts2np, ts2var, np2var, list2var 6 | from .common import mkdir, clones 7 | from .common import MergeCfgsDict 8 | from .common import get_attr_from 9 | from .common import NoOp 10 | from .msg_manager import get_msg_mgr -------------------------------------------------------------------------------- /opengait/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | 6 | # iterate through the modules in the current package 7 | package_dir = Path(__file__).resolve().parent 8 | for (_, module_name, _) in iter_modules([str(package_dir)]): 9 | 10 | # import the module and iterate through its attributes 11 | module = import_module(f"{__name__}.{module_name}") 12 | for attribute_name in dir(module): 13 | attribute = getattr(module, attribute_name) 14 | 15 | if isclass(attribute): 16 | # Add the class to this package's variables 17 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /opengait/modeling/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | 6 | # iterate through the modules in the current package 7 | package_dir = Path(__file__).resolve().parent 8 | for (_, module_name, _) in iter_modules([str(package_dir)]): 9 | 10 | # import the module and iterate through its attributes 11 | module = import_module(f"{__name__}.{module_name}") 12 | for attribute_name in dir(module): 13 | attribute = getattr(module, attribute_name) 14 | 15 | if isclass(attribute): 16 | # Add the class to this package's variables 17 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /opengait/modeling/models/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | 6 | # iterate through the modules in the current package 7 | package_dir = Path(__file__).resolve().parent 8 | for (_, module_name, _) in iter_modules([str(package_dir)]): 9 | 10 | # import the module and iterate through its attributes 11 | module = import_module(f"{__name__}.{module_name}") 12 | for attribute_name in dir(module): 13 | attribute = getattr(module, attribute_name) 14 | 15 | if isclass(attribute): 16 | # Add the class to this package's variables 17 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /opengait/modeling/losses/ce.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from .base import BaseLoss, gather_and_scale_wrapper 6 | 7 | class CrossEntropyLoss(BaseLoss): 8 | def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False): 9 | super(CrossEntropyLoss, self).__init__(loss_term_weight) 10 | self.scale = scale 11 | self.label_smooth = label_smooth 12 | self.eps = eps 13 | self.log_accuracy = log_accuracy 14 | 15 | def forward(self, logits, labels): 16 | """ 17 | logits: [n, c, p] 18 | labels: [n] 19 | """ 20 | n, c, p = logits.size() 21 | logits = logits.float() 22 | labels = labels.unsqueeze(1) 23 | if self.label_smooth: 24 | loss = F.cross_entropy( 25 | logits*self.scale, labels.repeat(1, p), label_smoothing=self.eps) 26 | else: 27 | loss = F.cross_entropy(logits*self.scale, labels.repeat(1, p)) 28 | self.info.update({'loss': loss.detach().clone()}) 29 | if self.log_accuracy: 30 | pred = logits.argmax(dim=1) # [n, p] 31 | accu = (pred == labels).float().mean() 32 | self.info.update({'accuracy': accu}) 33 | return loss, self.info 34 | 35 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | data_cfg: 2 | dataset_name: CASIA-B 3 | dataset_root: your_path 4 | num_workers: 1 5 | dataset_partition: ./datasets/CASIA-B/CASIA-B.json 6 | remove_no_gallery: false 7 | cache: false 8 | test_dataset_name: CASIA-B 9 | 10 | evaluator_cfg: 11 | enable_float16: false 12 | restore_ckpt_strict: true 13 | restore_hint: 80000 14 | save_name: tmp 15 | eval_func: evaluate_indoor_dataset 16 | sampler: 17 | batch_size: 4 18 | sample_type: all_ordered 19 | type: InferenceSampler 20 | transform: 21 | - type: BaseSilCuttingTransform 22 | metric: euc # cos 23 | cross_view_gallery: false 24 | 25 | loss_cfg: 26 | loss_term_weight: 1.0 27 | margin: 0.2 28 | type: TripletLoss 29 | log_prefix: triplet 30 | 31 | model_cfg: 32 | model: Baseline 33 | 34 | #optimizer_cfg: 35 | # lr: 0.1 36 | # momentum: 0.9 37 | # solver: SGD 38 | # weight_decay: 0.0005 39 | # 40 | #scheduler_cfg: 41 | # gamma: 0.1 42 | # milestones: 43 | # - 20000 44 | # - 40000 45 | # - 60000 46 | # scheduler: MultiStepLR 47 | 48 | trainer_cfg: 49 | enable_float16: true 50 | with_test: false 51 | fix_BN: false 52 | log_iter: 100 53 | restore_ckpt_strict: true 54 | optimizer_reset: false 55 | scheduler_reset: false 56 | restore_hint: 0 57 | save_iter: 2000 58 | save_name: tmp 59 | sync_BN: false 60 | total_iter: 80000 61 | sampler: 62 | batch_shuffle: false 63 | batch_size: 64 | - 8 65 | - 16 66 | frames_num_fixed: 30 67 | frames_num_max: 50 68 | frames_num_min: 25 69 | sample_type: fixed_unordered 70 | type: TripletSampler 71 | transform: 72 | - type: BaseSilCuttingTransform 73 | -------------------------------------------------------------------------------- /opengait/modeling/losses/base.py: -------------------------------------------------------------------------------- 1 | from ctypes import ArgumentError 2 | import torch.nn as nn 3 | import torch 4 | from utils import Odict 5 | import functools 6 | from utils import ddp_all_gather 7 | 8 | 9 | def gather_and_scale_wrapper(func): 10 | """Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards. 11 | """ 12 | 13 | @functools.wraps(func) 14 | def inner(*args, **kwds): 15 | try: 16 | 17 | for k, v in kwds.items(): 18 | kwds[k] = ddp_all_gather(v) 19 | 20 | loss, loss_info = func(*args, **kwds) 21 | loss *= torch.distributed.get_world_size() 22 | return loss, loss_info 23 | except: 24 | raise ArgumentError 25 | return inner 26 | 27 | 28 | class BaseLoss(nn.Module): 29 | """ 30 | Base class for all losses. 31 | 32 | Your loss should also subclass this class. 33 | """ 34 | 35 | def __init__(self, loss_term_weight=1.0): 36 | """ 37 | Initialize the base class. 38 | 39 | Args: 40 | loss_term_weight: the weight of the loss term. 41 | """ 42 | super(BaseLoss, self).__init__() 43 | self.loss_term_weight = loss_term_weight 44 | self.info = Odict() 45 | 46 | def forward(self, logits, labels): 47 | """ 48 | The default forward function. 49 | 50 | This function should be overridden by the subclass. 51 | 52 | Args: 53 | logits: the logits of the model. 54 | labels: the labels of the data. 55 | 56 | Returns: 57 | tuple of loss and info. 58 | """ 59 | return .0, self.info 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GLGait: A Global Local Temporal Receptive Field Network for Gait Recognition in the Wild 2 | [Paper](https://arxiv.org/abs/2408.06834) has been accepted in ACM MM 2024. This is the code for it. 3 | # Operating Environments 4 | ## Pytorch Environment 5 | * Pytorch=1.11.0 6 | * Python=3.8 7 | # CheckPoints 8 | * The checkpoint for Gait3D BaiduNetdisk [link](https://pan.baidu.com/s/1quNAQ1pTOHUa3tpfGCQ7IQ?pwd=fue3), huggingface [link](https://huggingface.co/bgdpgz/GLGait/tree/main/GLGait). 9 | * The checkpoint for GREW BaiduNetdisk [link](https://pan.baidu.com/s/1H41p_FQjSkL8Jn_2xWWsLA?pwd=soci), huggingface [link](https://huggingface.co/bgdpgz/GLGait/tree/main/GLGait). 10 | # Train and Test 11 | ## Train 12 | ``` 13 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port=1354 opengait/main.py --cfgs ./configs/GLGait/GLGait_Gait3D.yaml --phase train 14 | ``` 15 | ## Test 16 | ``` 17 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port=1354 opengait/main.py --cfgs ./configs/GLGait/GLGait_Gait3D.yaml --phase test 18 | ``` 19 | * python -m torch.distributed.launch: DDP launch instruction. 20 | * --nproc_per_node: The number of gpus to use, and it must equal the length of CUDA_VISIBLE_DEVICES. 21 | * --cfgs: The path to config file. 22 | * --phase: Specified as train or test. 23 | # Acknowledge 24 | The codebase is based on [OpenGait](https://github.com/ShiqiYu/OpenGait). 25 | # Citation 26 | ``` 27 | @inproceedings{peng2024glgait, 28 | title={GLGait: A Global-Local Temporal Receptive Field Network for Gait Recognition in the Wild}, 29 | author={Peng, Guozhen and Wang, Yunhong and Zhao, Yuwei and Zhang, Shaoxiong and Li, Annan}, 30 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia}, 31 | pages={826--835}, 32 | year={2024} 33 | } 34 | 35 | ``` 36 | Note: This code is only used for academic purposes, people cannot use this code for anything that might be considered commercial use. 37 | -------------------------------------------------------------------------------- /datasets/Gait3D/merge_two_modality.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | import shutil 5 | 6 | 7 | def merge(sils_path, smpls_path, output_path, link): 8 | if link == 'hard': 9 | link_method = os.link 10 | elif link == 'soft': 11 | link_method = os.symlink 12 | else: 13 | link_method = shutil.copyfile 14 | for _id in os.listdir(sils_path): 15 | id_path = os.path.join(sils_path, _id) 16 | for _type in os.listdir(id_path): 17 | type_path = os.path.join(id_path, _type) 18 | for _view in os.listdir(type_path): 19 | view_path = os.path.join(type_path, _view) 20 | for _seq in os.listdir(view_path): 21 | sils_seq_path = os.path.join(view_path, _seq) 22 | smpls_seq_path = os.path.join( 23 | smpls_path, _id, _type, _view, _seq) 24 | output_seq_path = os.path.join(output_path, _id, _type, _view) 25 | os.makedirs(output_seq_path, exist_ok=True) 26 | link_method(sils_seq_path, os.path.join( 27 | output_seq_path, "sils-"+_seq)) 28 | link_method(smpls_seq_path, os.path.join( 29 | output_seq_path, "smpls-"+_seq)) 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser(description='Gait3D dataset mergence.') 34 | parser.add_argument('--sils_path', default='', type=str, 35 | help='Root path of raw silhs dataset.') 36 | parser.add_argument('--smpls_path', default='', type=str, 37 | help='Root path of raw smpls dataset.') 38 | parser.add_argument('-o', '--output_path', default='', 39 | type=str, help='Output path of pickled dataset.') 40 | parser.add_argument('-l', '--link', default='hard', type=str, 41 | choices=['hard', 'soft', 'copy'], help='Link type of output data.') 42 | args = parser.parse_args() 43 | 44 | merge(sils_path=Path(args.sils_path), smpls_path=Path( 45 | args.smpls_path), output_path=Path(args.output_path), link=args.link) 46 | -------------------------------------------------------------------------------- /datasets/Gait3D/README.md: -------------------------------------------------------------------------------- 1 | # Gait3D 2 | This is the pre-processing instructions for the Gait3D dataset. The original dataset can be found [here](https://gait3d.github.io/). The original dataset is not publicly available. You need to request access to the dataset in order to download it. This README explains how to extract the original dataset and convert it to a format suitable for OpenGait. 3 | ## Data Preparation 4 | https://github.com/Gait3D/Gait3D-Benchmark#data-preparation 5 | ## Data Pretreatment 6 | ```python 7 | python datasets/pretreatment.py --input_path 'Gait3D/2D_Silhouettes' --output_path 'Gait3D-sils-64-64-pkl' 8 | python datasets/pretreatment_smpl.py --input_path 'Gait3D/3D_SMPLs' --output_path 'Gait3D-smpls-pkl' 9 | 10 | (optional) python datasets/pretreatment.py --input_path 'Gait3D/2D_Silhouettes' --img_size 128 --output_path 'Gait3D-sils-128-128-pkl' 11 | 12 | python datasets/Gait3D/merge_two_modality.py --sils_path 'Gait3D-sils-64-64-pkl' --smpls_path 'Gait3D-smpls-pkl' --output_path 'Gait3D-merged-pkl' --link 'hard' 13 | ``` 14 | 15 | ## Train 16 | ### Baseline model: 17 | `CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 opengait/main.py --cfgs ./configs/baseline/baseline_Gait3D.yaml --phase train` 18 | ### SMPLGait model: 19 | `CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 opengait/main.py --cfgs ./configs/smplgait/smplgait.yaml --phase train` 20 | 21 | ## Citation 22 | If you use this dataset in your research, please cite the following paper: 23 | ``` 24 | @inproceedings{zheng2022gait3d, 25 | title={Gait Recognition in the Wild with Dense 3D Representations and A Benchmark}, 26 | author={Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei}, 27 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 28 | year={2022} 29 | } 30 | ``` 31 | If you think the re-implementation of OpenGait is useful, please cite the following paper: 32 | ``` 33 | @misc{fan2022opengait, 34 | title={OpenGait: Revisiting Gait Recognition Toward Better Practicality}, 35 | author={Chao Fan and Junhao Liang and Chuanfu Shen and Saihui Hou and Yongzhen Huang and Shiqi Yu}, 36 | year={2022}, 37 | eprint={2211.06597}, 38 | archivePrefix={arXiv}, 39 | primaryClass={cs.CV} 40 | } 41 | ``` 42 | ## Acknowledgements 43 | This dataset was collected by the [Zheng at. al.](https://gait3d.github.io/). The pre-processing instructions are modified from (https://github.com/Gait3D/Gait3D-Benchmark). 44 | -------------------------------------------------------------------------------- /opengait/modeling/models/baseline_trans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..base_model import BaseModel 4 | from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks 5 | 6 | 7 | class Baseline_trans(BaseModel): 8 | 9 | def build_network(self, model_cfg): 10 | self.Backbone = self.get_backbone(model_cfg['backbone_cfg']) 11 | self.Backbone = SetBlockWrapper(self.Backbone) 12 | self.FCs = SeparateFCs(**model_cfg['SeparateFCs']) 13 | self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks']) 14 | self.TP = PackSequenceWrapper(torch.max) 15 | self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']) 16 | def forward(self, inputs): 17 | 18 | ipts, labs, _, _, _, _, seqL = inputs 19 | sils = ipts[0] 20 | 21 | if len(sils.size()) == 4: 22 | sils = sils.unsqueeze(1) 23 | if sils.size()[2]==1: 24 | sils = torch.cat((sils,sils,sils),dim=2) 25 | if sils.size()[2]==2: 26 | sils = torch.cat((sils,sils[:,:,-1,:,:].unsqueeze(2)),dim=2) 27 | # 28 | if sils.size()[2]%3!=0: 29 | num = sils.size()[2]//3 30 | sils = sils[:,:,:num*3,:,:] 31 | 32 | 33 | del ipts 34 | outs = self.Backbone(sils) # [n, c, s, h, w] 35 | 36 | seqL[0] = sils.size()[2] 37 | 38 | outs_tp, indice = self.TP(outs, seqL, options={"dim": 2}) # [n, c, h, w] 39 | 40 | 41 | feat = self.HPP(outs_tp) # [n, c, p] 42 | 43 | 44 | embed_1 = self.FCs(feat) # [n, c, p] 45 | embed = embed_1 46 | 47 | n, _, s, h, w = sils.size() 48 | 49 | bnn = self.BNNecks.fc_bin[:, :, labs].permute(2, 1, 0).contiguous().float() # [n,c,p] 50 | 51 | if self.training: 52 | embed_2, logits = self.BNNecks(embed_1) # [n, c, p] 53 | retval = { 54 | 'training_feat': { 55 | #'triplet': {'embeddings': embed_1, 'labels': labs}, 56 | 'ctl': {'embeddings': embed_1, 'labels': labs, 'bnn': bnn}, 57 | 'softmax': {'logits': logits, 'labels': labs, }, 58 | }, 59 | 'visual_summary': { 60 | 'image/sils': sils.reshape(n*s, 1, h, w) 61 | }, 62 | 'inference_feat': { 63 | 'embeddings': embed 64 | } 65 | } 66 | else: 67 | retval = { 68 | 'visual_summary': { 69 | 'image/sils': sils.view(n * s, 1, h, w) 70 | }, 71 | 'inference_feat': { 72 | 'embeddings': embed 73 | } 74 | } 75 | return retval 76 | -------------------------------------------------------------------------------- /opengait/main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | from modeling import models 7 | from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr 8 | 9 | parser = argparse.ArgumentParser(description='Main program for opengait.') 10 | parser.add_argument('--local_rank', type=int, default=0, 11 | help="passed by torch.distributed.launch module") 12 | parser.add_argument('--cfgs', type=str, 13 | default='E:/Desktop/gait recognition/OpenGait-master/configs/baseline/baseline.yaml', help="path of config file") 14 | parser.add_argument('--phase', default='train', 15 | choices=['train', 'test'], help="choose train or test phase") 16 | parser.add_argument('--log_to_file', action='store_true', 17 | help="log to file, default path is: output/////.txt") 18 | parser.add_argument('--iter', default=0, help="iter to restore") 19 | opt = parser.parse_args() 20 | 21 | 22 | def initialization(cfgs, training): 23 | msg_mgr = get_msg_mgr() 24 | engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg'] 25 | output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'], 26 | cfgs['model_cfg']['model'], engine_cfg['save_name']) 27 | if training: 28 | msg_mgr.init_manager(output_path, opt.log_to_file, engine_cfg['log_iter'], 29 | engine_cfg['restore_hint'] if isinstance(engine_cfg['restore_hint'], (int)) else 0) 30 | else: 31 | msg_mgr.init_logger(output_path, opt.log_to_file) 32 | 33 | msg_mgr.log_info(engine_cfg) 34 | 35 | seed = torch.distributed.get_rank() 36 | init_seeds(seed) 37 | 38 | 39 | def run_model(cfgs, training): 40 | msg_mgr = get_msg_mgr() 41 | model_cfg = cfgs['model_cfg'] 42 | msg_mgr.log_info(model_cfg) 43 | Model = getattr(models, model_cfg['model']) 44 | model = Model(cfgs, training) 45 | 46 | 47 | 48 | if training and cfgs['trainer_cfg']['sync_BN']: 49 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 50 | if cfgs['trainer_cfg']['fix_BN']: 51 | model.fix_BN() 52 | model = get_ddp_module(model) 53 | msg_mgr.log_info(params_count(model)) 54 | msg_mgr.log_info("Model Initialization Finished!") 55 | 56 | if training: 57 | Model.run_train(model) 58 | else: 59 | Model.run_test(model) 60 | 61 | 62 | if __name__ == '__main__': 63 | torch.distributed.init_process_group('nccl', init_method='env://', timeout = datetime.timedelta(seconds=18000)) 64 | if torch.distributed.get_world_size() != torch.cuda.device_count(): 65 | raise ValueError("Expect number of availuable GPUs({}) equals to the world size({}).".format( 66 | torch.cuda.device_count(), torch.distributed.get_world_size())) 67 | 68 | cfgs = config_loader(opt.cfgs) 69 | if opt.iter != 0: 70 | cfgs['evaluator_cfg']['restore_hint'] = int(opt.iter) 71 | cfgs['trainer_cfg']['restore_hint'] = int(opt.iter) 72 | 73 | 74 | 75 | training = (opt.phase == 'train') 76 | initialization(cfgs, training) 77 | run_model(cfgs, training) 78 | -------------------------------------------------------------------------------- /opengait/evaluation/re_rank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def re_ranking(original_dist, query_num, k1, k2, lambda_value): 5 | # Modified from https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py 6 | all_num = original_dist.shape[0] 7 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 8 | V = np.zeros_like(original_dist).astype(np.float16) 9 | initial_rank = np.argsort(original_dist).astype(np.int32) 10 | 11 | for i in range(all_num): 12 | # k-reciprocal neighbors 13 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 14 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 15 | fi = np.where(backward_k_neigh_index == i)[0] 16 | k_reciprocal_index = forward_k_neigh_index[fi] 17 | k_reciprocal_expansion_index = k_reciprocal_index 18 | for j in range(len(k_reciprocal_index)): 19 | candidate = k_reciprocal_index[j] 20 | candidate_forward_k_neigh_index = initial_rank[candidate, :int( 21 | np.around(k1 / 2)) + 1] 22 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 23 | :int(np.around(k1 / 2)) + 1] 24 | fi_candidate = np.where( 25 | candidate_backward_k_neigh_index == candidate)[0] 26 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 27 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 28 | candidate_k_reciprocal_index): 29 | k_reciprocal_expansion_index = np.append( 30 | k_reciprocal_expansion_index, candidate_k_reciprocal_index) 31 | 32 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 33 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 34 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 35 | original_dist = original_dist[:query_num, ] 36 | if k2 != 1: 37 | V_qe = np.zeros_like(V, dtype=np.float16) 38 | for i in range(all_num): 39 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 40 | V = V_qe 41 | del V_qe 42 | del initial_rank 43 | invIndex = [] 44 | for i in range(all_num): 45 | invIndex.append(np.where(V[:, i] != 0)[0]) 46 | 47 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 48 | 49 | for i in range(query_num): 50 | temp_min = np.zeros(shape=[1, all_num], dtype=np.float16) 51 | indNonZero = np.where(V[i, :] != 0)[0] 52 | indImages = [invIndex[ind] for ind in indNonZero] 53 | for j in range(len(indNonZero)): 54 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 55 | V[indImages[j], indNonZero[j]]) 56 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 57 | 58 | final_dist = jaccard_dist * (1 - lambda_value) + \ 59 | original_dist * lambda_value 60 | del original_dist 61 | del V 62 | del jaccard_dist 63 | final_dist = final_dist[:query_num, query_num:] 64 | return final_dist 65 | -------------------------------------------------------------------------------- /datasets/GREW/README.md: -------------------------------------------------------------------------------- 1 | # GREW Tutorial 2 | 3 | This is for [GREW-Benchmark](https://github.com/GREW-Benchmark/GREW-Benchmark). We report our result of 48% using the baseline model. In order for participants to better start the first step, we provide a tutorial on how to use OpenGait for GREW. 4 | 5 | ## Preprocess the dataset 6 | Download the raw dataset from the [official link](https://www.grew-benchmark.org/download.html). You will get three compressed files, i.e. `train.zip`, `test.zip` and `distractor.zip`. 7 | 8 | Step 1: Unzip train and test: 9 | ```shell 10 | unzip -P password train.zip (password is the obtained password) 11 | tar -xzvf train.tgz 12 | cd train 13 | ls *.tgz | xargs -n1 tar xzvf 14 | ``` 15 | 16 | ```shell 17 | unzip -P password test.zip (password is the obtained password) 18 | tar -xzvf test.tgz 19 | cd test & cd gallery 20 | ls *.tgz | xargs -n1 tar xzvf 21 | cd .. & cd probe 22 | ls *.tgz | xargs -n1 tar xzvf 23 | ``` 24 | 25 | After unpacking these compressed files, run this command: 26 | 27 | Step2 : To rearrange directory of GREW dataset, turning to id-type-view structure, Run 28 | ``` 29 | python datasets/GREW/rearrange_GREW.py --input_path Path_of_GREW-raw --output_path Path_of_GREW-rearranged 30 | ``` 31 | 32 | Step3: Transforming images to pickle file, run 33 | ``` 34 | python datasets/pretreatment.py --input_path Path_of_GREW-rearranged --output_path Path_of_GREW-pkl --dataset GREW 35 | ``` 36 | Then you will see the structure like: 37 | 38 | - Processed 39 | ``` 40 | GREW-pkl 41 | ├── 00001train (subject in training set) 42 | ├── 00 43 | ├── 4XPn5Z28 44 | ├── 4XPn5Z28.pkl 45 | ├──5TXe8svE 46 | ├── 5TXe8svE.pkl 47 | ...... 48 | ├── 00001 (subject in testing set) 49 | ├── 01 50 | ├── 79XJefi8 51 | ├── 79XJefi8.pkl 52 | ├── 02 53 | ├── t16VLaQf 54 | ├── t16VLaQf.pkl 55 | ├── probe 56 | ├── etaGVnWf 57 | ├── etaGVnWf.pkl 58 | ├── eT1EXpgZ 59 | ├── eT1EXpgZ.pkl 60 | ... 61 | ... 62 | ``` 63 | 64 | ## Train the dataset 65 | Modify the `dataset_root` in `./config/baseline/baseline_GREW.yaml`, and then run this command: 66 | ```shell 67 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 opengait/main.py --cfgs ./config/baseline/baseline_GREW.yaml --phase train 68 | ``` 69 | 70 | ## Get the submission file 71 | ```shell 72 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 opengait/main.py --cfgs ./config/baseline/baseline_GREW.yaml --phase test 73 | ``` 74 | The result will be generated in your working directory, you must rename and compress it as the requirements before submitting. 75 | 76 | ## Evaluation locally 77 | While the original grew treat both seq_01 and seq_02 as gallery, but there is no ground truth for probe. Therefore, it is nessesary to upload the submission file on grew competitation. We seperate test set to: seq_01 as gallery, seq_02 as probe. Then you can modify `eval_func` in the `./config/baseline/baseline_GREW.yaml` to `identification_real_scene`, you can obtain result localy like setting of OUMVLP. 78 | -------------------------------------------------------------------------------- /opengait/modeling/loss_aggregator.py: -------------------------------------------------------------------------------- 1 | """The loss aggregator.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from . import losses 6 | from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module 7 | from utils import Odict 8 | from utils import get_msg_mgr 9 | 10 | 11 | class LossAggregator(nn.Module): 12 | """The loss aggregator. 13 | 14 | This class is used to aggregate the losses. 15 | For example, if you have two losses, one is triplet loss, the other is cross entropy loss, 16 | you can aggregate them as follows: 17 | loss_num = tripley_loss + cross_entropy_loss 18 | 19 | Attributes: 20 | losses: A dict of losses. 21 | """ 22 | def __init__(self, loss_cfg) -> None: 23 | """ 24 | Initialize the loss aggregator. 25 | 26 | LossAggregator can be indexed like a regular Python dictionary, 27 | but modules it contains are properly registered, and will be visible by all Module methods. 28 | All parameters registered in losses can be accessed by the method 'self.parameters()', 29 | thus they can be trained properly. 30 | 31 | Args: 32 | loss_cfg: Config of losses. List for multiple losses. 33 | """ 34 | super().__init__() 35 | self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \ 36 | else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}) 37 | 38 | def _build_loss_(self, loss_cfg): 39 | """Build the losses from loss_cfg. 40 | 41 | Args: 42 | loss_cfg: Config of loss. 43 | """ 44 | Loss = get_attr_from([losses], loss_cfg['type']) 45 | valid_loss_arg = get_valid_args( 46 | Loss, loss_cfg, ['type', 'gather_and_scale']) 47 | loss = get_ddp_module(Loss(**valid_loss_arg).cuda()) 48 | return loss 49 | 50 | def forward(self, training_feats): 51 | """Compute the sum of all losses. 52 | 53 | The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in 54 | built losses and the value is torch.Tensor, then it is the computed loss to be added loss_sum. 55 | 56 | Args: 57 | training_feats: A dict of features. The same as the output["training_feat"] of the model. 58 | """ 59 | loss_sum = .0 60 | loss_info = Odict() 61 | 62 | for k, v in training_feats.items(): 63 | if k in self.losses: 64 | loss_func = self.losses[k] 65 | loss, info = loss_func(**v) 66 | for name, value in info.items(): 67 | loss_info['scalar/%s/%s' % (k, name)] = value 68 | loss = loss.mean() * loss_func.loss_term_weight 69 | loss_sum += loss 70 | 71 | else: 72 | if isinstance(v, dict): 73 | raise ValueError( 74 | "The key %s in -Trainng-Feat- should be stated in your loss_cfg as log_prefix."%k 75 | ) 76 | elif is_tensor(v): 77 | _ = v.mean() 78 | loss_info['scalar/%s' % k] = _ 79 | loss_sum += _ 80 | get_msg_mgr().log_debug( 81 | "Please check whether %s needed in training." % k) 82 | else: 83 | raise ValueError( 84 | "Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.") 85 | 86 | return loss_sum, loss_info 87 | -------------------------------------------------------------------------------- /configs/GLGait/GLGait_Gait3D.yaml: -------------------------------------------------------------------------------- 1 | data_cfg: 2 | dataset_name: Gait3D 3 | dataset_root: Your Path 4 | dataset_partition: ./datasets/Gait3D/Gait3D.json 5 | num_workers: 8 6 | remove_no_gallery: false # Remove probe if no gallery for it 7 | test_dataset_name: Gait3D 8 | 9 | evaluator_cfg: 10 | enable_float16: true 11 | restore_ckpt_strict: true 12 | restore_hint: 120000 13 | save_name: your path 14 | eval_func: evaluate_Gait3D 15 | sampler: 16 | batch_shuffle: false 17 | batch_size: 1 18 | sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered 19 | frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory 20 | metric: euc # cos 21 | transform: 22 | - type: BaseSilCuttingTransform 23 | 24 | loss_cfg: 25 | - loss_term_weight: 1.0 26 | margin: 0.2 27 | type: TripletLoss 28 | log_prefix: triplet 29 | - loss_term_weight: 1.0 30 | margin: 0.2 31 | start: 30000 32 | type: CTL 33 | log_prefix: ctl 34 | - loss_term_weight: 1.0 35 | scale: 16 36 | type: CrossEntropyLoss 37 | log_prefix: softmax 38 | log_accuracy: true 39 | 40 | 41 | 42 | model_cfg: 43 | model: Baseline_trans 44 | backbone_cfg: 45 | type: GLGait 46 | block: BasicBlock 47 | channels: # Layers configuration for automatically model construction 48 | - 64 49 | - 128 50 | - 256 51 | - 512 52 | layers: 53 | - 1 54 | - 4 55 | - 4 56 | - 1 57 | strides: 58 | - 1 59 | - 2 60 | - 2 61 | - 1 62 | maxpool: false 63 | SeparateFCs: 64 | in_channels: 512 65 | out_channels: 256 66 | parts_num: 16 67 | SeparateBNNecks: 68 | class_num: 3000 69 | in_channels: 256 70 | parts_num: 16 71 | bin_num: 72 | - 16 73 | 74 | 75 | optimizer_cfg: 76 | lr: 0.1 77 | momentum: 0.9 78 | solver: SGD 79 | weight_decay: 0.0005 80 | 81 | scheduler_cfg: 82 | gamma: 0.1 83 | milestones: # Learning Rate Reduction at each milestones 84 | - 40000 85 | - 80000 86 | - 100000 87 | scheduler: MultiStepLR 88 | 89 | trainer_cfg: 90 | enable_float16: true # half_percesion float for memory reduction and speedup 91 | fix_BN: false 92 | with_test: false 93 | log_iter: 100 94 | restore_ckpt_strict: true 95 | restore_hint: 0 96 | save_iter: 10000 97 | save_name: your name 98 | sync_BN: true 99 | total_iter: 120000 100 | sampler: 101 | batch_shuffle: true 102 | batch_size: 103 | - 32 # TripletSampler, batch_size[0] indicates Number of Identity 104 | - 4 # batch_size[1] indicates Samples sequqnce for each Identity 105 | frames_num_fixed: 30 # fixed frames number for training 106 | frames_num_max: 40 # max frames number for unfixed training 107 | frames_num_min: 20 # min frames number for unfixed traing 108 | frames_skip_num: 0 109 | sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered 110 | type: TripletSampler 111 | transform: 112 | - type: Compose 113 | trf_cfg: 114 | - type: RandomPerspective 115 | prob: 0.2 116 | - type: BaseSilCuttingTransform 117 | - type: RandomHorizontalFlip 118 | prob: 0.2 119 | - type: RandomRotate 120 | prob: 0.2 121 | -------------------------------------------------------------------------------- /datasets/GREW/rearrange_GREW.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | 6 | from tqdm import tqdm 7 | 8 | TOTAL_Test = 24000 9 | TOTAL_Train = 20000 10 | 11 | def rearrange_train(train_path: Path, output_path: Path) -> None: 12 | progress = tqdm(total=TOTAL_Train) 13 | for sid in train_path.iterdir(): 14 | if not sid.is_dir(): 15 | continue 16 | for sub_seq in sid.iterdir(): 17 | if not sub_seq.is_dir(): 18 | continue 19 | for subfile in os.listdir(sub_seq): 20 | src = os.path.join(train_path, sid.name, sub_seq.name) 21 | dst = os.path.join(output_path, sid.name+'train', '00', sub_seq.name) 22 | os.makedirs(dst,exist_ok=True) 23 | if subfile not in os.listdir(dst) and subfile.endswith('.png'): 24 | os.symlink(os.path.join(src, subfile), 25 | os.path.join(dst, subfile)) 26 | progress.update(1) 27 | 28 | def rearrange_test(test_path: Path, output_path: Path) -> None: 29 | # for gallery 30 | gallery = Path(os.path.join(test_path, 'gallery')) 31 | probe = Path(os.path.join(test_path, 'probe')) 32 | progress = tqdm(total=TOTAL_Test) 33 | for sid in gallery.iterdir(): 34 | if not sid.is_dir(): 35 | continue 36 | cnt = 1 37 | for sub_seq in sid.iterdir(): 38 | if not sub_seq.is_dir(): 39 | continue 40 | for subfile in sorted(os.listdir(sub_seq)): 41 | src = os.path.join(gallery, sid.name, sub_seq.name) 42 | dst = os.path.join(output_path, sid.name, '%02d'%cnt, sub_seq.name) 43 | os.makedirs(dst,exist_ok=True) 44 | if subfile not in os.listdir(dst) and subfile.endswith('.png'): 45 | os.symlink(os.path.join(src, subfile), 46 | os.path.join(dst, subfile)) 47 | cnt += 1 48 | progress.update(1) 49 | # for probe 50 | for sub_seq in probe.iterdir(): 51 | if not sub_seq.is_dir(): 52 | continue 53 | for subfile in os.listdir(sub_seq): 54 | src = os.path.join(probe, sub_seq.name) 55 | dst = os.path.join(output_path, 'probe', '03', sub_seq.name) 56 | os.makedirs(dst,exist_ok=True) 57 | if subfile not in os.listdir(dst) and subfile.endswith('.png'): 58 | os.symlink(os.path.join(src, subfile), 59 | os.path.join(dst, subfile)) 60 | progress.update(1) 61 | 62 | def rearrange_GREW(input_path: Path, output_path: Path) -> None: 63 | os.makedirs(output_path, exist_ok=True) 64 | 65 | for folder in input_path.iterdir(): 66 | if not folder.is_dir(): 67 | continue 68 | 69 | print(f'Rearranging {folder}') 70 | if folder.name == 'train': 71 | rearrange_train(folder,output_path) 72 | if folder.name == 'test': 73 | rearrange_test(folder, output_path) 74 | if folder.name == 'distractor': 75 | pass 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser(description='GREW rearrange tool') 80 | parser.add_argument('-i', '--input_path', required=True, type=str, 81 | help='Root path of raw dataset.') 82 | parser.add_argument('-o', '--output_path', default='GREW_rearranged', type=str, 83 | help='Root path for output.') 84 | 85 | args = parser.parse_args() 86 | 87 | input_path = Path(args.input_path).resolve() 88 | output_path = Path(args.output_path).resolve() 89 | rearrange_GREW(input_path, output_path) 90 | -------------------------------------------------------------------------------- /configs/GLGait/GLGait_GREW.yaml: -------------------------------------------------------------------------------- 1 | data_cfg: 2 | dataset_name: GREW 3 | dataset_root: your path 4 | dataset_partition: ./datasets/GREW/GREW.json 5 | num_workers: 4 6 | remove_no_gallery: false # Remove probe if no gallery for it 7 | test_dataset_name: GREW 8 | 9 | 10 | evaluator_cfg: 11 | enable_float16: true 12 | restore_ckpt_strict: true 13 | restore_hint: 180000 14 | save_name: GLGait-L 15 | eval_func: GREW_submission 16 | sampler: 17 | batch_shuffle: false 18 | batch_size: 1 19 | sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered 20 | frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory 21 | metric: euc # cos 22 | transform: 23 | - type: BaseSilCuttingTransform 24 | 25 | loss_cfg: 26 | - loss_term_weight: 1.0 27 | margin: 0.2 28 | type: TripletLoss 29 | log_prefix: triplet 30 | - loss_term_weight: 1.0 31 | margin: 0.2 32 | start: 50000 33 | type: CTL 34 | log_prefix: ctl 35 | - loss_term_weight: 1.0 36 | scale: 16 37 | type: CrossEntropyLoss 38 | log_prefix: softmax 39 | log_accuracy: true 40 | 41 | model_cfg: 42 | model: Baseline_trans 43 | backbone_cfg: 44 | type: GLGait 45 | block: BasicBlock 46 | channels: # Layers configuration for automatically model construction 47 | - 64 48 | - 128 49 | - 256 50 | - 512 51 | layers: 52 | - 1 53 | - 4 54 | - 4 55 | - 1 56 | strides: 57 | - 1 58 | - 2 59 | - 2 60 | - 1 61 | maxpool: false 62 | SeparateFCs: 63 | in_channels: 512 64 | out_channels: 256 65 | parts_num: 16 66 | SeparateBNNecks: 67 | class_num: 20000 68 | in_channels: 256 69 | parts_num: 16 70 | bin_num: 71 | - 16 72 | max_num: 73 | - 2 74 | 75 | #optimizer_cfg: 76 | # lr: 0.0003 77 | # solver: AdamW 78 | # weight_decay: 0.02 79 | # 80 | #scheduler_cfg: 81 | # T_max: 120000 82 | # eta_min: 0.00003 83 | # scheduler: CosineAnnealingLR 84 | 85 | optimizer_cfg: 86 | lr: 0.05 87 | momentum: 0.9 88 | solver: SGD 89 | weight_decay: 0.0005 90 | 91 | scheduler_cfg: 92 | gamma: 0.2 93 | milestones: # Learning Rate Reduction at each milestones 94 | - 60000 95 | - 120000 96 | - 150000 97 | scheduler: MultiStepLR 98 | 99 | trainer_cfg: 100 | enable_float16: true # half_percesion float for memory reduction and speedup 101 | fix_BN: false 102 | with_test: false 103 | log_iter: 100 104 | restore_ckpt_strict: true 105 | restore_hint: 0 106 | save_iter: 10000 107 | save_name: your name 108 | sync_BN: true 109 | total_iter: 180000 110 | sampler: 111 | batch_shuffle: true 112 | batch_size: 113 | - 32 # TripletSampler, batch_size[0] indicates Number of Identity 114 | - 4 # batch_size[1] indicates Samples sequqnce for each Identity 115 | frames_num_fixed: 30 # fixed frames number for training 116 | frames_num_max: 40 # max frames number for unfixed training 117 | frames_num_min: 20 # min frames number for unfixed traing 118 | frames_skip_num: 0 119 | sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered 120 | type: TripletSampler 121 | transform: 122 | - type: Compose 123 | trf_cfg: 124 | - type: RandomPerspective 125 | prob: 0.2 126 | - type: BaseSilCuttingTransform 127 | - type: RandomHorizontalFlip 128 | prob: 0.2 129 | - type: RandomRotate 130 | prob: 0.2 131 | -------------------------------------------------------------------------------- /opengait/data/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.distributed as dist 5 | import torch.utils.data as tordata 6 | 7 | 8 | class TripletSampler(tordata.sampler.Sampler): 9 | def __init__(self, dataset, batch_size, batch_shuffle=False): 10 | self.dataset = dataset 11 | self.batch_size = batch_size 12 | if len(self.batch_size) != 2: 13 | raise ValueError( 14 | "batch_size should be (P x K) not {}".format(batch_size)) 15 | self.batch_shuffle = batch_shuffle 16 | 17 | self.world_size = dist.get_world_size() 18 | if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0: 19 | raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format( 20 | self.world_size, batch_size[0], batch_size[1])) 21 | self.rank = dist.get_rank() 22 | 23 | def __iter__(self): 24 | while True: 25 | sample_indices = [] 26 | pid_list = sync_random_sample_list( 27 | self.dataset.label_set, self.batch_size[0]) 28 | 29 | for pid in pid_list: 30 | indices = self.dataset.indices_dict[pid] 31 | indices = sync_random_sample_list( 32 | indices, k=self.batch_size[1]) 33 | sample_indices += indices 34 | 35 | if self.batch_shuffle: 36 | sample_indices = sync_random_sample_list( 37 | sample_indices, len(sample_indices)) 38 | 39 | total_batch_size = self.batch_size[0] * self.batch_size[1] 40 | total_size = int(math.ceil(total_batch_size / 41 | self.world_size)) * self.world_size 42 | sample_indices += sample_indices[:( 43 | total_batch_size - len(sample_indices))] 44 | 45 | sample_indices = sample_indices[self.rank:total_size:self.world_size] 46 | yield sample_indices 47 | 48 | def __len__(self): 49 | return len(self.dataset) 50 | 51 | 52 | def sync_random_sample_list(obj_list, k): 53 | if len(obj_list) < k: 54 | idx = random.choices(range(len(obj_list)), k=k) 55 | idx = torch.tensor(idx) 56 | else: 57 | idx = torch.randperm(len(obj_list))[:k] 58 | if torch.cuda.is_available(): 59 | idx = idx.cuda() 60 | torch.distributed.broadcast(idx, src=0) 61 | idx = idx.tolist() 62 | return [obj_list[i] for i in idx] 63 | 64 | 65 | class InferenceSampler(tordata.sampler.Sampler): 66 | def __init__(self, dataset, batch_size): 67 | self.dataset = dataset 68 | self.batch_size = batch_size 69 | 70 | self.size = len(dataset) 71 | indices = list(range(self.size)) 72 | 73 | world_size = dist.get_world_size() 74 | rank = dist.get_rank() 75 | 76 | if batch_size % world_size != 0: 77 | raise ValueError("World size ({}) is not divisible by batch_size ({})".format( 78 | world_size, batch_size)) 79 | 80 | if batch_size != 1: 81 | complement_size = math.ceil(self.size / batch_size) * \ 82 | batch_size 83 | indices += indices[:(complement_size - self.size)] 84 | self.size = complement_size 85 | 86 | batch_size_per_rank = int(self.batch_size / world_size) 87 | indx_batch_per_rank = [] 88 | 89 | for i in range(int(self.size / batch_size_per_rank)): 90 | indx_batch_per_rank.append( 91 | indices[i*batch_size_per_rank:(i+1)*batch_size_per_rank]) 92 | 93 | self.idx_batch_this_rank = indx_batch_per_rank[rank::world_size] 94 | 95 | def __iter__(self): 96 | yield from self.idx_batch_this_rank 97 | 98 | def __len__(self): 99 | return len(self.dataset) 100 | -------------------------------------------------------------------------------- /opengait/utils/msg_manager.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | import numpy as np 5 | import torchvision.utils as vutils 6 | import os.path as osp 7 | from time import strftime, localtime 8 | 9 | from torch.utils.tensorboard import SummaryWriter 10 | from .common import is_list, is_tensor, ts2np, mkdir, Odict, NoOp 11 | import logging 12 | 13 | 14 | class MessageManager: 15 | def __init__(self): 16 | self.info_dict = Odict() 17 | self.writer_hparams = ['image', 'scalar'] 18 | self.time = time.time() 19 | 20 | def init_manager(self, save_path, log_to_file, log_iter, iteration=0): 21 | self.iteration = iteration 22 | self.log_iter = log_iter 23 | mkdir(osp.join(save_path, "summary/")) 24 | self.writer = SummaryWriter( 25 | osp.join(save_path, "summary/"), purge_step=self.iteration) 26 | self.init_logger(save_path, log_to_file) 27 | 28 | def init_logger(self, save_path, log_to_file): 29 | # init logger 30 | self.logger = logging.getLogger('opengait') 31 | self.logger.setLevel(logging.INFO) 32 | self.logger.propagate = False 33 | formatter = logging.Formatter( 34 | fmt='[%(asctime)s] [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 35 | if log_to_file: 36 | mkdir(osp.join(save_path, "logs/")) 37 | vlog = logging.FileHandler( 38 | osp.join(save_path, "logs/", strftime('%Y-%m-%d-%H-%M-%S', localtime())+'.txt')) 39 | vlog.setLevel(logging.INFO) 40 | vlog.setFormatter(formatter) 41 | self.logger.addHandler(vlog) 42 | 43 | console = logging.StreamHandler() 44 | console.setFormatter(formatter) 45 | console.setLevel(logging.DEBUG) 46 | self.logger.addHandler(console) 47 | 48 | def append(self, info): 49 | for k, v in info.items(): 50 | v = [v] if not is_list(v) else v 51 | v = [ts2np(_) if is_tensor(_) else _ for _ in v] 52 | info[k] = v 53 | self.info_dict.append(info) 54 | 55 | def flush(self): 56 | self.info_dict.clear() 57 | self.writer.flush() 58 | 59 | def write_to_tensorboard(self, summary): 60 | 61 | for k, v in summary.items(): 62 | module_name = k.split('/')[0] 63 | if module_name not in self.writer_hparams: 64 | self.log_warning( 65 | 'Not Expected --Summary-- type [{}] appear!!!{}'.format(k, self.writer_hparams)) 66 | continue 67 | board_name = k.replace(module_name + "/", '') 68 | writer_module = getattr(self.writer, 'add_' + module_name) 69 | v = v.detach() if is_tensor(v) else v 70 | v = vutils.make_grid( 71 | v, normalize=True, scale_each=True) if 'image' in module_name else v 72 | if module_name == 'scalar': 73 | try: 74 | v = v.mean() 75 | except: 76 | v = v 77 | writer_module(board_name, v, self.iteration) 78 | 79 | def log_training_info(self): 80 | now = time.time() 81 | string = "Iteration {:0>5}, Cost {:.2f}s".format( 82 | self.iteration, now-self.time, end="") 83 | for i, (k, v) in enumerate(self.info_dict.items()): 84 | if 'scalar' not in k: 85 | continue 86 | k = k.replace('scalar/', '').replace('/', '_') 87 | end = "\n" if i == len(self.info_dict)-1 else "" 88 | string += ", {0}={1:.4f}".format(k, np.mean(v), end=end) 89 | self.log_info(string) 90 | self.reset_time() 91 | 92 | def reset_time(self): 93 | self.time = time.time() 94 | 95 | def train_step(self, info, summary): 96 | self.iteration += 1 97 | self.append(info) 98 | if self.iteration % self.log_iter == 0: 99 | self.log_training_info() 100 | self.flush() 101 | self.write_to_tensorboard(summary) 102 | 103 | def log_debug(self, *args, **kwargs): 104 | self.logger.debug(*args, **kwargs) 105 | 106 | def log_info(self, *args, **kwargs): 107 | self.logger.info(*args, **kwargs) 108 | 109 | def log_warning(self, *args, **kwargs): 110 | self.logger.warning(*args, **kwargs) 111 | 112 | 113 | msg_mgr = MessageManager() 114 | noop = NoOp() 115 | 116 | 117 | def get_msg_mgr(): 118 | if torch.distributed.get_rank() > 0: 119 | return noop 120 | else: 121 | return msg_mgr 122 | -------------------------------------------------------------------------------- /opengait/data/collate_fn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | from utils import get_msg_mgr 5 | 6 | 7 | class CollateFn(object): 8 | def __init__(self, label_set, sample_config): 9 | self.label_set = label_set 10 | sample_type = sample_config['sample_type'] 11 | sample_type = sample_type.split('_') 12 | self.sampler = sample_type[0] 13 | self.ordered = sample_type[1] 14 | if self.sampler not in ['fixed', 'unfixed', 'all']: 15 | raise ValueError 16 | if self.ordered not in ['ordered', 'unordered']: 17 | raise ValueError 18 | self.ordered = sample_type[1] == 'ordered' 19 | 20 | # fixed cases 21 | if self.sampler == 'fixed': 22 | self.frames_num_fixed = sample_config['frames_num_fixed'] 23 | 24 | # unfixed cases 25 | if self.sampler == 'unfixed': 26 | self.frames_num_max = sample_config['frames_num_max'] 27 | self.frames_num_min = sample_config['frames_num_min'] 28 | 29 | if self.sampler != 'all' and self.ordered: 30 | self.frames_skip_num = sample_config['frames_skip_num'] 31 | 32 | self.frames_all_limit = -1 33 | if self.sampler == 'all' and 'frames_all_limit' in sample_config: 34 | self.frames_all_limit = sample_config['frames_all_limit'] 35 | 36 | def __call__(self, batch): 37 | batch_size = len(batch) 38 | # currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton. 39 | feature_num = len(batch[0][0]) 40 | seqs_batch, labs_batch, typs_batch, vies_batch, ids_batch = [], [], [], [], [] 41 | 42 | for bt in batch: 43 | seqs_batch.append(bt[0]) 44 | labs_batch.append(self.label_set.index(bt[1][0])) 45 | ids_batch.append(bt[1][0]) 46 | typs_batch.append(bt[1][1]) 47 | vies_batch.append(bt[1][2]) 48 | global count 49 | count = 0 50 | 51 | def sample_frames(seqs): 52 | global count 53 | sampled_fras = [[] for i in range(feature_num)] 54 | seq_len = len(seqs[0]) 55 | indices = list(range(seq_len)) 56 | 57 | if self.sampler in ['fixed', 'unfixed']: 58 | if self.sampler == 'fixed': 59 | frames_num = self.frames_num_fixed 60 | else: 61 | frames_num = random.choice( 62 | list(range(self.frames_num_min, self.frames_num_max+1))) 63 | 64 | if self.ordered: 65 | fs_n = frames_num + self.frames_skip_num 66 | if seq_len < fs_n: 67 | it = math.ceil(fs_n / seq_len) 68 | seq_len = seq_len * it 69 | indices = indices * it 70 | 71 | start = random.choice(list(range(0, seq_len - fs_n + 1))) 72 | end = start + fs_n 73 | idx_lst = list(range(seq_len)) 74 | idx_lst = idx_lst[start:end] 75 | idx_lst = sorted(np.random.choice( 76 | idx_lst, frames_num, replace=False)) 77 | indices = [indices[i] for i in idx_lst] 78 | else: 79 | replace = seq_len < frames_num 80 | 81 | if seq_len == 0: 82 | get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.' 83 | % (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count]))) 84 | 85 | count += 1 86 | indices = np.random.choice( 87 | indices, frames_num, replace=replace) 88 | 89 | for i in range(feature_num): 90 | for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices: 91 | sampled_fras[i].append(seqs[i][j]) 92 | return [sampled_fras, indices] 93 | 94 | # f: feature_num 95 | # b: batch_size 96 | # p: batch_size_per_gpu 97 | # g: gpus_num 98 | fras_index_batch = [sample_frames(seqs) for seqs in seqs_batch] # [b, f] 99 | fras_batch = [i[0] for i in fras_index_batch] 100 | index_batch = [i[1] for i in fras_index_batch] 101 | batch = [fras_batch, labs_batch, typs_batch, vies_batch, ids_batch, index_batch, None] 102 | 103 | if self.sampler == "fixed": 104 | fras_batch = [[np.asarray(fras_batch[i][j]) for i in range(batch_size)] 105 | for j in range(feature_num)] # [f, b] 106 | else: 107 | seqL_batch = [[len(fras_batch[i][0]) 108 | for i in range(batch_size)]] # [1, p] 109 | 110 | def my_cat(k): return np.concatenate( 111 | [fras_batch[i][k] for i in range(batch_size)], 0) 112 | fras_batch = [[my_cat(k)] for k in range(feature_num)] # [f, g] 113 | batch[-1] = np.asarray(seqL_batch) 114 | 115 | batch[0] = fras_batch 116 | return batch 117 | -------------------------------------------------------------------------------- /opengait/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch.utils.data as tordata 7 | import json 8 | from utils import get_msg_mgr 9 | import cv2 10 | 11 | class DataSet(tordata.Dataset): 12 | def __init__(self, data_cfg, training): 13 | """ 14 | seqs_info: the list with each element indicating 15 | a certain gait sequence presented as [label, type, view, paths]; 16 | """ 17 | self.__dataset_parser(data_cfg, training) 18 | self.cache = data_cfg['cache'] 19 | self.label_list = [seq_info[0] for seq_info in self.seqs_info] 20 | self.types_list = [seq_info[1] for seq_info in self.seqs_info] 21 | self.views_list = [seq_info[2] for seq_info in self.seqs_info] 22 | 23 | self.label_set = sorted(list(set(self.label_list))) 24 | self.types_set = sorted(list(set(self.types_list))) 25 | self.views_set = sorted(list(set(self.views_list))) 26 | self.seqs_data = [None] * len(self) 27 | self.indices_dict = {label: [] for label in self.label_set} 28 | for i, seq_info in enumerate(self.seqs_info): 29 | self.indices_dict[seq_info[0]].append(i) 30 | if self.cache: 31 | self.__load_all_data() 32 | 33 | def __len__(self): 34 | return len(self.seqs_info) 35 | 36 | def __loader__(self, paths): 37 | paths = sorted(paths) 38 | data_list = [] 39 | for pth in paths: 40 | if pth.endswith('.pkl'): 41 | with open(pth, 'rb') as f: 42 | _ = pickle.load(f) 43 | f.close() 44 | else: 45 | raise ValueError('- Loader - just support .pkl !!!') 46 | data_list.append(_) 47 | for idx, data in enumerate(data_list): 48 | if len(data) != len(data_list[0]): 49 | raise ValueError( 50 | 'Each input data({}) should have the same length.'.format(paths[idx])) 51 | if len(data) == 0: 52 | raise ValueError( 53 | 'Each input data({}) should have at least one element.'.format(paths[idx])) 54 | return data_list 55 | 56 | def __getitem__(self, idx): 57 | if not self.cache: 58 | data_list = self.__loader__(self.seqs_info[idx][-1]) 59 | elif self.seqs_data[idx] is None: 60 | data_list = self.__loader__(self.seqs_info[idx][-1]) 61 | self.seqs_data[idx] = data_list 62 | else: 63 | data_list = self.seqs_data[idx] 64 | seq_info = self.seqs_info[idx] 65 | return data_list, seq_info 66 | 67 | def __load_all_data(self): 68 | for idx in range(len(self)): 69 | self.__getitem__(idx) 70 | 71 | def __dataset_parser(self, data_config, training): 72 | dataset_root = data_config['dataset_root'] 73 | try: 74 | data_in_use = data_config['data_in_use'] # [n], true or false 75 | except: 76 | data_in_use = None 77 | 78 | with open(data_config['dataset_partition'], "rb") as f: 79 | partition = json.load(f) 80 | train_set = partition["TRAIN_SET"] 81 | test_set = partition["TEST_SET"] 82 | label_list = os.listdir(dataset_root) 83 | train_set = [label for label in train_set if label in label_list] 84 | test_set = [label for label in test_set if label in label_list] 85 | miss_pids = [label for label in label_list if label not in ( 86 | train_set + test_set)] 87 | msg_mgr = get_msg_mgr() 88 | 89 | def log_pid_list(pid_list): 90 | if len(pid_list) >= 3: 91 | msg_mgr.log_info('[%s, %s, ..., %s]' % 92 | (pid_list[0], pid_list[1], pid_list[-1])) 93 | else: 94 | msg_mgr.log_info(pid_list) 95 | 96 | if len(miss_pids) > 0: 97 | msg_mgr.log_debug('-------- Miss Pid List --------') 98 | msg_mgr.log_debug(miss_pids) 99 | if training: 100 | msg_mgr.log_info("-------- Train Pid List --------") 101 | log_pid_list(train_set) 102 | else: 103 | msg_mgr.log_info("-------- Test Pid List --------") 104 | log_pid_list(test_set) 105 | 106 | def get_seqs_info_list(label_set): 107 | seqs_info_list = [] 108 | for lab in label_set: 109 | for typ in sorted(os.listdir(osp.join(dataset_root, lab))): 110 | for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))): 111 | seq_info = [lab, typ, vie] 112 | seq_path = osp.join(dataset_root, *seq_info) 113 | seq_dirs = sorted(os.listdir(seq_path)) 114 | if seq_dirs != []: 115 | seq_dirs = [osp.join(seq_path, dir) 116 | for dir in seq_dirs] 117 | if data_in_use is not None: 118 | seq_dirs = [dir for dir, use_bl in zip( 119 | seq_dirs, data_in_use) if use_bl] 120 | seqs_info_list.append([*seq_info, seq_dirs]) 121 | else: 122 | msg_mgr.log_debug( 123 | 'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie)) 124 | return seqs_info_list 125 | 126 | self.seqs_info = get_seqs_info_list( 127 | train_set) if training else get_seqs_info_list(test_set) 128 | -------------------------------------------------------------------------------- /opengait/evaluation/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | from utils import is_tensor 6 | 7 | 8 | def cuda_dist(x, y, metric='euc'): 9 | x = torch.from_numpy(x).cuda() 10 | y = torch.from_numpy(y).cuda() 11 | if metric == 'cos': 12 | x = F.normalize(x, p=2, dim=1) # n c p 13 | y = F.normalize(y, p=2, dim=1) # n c p 14 | num_bin = x.size(2) 15 | n_x = x.size(0) 16 | n_y = y.size(0) 17 | dist = torch.zeros(n_x, n_y).cuda() 18 | for i in range(num_bin): 19 | _x = x[:, :, i] 20 | _y = y[:, :, i] 21 | if metric == 'cos': 22 | dist += torch.matmul(_x, _y.transpose(0, 1)) 23 | else: 24 | _dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze( 25 | 0) - 2 * torch.matmul(_x, _y.transpose(0, 1)) 26 | dist += torch.sqrt(F.relu(_dist)) 27 | return 1 - dist/num_bin if metric == 'cos' else dist / num_bin 28 | 29 | 30 | def mean_iou(msk1, msk2, eps=1.0e-9): 31 | if not is_tensor(msk1): 32 | msk1 = torch.from_numpy(msk1).cuda() 33 | if not is_tensor(msk2): 34 | msk2 = torch.from_numpy(msk2).cuda() 35 | n = msk1.size(0) 36 | inter = msk1 * msk2 37 | union = ((msk1 + msk2) > 0.).float() 38 | miou = inter.view(n, -1).sum(-1) / (union.view(n, -1).sum(-1) + eps) 39 | return miou 40 | 41 | 42 | def compute_ACC_mAP(distmat, q_pids, g_pids, q_views=None, g_views=None, rank=1): 43 | num_q, _ = distmat.shape 44 | # indices = np.argsort(distmat, axis=1) 45 | # matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 46 | 47 | 48 | all_ACC = [] 49 | all_AP = [] 50 | num_valid_q = 0. # number of valid query 51 | for q_idx in range(num_q): 52 | q_idx_dist = distmat[q_idx] 53 | q_idx_glabels = g_pids 54 | if q_views is not None and g_views is not None: 55 | q_idx_mask = np.isin(g_views, q_views[q_idx], invert=True) | np.isin( 56 | g_pids, q_pids[q_idx], invert=True) 57 | q_idx_dist = q_idx_dist[q_idx_mask] 58 | q_idx_glabels = q_idx_glabels[q_idx_mask] 59 | 60 | assert(len(q_idx_glabels) > 61 | 0), "No gallery after excluding identical-view cases!" 62 | q_idx_indices = np.argsort(q_idx_dist) 63 | q_idx_matches = (q_idx_glabels[q_idx_indices] 64 | == q_pids[q_idx]).astype(np.int32) 65 | 66 | # binary vector, positions with value 1 are correct matches 67 | # orig_cmc = matches[q_idx] 68 | orig_cmc = q_idx_matches 69 | cmc = orig_cmc.cumsum() 70 | cmc[cmc > 1] = 1 71 | all_ACC.append(cmc[rank-1]) 72 | 73 | # compute average precision 74 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 75 | num_rel = orig_cmc.sum() 76 | 77 | if num_rel > 0: 78 | num_valid_q += 1. 79 | tmp_cmc = orig_cmc.cumsum() 80 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 81 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 82 | AP = tmp_cmc.sum() / num_rel 83 | all_AP.append(AP) 84 | 85 | # all_ACC = np.asarray(all_ACC).astype(np.float32) 86 | ACC = np.mean(all_ACC) 87 | mAP = np.mean(all_AP) 88 | 89 | return ACC, mAP 90 | 91 | 92 | def evaluate_rank(distmat, p_lbls, g_lbls, max_rank=50): 93 | ''' 94 | Copy from https://github.com/Gait3D/Gait3D-Benchmark/blob/72beab994c137b902d826f4b9f9e95b107bebd78/lib/utils/rank.py#L12-L63 95 | ''' 96 | num_p, num_g = distmat.shape 97 | 98 | if num_g < max_rank: 99 | max_rank = num_g 100 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 101 | 102 | indices = np.argsort(distmat, axis=1) 103 | np.save('indices_p3datt.npy', indices) 104 | matches = (g_lbls[indices] == p_lbls[:, np.newaxis]).astype(np.int32) 105 | np.save('match_p3datt.npy',matches) 106 | # compute cmc curve for each probe 107 | all_cmc = [] 108 | all_AP = [] 109 | all_INP = [] 110 | num_valid_p = 0. # number of valid probe 111 | 112 | for p_idx in range(num_p): 113 | # compute cmc curve 114 | # binary vector, positions with value 1 are correct matches 115 | raw_cmc = matches[p_idx] #5369 116 | if not np.any(raw_cmc): 117 | # this condition is true when probe identity does not appear in gallery 118 | continue 119 | 120 | cmc = raw_cmc.cumsum() 121 | 122 | pos_idx = np.where(raw_cmc == 1) # 返回坐标,此处raw_cmc为一维矩阵,所以返回相当于index 123 | max_pos_idx = np.max(pos_idx) 124 | inp = cmc[max_pos_idx] / (max_pos_idx + 1.0) 125 | all_INP.append(inp) 126 | 127 | cmc[cmc > 1] = 1 128 | 129 | all_cmc.append(cmc[:max_rank]) 130 | num_valid_p += 1. 131 | 132 | # compute average precision 133 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 134 | num_rel = raw_cmc.sum() 135 | pos_idx = np.where(raw_cmc == 1) # 返回坐标,此处raw_cmc为一维矩阵,所以返回相当于index 136 | max_pos_idx = np.max(pos_idx) 137 | inp = cmc[max_pos_idx] / (max_pos_idx + 1.0) 138 | all_INP.append(inp) 139 | 140 | cmc[cmc > 1] = 1 141 | 142 | all_cmc.append(cmc[:max_rank]) 143 | num_valid_p += 1. 144 | 145 | # compute average precision 146 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 147 | num_rel = raw_cmc.sum() 148 | tmp_cmc = raw_cmc.cumsum() 149 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 150 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 151 | AP = tmp_cmc.sum() / num_rel 152 | all_AP.append(AP) 153 | 154 | assert num_valid_p > 0, 'Error: all probe identities do not appear in gallery' 155 | 156 | all_cmc = np.asarray(all_cmc).astype(np.float32) 157 | all_cmc = all_cmc.sum(0) / num_valid_p 158 | 159 | return all_cmc, all_AP, all_INP 160 | -------------------------------------------------------------------------------- /opengait/modeling/losses/triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .base import BaseLoss, gather_and_scale_wrapper 6 | 7 | 8 | class TripletLoss(BaseLoss): 9 | def __init__(self, margin, loss_term_weight=1.0): 10 | super(TripletLoss, self).__init__(loss_term_weight) 11 | self.margin = margin 12 | @gather_and_scale_wrapper 13 | def forward(self, embeddings, labels): 14 | # embeddings: [n, c, p], label: [n] 15 | 16 | embeddings = embeddings.permute( 17 | 2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c] 18 | 19 | ref_embed, ref_label = embeddings, labels 20 | dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2] 21 | mean_dist = dist.mean((1, 2)) # [p] 22 | ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist) 23 | mean_ap_dist = ap_dist.mean((1,2,3)) 24 | mean_an_dist = an_dist.mean((1,2,3)) 25 | dist_diff = (ap_dist - an_dist).view(dist.size(0), -1) 26 | loss = F.relu(dist_diff + self.margin) 27 | 28 | hard_loss = torch.max(loss, -1)[0] 29 | loss_avg, loss_num = self.AvgNonZeroReducer(loss) 30 | 31 | self.info.update({ 32 | 'loss': loss_avg.detach().clone(), 33 | 'hard_loss': hard_loss.detach().clone(), 34 | 'loss_num': loss_num.detach().clone(), 35 | 'mean_dist': mean_dist.detach().clone(), 36 | 'mean_ap_dist': mean_ap_dist.detach().clone(), 37 | 'mean_an_dist': mean_an_dist.detach().clone(),}) 38 | 39 | return loss_avg, self.info 40 | 41 | def AvgNonZeroReducer(self, loss): 42 | eps = 1.0e-9 43 | loss_sum = loss.sum(-1) 44 | loss_num = (loss != 0).sum(-1).float() 45 | 46 | loss_avg = loss_sum / (loss_num + eps) 47 | loss_avg[loss_num == 0] = 0 48 | return loss_avg, loss_num 49 | 50 | def ComputeDistance(self, x, y): 51 | """ 52 | x: [p, n_x, c] 53 | y: [p, n_y, c] 54 | """ 55 | x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1] 56 | y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y] 57 | inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y] 58 | dist = x2 + y2 - 2 * inner 59 | dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y] 60 | return dist 61 | 62 | def Convert2Triplets(self, row_labels, clo_label, dist): 63 | """ 64 | row_labels: tensor with size [n_r] 65 | clo_label : tensor with size [n_c] 66 | """ 67 | matches = (row_labels.unsqueeze(1) == 68 | clo_label.unsqueeze(0)).bool() # [n_r, n_c] 69 | diffenc = torch.logical_not(matches) # [n_r, n_c] 70 | p, n, _ = dist.size() 71 | ap_dist = dist[:, matches].view(p, n, -1, 1) # [n, p, postive, 1] 72 | an_dist = dist[:, diffenc].view(p, n, 1, -1) # [n, p ,1, negative] 73 | return ap_dist, an_dist 74 | 75 | 76 | 77 | class CTL(BaseLoss): 78 | def __init__(self, margin, loss_term_weight=1.0, start=30000,): 79 | super(CTL, self).__init__(loss_term_weight) 80 | self.margin = margin 81 | self.count = 0 82 | self.start = start 83 | self.eps = 1e-3 84 | 85 | @gather_and_scale_wrapper 86 | def forward(self, embeddings, labels, bnn): 87 | # embeddings: [n, c, p], label: [n], bnn: [p, n, c] 88 | self.centers = bnn.permute(2, 0, 1).contiguous().float() 89 | embeddings = embeddings.permute(2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c] 90 | 91 | ref_embed, ref_label = embeddings, labels 92 | 93 | distmat = torch.pow(embeddings, 2).sum(dim=2) + torch.pow(self.centers, 2).sum(dim=2) - ( 94 | 2 * embeddings * self.centers).sum(dim=2) # [p,n] 95 | distmat = torch.sqrt(F.relu(distmat)) 96 | dist_d = distmat.mean() 97 | 98 | # embeddings ? 99 | if self.count >= self.start: 100 | dist = self.ComputeDistance(embeddings, ref_embed) 101 | mean_dist = dist.mean((1, 2)) 102 | ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist, distmat) # 8, 124 103 | 104 | else: 105 | self.count += 1 106 | if self.count == self.start: 107 | print("---------------------------starting!---------------------------") 108 | dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2] 109 | mean_dist = dist.mean((1, 2)) # [p] 110 | ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist) 111 | 112 | dist_diff = (ap_dist - an_dist).view(dist.size(0), -1) # dist_diff 113 | dist_d /= 128 114 | loss = F.relu(dist_diff + self.margin) #+ dist_d) 115 | 116 | hard_loss = torch.max(loss, -1)[0] 117 | loss_avg, loss_num = self.AvgNonZeroReducer(loss) 118 | 119 | self.info.update({ 120 | 'loss': loss_avg.detach().clone(), 121 | 'hard_loss': hard_loss.detach().clone(), 122 | 'center_loss': dist_d.detach().clone(), 123 | 'loss_num': loss_num.detach().clone(), 124 | 'mean_dist': mean_dist.detach().clone()}) 125 | 126 | return loss_avg, self.info 127 | 128 | def AvgNonZeroReducer(self, loss): 129 | eps = 1.0e-9 130 | loss_sum = loss.sum(-1) 131 | loss_num = (loss != 0).sum(-1).float() 132 | 133 | loss_avg = loss_sum / (loss_num + eps) 134 | loss_avg[loss_num == 0] = 0 135 | return loss_avg, loss_num 136 | 137 | def ComputeDistance(self, x, y): 138 | """ 139 | x: [p, n_x, c] embeddings 140 | y: [p, n_y, c] 141 | """ 142 | 143 | x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1] 144 | y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y] 145 | inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y] 146 | dist = x2 + y2 - 2 * inner 147 | dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y] 148 | return dist 149 | def Convert2Triplets(self, row_labels, clo_label, dist, dist_d=None): 150 | """ 151 | row_labels: tensor with size [n_r] 152 | clo_label : tensor with size [n_c] 153 | """ 154 | # 155 | matches = (row_labels.unsqueeze(1) == 156 | clo_label.unsqueeze(0)).bool() # [n_r, n_c] [128, 128] 157 | diffenc = torch.logical_not(matches) # [n_r, n_c] 158 | p, n, _ = dist.size() # [p, n, n] 159 | ap_dist = dist[:, matches].view(p, n, -1, 1) # [p, n, 4, 1] 160 | an_dist = dist[:, diffenc].view(p, n, 1, -1) # [p, n, 1, 124] 161 | if dist_d != None: 162 | ap_dist = torch.cat([ap_dist, dist_d.reshape(p,n,1,1)], dim=2) # # [p, n, 5, 1] 163 | return ap_dist, an_dist 164 | -------------------------------------------------------------------------------- /opengait/utils/common.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import inspect 4 | import logging 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.autograd as autograd 9 | import yaml 10 | import random 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | from collections import OrderedDict, namedtuple 13 | # from fvcore.nn import FlopCountAnalysis, parameter_count_table 14 | # from thop import profile 15 | # from ptflops import get_model_complexity_info 16 | class NoOp: 17 | def __getattr__(self, *args): 18 | def no_op(*args, **kwargs): pass 19 | return no_op 20 | 21 | 22 | class Odict(OrderedDict): 23 | def append(self, odict): 24 | dst_keys = self.keys() 25 | for k, v in odict.items(): 26 | if not is_list(v): 27 | v = [v] 28 | if k in dst_keys: 29 | if is_list(self[k]): 30 | self[k] += v 31 | else: 32 | self[k] = [self[k]] + v 33 | else: 34 | self[k] = v 35 | 36 | 37 | def Ntuple(description, keys, values): 38 | if not is_list_or_tuple(keys): 39 | keys = [keys] 40 | values = [values] 41 | Tuple = namedtuple(description, keys) 42 | return Tuple._make(values) 43 | 44 | 45 | def get_valid_args(obj, input_args, free_keys=[]): 46 | if inspect.isfunction(obj): 47 | expected_keys = inspect.getfullargspec(obj)[0] 48 | elif inspect.isclass(obj): 49 | expected_keys = inspect.getfullargspec(obj.__init__)[0] 50 | else: 51 | raise ValueError('Just support function and class object!') 52 | unexpect_keys = list() 53 | expected_args = {} 54 | for k, v in input_args.items(): 55 | if k in expected_keys: 56 | expected_args[k] = v 57 | elif k in free_keys: 58 | pass 59 | else: 60 | unexpect_keys.append(k) 61 | if unexpect_keys != []: 62 | logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" % 63 | (', '.join(unexpect_keys), obj.__name__)) 64 | return expected_args 65 | 66 | 67 | def get_attr_from(sources, name): 68 | try: 69 | return getattr(sources[0], name) 70 | except: 71 | return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name) 72 | 73 | 74 | def is_list_or_tuple(x): 75 | return isinstance(x, (list, tuple)) 76 | 77 | 78 | def is_bool(x): 79 | return isinstance(x, bool) 80 | 81 | 82 | def is_str(x): 83 | return isinstance(x, str) 84 | 85 | 86 | def is_list(x): 87 | return isinstance(x, list) or isinstance(x, nn.ModuleList) 88 | 89 | 90 | def is_dict(x): 91 | return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict) 92 | 93 | 94 | def is_tensor(x): 95 | return isinstance(x, torch.Tensor) 96 | 97 | 98 | def is_array(x): 99 | return isinstance(x, np.ndarray) 100 | 101 | 102 | def ts2np(x): 103 | return x.cpu().data.numpy() 104 | 105 | 106 | def ts2var(x, **kwargs): 107 | return autograd.Variable(x, **kwargs).cuda() 108 | 109 | 110 | def np2var(x, **kwargs): 111 | return ts2var(torch.from_numpy(x), **kwargs) 112 | 113 | 114 | def list2var(x, **kwargs): 115 | return np2var(np.array(x), **kwargs) 116 | 117 | 118 | def mkdir(path): 119 | if not os.path.exists(path): 120 | os.makedirs(path) 121 | 122 | 123 | def MergeCfgsDict(src, dst): 124 | for k, v in src.items(): 125 | if (k not in dst.keys()) or (type(v) != type(dict())): 126 | dst[k] = v 127 | else: 128 | if is_dict(src[k]) and is_dict(dst[k]): 129 | MergeCfgsDict(src[k], dst[k]) 130 | else: 131 | dst[k] = v 132 | 133 | 134 | def clones(module, N): 135 | "Produce N identical layers." 136 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 137 | 138 | 139 | def config_loader(path): 140 | with open(path, 'r') as stream: 141 | src_cfgs = yaml.safe_load(stream) 142 | with open("./configs/default.yaml", 'r') as stream: 143 | dst_cfgs = yaml.safe_load(stream) 144 | MergeCfgsDict(src_cfgs, dst_cfgs) 145 | return dst_cfgs 146 | 147 | 148 | def init_seeds(seed=0, cuda_deterministic=True): 149 | random.seed(seed) 150 | np.random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.cuda.manual_seed_all(seed) 153 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 154 | if cuda_deterministic: # slower, more reproducible 155 | torch.backends.cudnn.deterministic = True 156 | torch.backends.cudnn.benchmark = False 157 | else: # faster, less reproducible 158 | torch.backends.cudnn.deterministic = False 159 | torch.backends.cudnn.benchmark = True 160 | 161 | 162 | def handler(signum, frame): 163 | logging.info('Ctrl+c/z pressed') 164 | os.system( 165 | "kill $(ps aux | grep main.py | grep -v grep | awk '{print $2}') ") 166 | logging.info('process group flush!') 167 | 168 | 169 | def ddp_all_gather(features, dim=0, requires_grad=True): 170 | ''' 171 | inputs: [n, ...] 172 | ''' 173 | 174 | world_size = torch.distributed.get_world_size() 175 | rank = torch.distributed.get_rank() 176 | feature_list = [torch.ones_like(features) for _ in range(world_size)] 177 | torch.distributed.all_gather(feature_list, features.contiguous()) 178 | 179 | if requires_grad: 180 | feature_list[rank] = features 181 | feature = torch.cat(feature_list, dim=dim) 182 | return feature 183 | 184 | 185 | # https://github.com/pytorch/pytorch/issues/16885 186 | class DDPPassthrough(DDP): 187 | def __getattr__(self, name): 188 | try: 189 | return super().__getattr__(name) 190 | except AttributeError: 191 | return getattr(self.module, name) 192 | 193 | 194 | def get_ddp_module(module, **kwargs): 195 | if len(list(module.parameters())) == 0: 196 | # for the case that loss module has not parameters. 197 | return module 198 | device = torch.cuda.current_device() 199 | module = DDPPassthrough(module, device_ids=[device], output_device=device, 200 | find_unused_parameters=True, **kwargs) 201 | return module 202 | 203 | 204 | def params_count(net): 205 | n_parameters = sum(p.numel() for p in net.parameters()) 206 | tensor = torch.randn(1, 1, 1, 64, 44) 207 | # tensor = torch.rand(1, 1, 64, 44).cuda() 208 | # tensor = ([[tensor], None,None,None,None]) 209 | 210 | # 分析FLOPs 211 | #flops = FlopCountAnalysis(net.Backbone.cuda(), tensor.cuda()) 212 | 213 | #flops = FlopCountAnalysis(net.cuda(), tensor) 214 | #macs, params = profile(net.Backbone.cuda(), inputs=(tensor.cuda(),)) 215 | #return 'Parameters Count: {:.5f}M MACs Count: {:.5f}G'.format(n_parameters / 1e6, macs/1e9) 216 | return 'Parameters Count: {:.5f}M MACs Count: {:.5f}G'.format(n_parameters / 1e6, 0) 217 | -------------------------------------------------------------------------------- /opengait/data/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torchvision.transforms as T 4 | import cv2 5 | import math 6 | from data import transform as base_transform 7 | from utils import is_list, is_dict, get_valid_args 8 | 9 | 10 | class NoOperation(): 11 | def __call__(self, x): 12 | return x 13 | 14 | 15 | class BaseSilTransform(): 16 | def __init__(self, divsor=255.0, img_shape=None): 17 | self.divsor = divsor 18 | self.img_shape = img_shape 19 | 20 | def __call__(self, x): 21 | if self.img_shape is not None: 22 | s = x.shape[0] 23 | _ = [s] + [*self.img_shape] 24 | x = x.reshape(*_) 25 | return x / self.divsor 26 | 27 | class BaseParsingCuttingTransform(): 28 | def __init__(self, divsor=255.0, cutting=None): 29 | self.divsor = divsor 30 | self.cutting = cutting 31 | 32 | def __call__(self, x): 33 | if self.cutting is not None: 34 | cutting = self.cutting 35 | else: 36 | cutting = int(x.shape[-1] // 64) * 10 37 | if cutting != 0: 38 | x = x[..., cutting:-cutting] 39 | if x.max() == 255 or x.max() == 255.: 40 | return x / self.divsor 41 | else: 42 | return x / 1.0 43 | 44 | class BaseSilCuttingTransform(): 45 | def __init__(self, divsor=255.0, cutting=None): 46 | self.divsor = divsor 47 | self.cutting = cutting 48 | 49 | def __call__(self, x): 50 | if self.cutting is not None: 51 | cutting = self.cutting 52 | else: 53 | cutting = int(x.shape[-1] // 64) * 10 54 | if x.shape[-1]!=44: 55 | x = x[..., cutting:-cutting] 56 | return x / self.divsor 57 | 58 | class BaseRgbTransform(): 59 | def __init__(self, mean=None, std=None): 60 | if mean is None: 61 | mean = [0.485*255, 0.456*255, 0.406*255] 62 | if std is None: 63 | std = [0.229*255, 0.224*255, 0.225*255] 64 | self.mean = np.array(mean).reshape((1, 3, 1, 1)) 65 | self.std = np.array(std).reshape((1, 3, 1, 1)) 66 | 67 | def __call__(self, x): 68 | return (x - self.mean) / self.std 69 | 70 | 71 | # **************** Data Agumentation **************** 72 | 73 | 74 | class RandomHorizontalFlip(object): 75 | def __init__(self, prob=0.5): 76 | self.prob = prob 77 | 78 | def __call__(self, seq): 79 | if random.uniform(0, 1) >= self.prob: 80 | return seq 81 | else: 82 | return seq[..., ::-1] 83 | 84 | 85 | class RandomErasing(object): 86 | def __init__(self, prob=0.5, sl=0.05, sh=0.2, r1=0.3, per_frame=False): 87 | self.prob = prob 88 | self.sl = sl 89 | self.sh = sh 90 | self.r1 = r1 91 | self.per_frame = per_frame 92 | 93 | def __call__(self, seq): 94 | if not self.per_frame: 95 | if random.uniform(0, 1) >= self.prob: 96 | return seq 97 | else: 98 | for _ in range(100): 99 | seq_size = seq.shape 100 | area = seq_size[1] * seq_size[2] 101 | 102 | target_area = random.uniform(self.sl, self.sh) * area 103 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 104 | 105 | h = int(round(math.sqrt(target_area * aspect_ratio))) 106 | w = int(round(math.sqrt(target_area / aspect_ratio))) 107 | 108 | if w < seq_size[2] and h < seq_size[1]: 109 | x1 = random.randint(0, seq_size[1] - h) 110 | y1 = random.randint(0, seq_size[2] - w) 111 | seq[:, x1:x1+h, y1:y1+w] = 0. 112 | return seq 113 | return seq 114 | else: 115 | self.per_frame = False 116 | frame_num = seq.shape[0] 117 | ret = [self.__call__(seq[k][np.newaxis, ...]) 118 | for k in range(frame_num)] 119 | self.per_frame = True 120 | return np.concatenate(ret, 0) 121 | 122 | 123 | class RandomRotate(object): 124 | def __init__(self, prob=0.5, degree=10): 125 | self.prob = prob 126 | self.degree = degree 127 | 128 | def __call__(self, seq): 129 | if random.uniform(0, 1) >= self.prob: 130 | return seq 131 | else: 132 | _, dh, dw = seq.shape 133 | # rotation 134 | degree = random.uniform(-self.degree, self.degree) 135 | M1 = cv2.getRotationMatrix2D((dh // 2, dw // 2), degree, 1) 136 | # affine 137 | seq = [cv2.warpAffine(_[0, ...], M1, (dw, dh)) 138 | for _ in np.split(seq, seq.shape[0], axis=0)] 139 | seq = np.concatenate([np.array(_)[np.newaxis, ...] 140 | for _ in seq], 0) 141 | return seq 142 | 143 | 144 | class RandomPerspective(object): 145 | def __init__(self, prob=0.5): 146 | self.prob = prob 147 | 148 | def __call__(self, seq): 149 | if random.uniform(0, 1) >= self.prob: 150 | return seq 151 | else: 152 | _, h, w = seq.shape 153 | cutting = int(w // 44) * 10 154 | x_left = list(range(0, cutting)) 155 | x_right = list(range(w - cutting, w)) 156 | TL = (random.choice(x_left), 0) 157 | TR = (random.choice(x_right), 0) 158 | BL = (random.choice(x_left), h) 159 | BR = (random.choice(x_right), h) 160 | srcPoints = np.float32([TL, TR, BR, BL]) 161 | canvasPoints = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 162 | perspectiveMatrix = cv2.getPerspectiveTransform( 163 | np.array(srcPoints), np.array(canvasPoints)) 164 | seq = [cv2.warpPerspective(_[0, ...], perspectiveMatrix, (w, h)) 165 | for _ in np.split(seq, seq.shape[0], axis=0)] 166 | seq = np.concatenate([np.array(_)[np.newaxis, ...] 167 | for _ in seq], 0) 168 | return seq 169 | 170 | class RandomPerspective_128(object): 171 | def __init__(self, prob=0.5): 172 | self.prob = prob 173 | 174 | def __call__(self, seq): 175 | if random.uniform(0, 1) >= self.prob: 176 | return seq 177 | else: 178 | _, h, w = seq.shape 179 | cutting = int(w // 88) * 10 180 | x_left = list(range(0, cutting)) 181 | x_right = list(range(w - cutting, w)) 182 | TL = (random.choice(x_left), 0) 183 | TR = (random.choice(x_right), 0) 184 | BL = (random.choice(x_left), h) 185 | BR = (random.choice(x_right), h) 186 | srcPoints = np.float32([TL, TR, BR, BL]) 187 | canvasPoints = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 188 | perspectiveMatrix = cv2.getPerspectiveTransform( 189 | np.array(srcPoints), np.array(canvasPoints)) 190 | seq = [cv2.warpPerspective(_[0, ...], perspectiveMatrix, (w, h)) 191 | for _ in np.split(seq, seq.shape[0], axis=0)] 192 | seq = np.concatenate([np.array(_)[np.newaxis, ...] 193 | for _ in seq], 0) 194 | return seq 195 | 196 | class RandomAffine(object): 197 | def __init__(self, prob=0.5, degree=10): 198 | self.prob = prob 199 | self.degree = degree 200 | 201 | def __call__(self, seq): 202 | if random.uniform(0, 1) >= self.prob: 203 | return seq 204 | else: 205 | _, dh, dw = seq.shape 206 | # rotation 207 | max_shift = int(dh // 64 * 10) 208 | shift_range = list(range(0, max_shift)) 209 | pts1 = np.float32([[random.choice(shift_range), random.choice(shift_range)], [ 210 | dh-random.choice(shift_range), random.choice(shift_range)], [random.choice(shift_range), dw-random.choice(shift_range)]]) 211 | pts2 = np.float32([[random.choice(shift_range), random.choice(shift_range)], [ 212 | dh-random.choice(shift_range), random.choice(shift_range)], [random.choice(shift_range), dw-random.choice(shift_range)]]) 213 | M1 = cv2.getAffineTransform(pts1, pts2) 214 | # affine 215 | seq = [cv2.warpAffine(_[0, ...], M1, (dw, dh)) 216 | for _ in np.split(seq, seq.shape[0], axis=0)] 217 | seq = np.concatenate([np.array(_)[np.newaxis, ...] 218 | for _ in seq], 0) 219 | return seq 220 | 221 | # ****************************************** 222 | 223 | def Compose(trf_cfg): 224 | assert is_list(trf_cfg) 225 | transform = T.Compose([get_transform(cfg) for cfg in trf_cfg]) 226 | return transform 227 | 228 | 229 | def get_transform(trf_cfg=None): 230 | if is_dict(trf_cfg): 231 | transform = getattr(base_transform, trf_cfg['type']) 232 | valid_trf_arg = get_valid_args(transform, trf_cfg, ['type']) 233 | return transform(**valid_trf_arg) 234 | if trf_cfg is None: 235 | return lambda x: x 236 | if is_list(trf_cfg): 237 | transform = [get_transform(cfg) for cfg in trf_cfg] 238 | return transform 239 | raise "Error type for -Transform-Cfg-" 240 | -------------------------------------------------------------------------------- /opengait/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import strftime, localtime 3 | import numpy as np 4 | from utils import get_msg_mgr, mkdir 5 | 6 | from .metric import mean_iou, cuda_dist, compute_ACC_mAP, evaluate_rank 7 | from .re_rank import re_ranking 8 | 9 | 10 | def de_diag(acc, each_angle=False): 11 | # Exclude identical-view cases 12 | dividend = acc.shape[1] - 1. 13 | result = np.sum(acc - np.diag(np.diag(acc)), 1) / dividend 14 | if not each_angle: 15 | result = np.mean(result) 16 | return result 17 | 18 | 19 | def cross_view_gallery_evaluation(feature, label, seq_type, view, dataset, metric): 20 | '''More details can be found: More details can be found in 21 | [A Comprehensive Study on the Evaluation of Silhouette-based Gait Recognition](https://ieeexplore.ieee.org/document/9928336). 22 | ''' 23 | probe_seq_dict = {'CASIA-B': {'NM': ['nm-01'], 'BG': ['bg-01'], 'CL': ['cl-01']}, 24 | 'OUMVLP': {'NM': ['00']}} 25 | 26 | gallery_seq_dict = {'CASIA-B': ['nm-02', 'bg-02', 'cl-02'], 27 | 'OUMVLP': ['01']} 28 | 29 | msg_mgr = get_msg_mgr() 30 | acc = {} 31 | mean_ap = {} 32 | view_list = sorted(np.unique(view)) 33 | for (type_, probe_seq) in probe_seq_dict[dataset].items(): 34 | acc[type_] = np.zeros(len(view_list)) - 1. 35 | mean_ap[type_] = np.zeros(len(view_list)) - 1. 36 | for (v1, probe_view) in enumerate(view_list): 37 | pseq_mask = np.isin(seq_type, probe_seq) & np.isin( 38 | view, probe_view) 39 | probe_x = feature[pseq_mask, :] 40 | probe_y = label[pseq_mask] 41 | gseq_mask = np.isin(seq_type, gallery_seq_dict[dataset]) 42 | gallery_y = label[gseq_mask] 43 | gallery_x = feature[gseq_mask, :] 44 | dist = cuda_dist(probe_x, gallery_x, metric) 45 | eval_results = compute_ACC_mAP( 46 | dist.cpu().numpy(), probe_y, gallery_y, view[pseq_mask], view[gseq_mask]) 47 | acc[type_][v1] = np.round(eval_results[0] * 100, 2) 48 | mean_ap[type_][v1] = np.round(eval_results[1] * 100, 2) 49 | 50 | result_dict = {} 51 | msg_mgr.log_info( 52 | '===Cross View Gallery Evaluation (Excluded identical-view cases)===') 53 | out_acc_str = "========= Rank@1 Acc =========\n" 54 | out_map_str = "============= mAP ============\n" 55 | for type_ in probe_seq_dict[dataset].keys(): 56 | avg_acc = np.mean(acc[type_]) 57 | avg_map = np.mean(mean_ap[type_]) 58 | result_dict[f'scalar/test_accuracy/{type_}-Rank@1'] = avg_acc 59 | result_dict[f'scalar/test_accuracy/{type_}-mAP'] = avg_map 60 | out_acc_str += f"{type_}:\t{acc[type_]}, mean: {avg_acc:.2f}%\n" 61 | out_map_str += f"{type_}:\t{mean_ap[type_]}, mean: {avg_map:.2f}%\n" 62 | # msg_mgr.log_info(f'========= Rank@1 Acc =========') 63 | msg_mgr.log_info(f'{out_acc_str}') 64 | # msg_mgr.log_info(f'========= mAP =========') 65 | msg_mgr.log_info(f'{out_map_str}') 66 | return result_dict 67 | 68 | # Modified From https://github.com/AbnerHqC/GaitSet/blob/master/model/utils/evaluator.py 69 | 70 | 71 | def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metric): 72 | probe_seq_dict = {'CASIA-B': {'NM': ['nm-05', 'nm-06'], 'BG': ['bg-01', 'bg-02'], 'CL': ['cl-01', 'cl-02']}, 73 | 'OUMVLP': {'NM': ['00']}, 74 | 'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2',], 75 | 'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'], 76 | 'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2'] 77 | } 78 | 79 | } 80 | gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'], 81 | 'OUMVLP': ['01'], 82 | 'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2']} 83 | msg_mgr = get_msg_mgr() 84 | acc = {} 85 | view_list = sorted(np.unique(view)) 86 | if dataset == 'CASIA-E': 87 | view_list.remove("270") 88 | view_num = len(view_list) 89 | num_rank = 1 90 | for (type_, probe_seq) in probe_seq_dict[dataset].items(): 91 | acc[type_] = np.zeros((view_num, view_num)) - 1. 92 | for (v1, probe_view) in enumerate(view_list): 93 | pseq_mask = np.isin(seq_type, probe_seq) & np.isin( 94 | view, probe_view) 95 | probe_x = feature[pseq_mask, :] 96 | probe_y = label[pseq_mask] 97 | 98 | for (v2, gallery_view) in enumerate(view_list): 99 | gseq_mask = np.isin(seq_type, gallery_seq_dict[dataset]) & np.isin( 100 | view, [gallery_view]) 101 | gallery_y = label[gseq_mask] 102 | gallery_x = feature[gseq_mask, :] 103 | dist = cuda_dist(probe_x, gallery_x, metric) 104 | idx = dist.topk(num_rank, largest=False)[1].cpu().numpy() 105 | acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx], 1) > 0, 106 | 0) * 100 / dist.shape[0], 2) 107 | 108 | result_dict = {} 109 | msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===') 110 | out_str = "" 111 | for type_ in probe_seq_dict[dataset].keys(): 112 | sub_acc = de_diag(acc[type_], each_angle=True) 113 | msg_mgr.log_info(f'{type_}: {sub_acc}') 114 | result_dict[f'scalar/test_accuracy/{type_}'] = np.mean(sub_acc) 115 | out_str += f"{type_}: {np.mean(sub_acc):.2f}%\t" 116 | msg_mgr.log_info(out_str) 117 | return result_dict 118 | 119 | 120 | def evaluate_indoor_dataset(data, dataset, metric='euc', cross_view_gallery=False): 121 | feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views'] 122 | label = np.array(label) 123 | view = np.array(view) 124 | 125 | if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E'): 126 | raise KeyError("DataSet %s hasn't been supported !" % dataset) 127 | if cross_view_gallery: 128 | return cross_view_gallery_evaluation( 129 | feature, label, seq_type, view, dataset, metric) 130 | else: 131 | return single_view_gallery_evaluation( 132 | feature, label, seq_type, view, dataset, metric) 133 | 134 | 135 | def evaluate_real_scene(data, dataset, metric='euc'): 136 | msg_mgr = get_msg_mgr() 137 | feature, label, seq_type = data['embeddings'], data['labels'], data['types'] 138 | label = np.array(label) 139 | 140 | gallery_seq_type = {'0001-1000': ['1', '2'], 141 | "HID2021": ['0'], '0001-1000-test': ['0'], 142 | 'GREW': ['01'], 'TTG-200': ['1']} 143 | probe_seq_type = {'0001-1000': ['3', '4', '5', '6'], 144 | "HID2021": ['1'], '0001-1000-test': ['1'], 145 | 'GREW': ['02'], 'TTG-200': ['2', '3', '4', '5', '6']} 146 | 147 | num_rank = 20 148 | acc = np.zeros([num_rank]) - 1. 149 | gseq_mask = np.isin(seq_type, gallery_seq_type[dataset]) 150 | gallery_x = feature[gseq_mask, :] 151 | gallery_y = label[gseq_mask] 152 | pseq_mask = np.isin(seq_type, probe_seq_type[dataset]) 153 | probe_x = feature[pseq_mask, :] 154 | probe_y = label[pseq_mask] 155 | 156 | dist = cuda_dist(probe_x, gallery_x, metric) 157 | idx = dist.topk(num_rank, largest=False)[1].cpu().numpy() 158 | acc = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, 159 | 0) * 100 / dist.shape[0], 2) 160 | msg_mgr.log_info('==Rank-1==') 161 | msg_mgr.log_info('%.3f' % (np.mean(acc[0]))) 162 | msg_mgr.log_info('==Rank-5==') 163 | msg_mgr.log_info('%.3f' % (np.mean(acc[4]))) 164 | msg_mgr.log_info('==Rank-10==') 165 | msg_mgr.log_info('%.3f' % (np.mean(acc[9]))) 166 | msg_mgr.log_info('==Rank-20==') 167 | msg_mgr.log_info('%.3f' % (np.mean(acc[19]))) 168 | return {"scalar/test_accuracy/Rank-1": np.mean(acc[0]), "scalar/test_accuracy/Rank-5": np.mean(acc[4])} 169 | 170 | 171 | def GREW_submission(data, dataset, metric='euc'): 172 | get_msg_mgr().log_info("Evaluating GREW") 173 | feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views'] 174 | label = np.array(label) 175 | view = np.array(view) 176 | gallery_seq_type = {'GREW': ['01', '02']} 177 | probe_seq_type = {'GREW': ['03']} 178 | gseq_mask = np.isin(seq_type, gallery_seq_type[dataset]) 179 | gallery_x = feature[gseq_mask, :] 180 | gallery_y = label[gseq_mask] 181 | pseq_mask = np.isin(seq_type, probe_seq_type[dataset]) 182 | probe_x = feature[pseq_mask, :] 183 | probe_y = view[pseq_mask] 184 | 185 | num_rank = 20 186 | dist = cuda_dist(probe_x, gallery_x, metric) 187 | idx = dist.topk(num_rank, largest=False)[1].cpu().numpy() 188 | 189 | save_path = os.path.join( 190 | "GREW_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv") 191 | mkdir("GREW_result") 192 | with open(save_path, "w") as f: 193 | f.write("videoId,rank1,rank2,rank3,rank4,rank5,rank6,rank7,rank8,rank9,rank10,rank11,rank12,rank13,rank14,rank15,rank16,rank17,rank18,rank19,rank20\n") 194 | for i in range(len(idx)): 195 | r_format = [int(idx) for idx in gallery_y[idx[i, 0:num_rank]]] 196 | output_row = '{}'+',{}'*num_rank+'\n' 197 | f.write(output_row.format(probe_y[i], *r_format)) 198 | print("GREW result saved to {}/{}".format(os.getcwd(), save_path)) 199 | return 200 | 201 | 202 | def HID_submission(data, dataset, rerank=True, metric='euc'): 203 | msg_mgr = get_msg_mgr() 204 | msg_mgr.log_info("Evaluating HID") 205 | feature, label, seq_type = data['embeddings'], data['labels'], data['views'] 206 | label = np.array(label) 207 | seq_type = np.array(seq_type) 208 | probe_mask = (label == "probe") 209 | gallery_mask = (label != "probe") 210 | gallery_x = feature[gallery_mask, :] 211 | gallery_y = label[gallery_mask] 212 | probe_x = feature[probe_mask, :] 213 | probe_y = seq_type[probe_mask] 214 | if rerank: 215 | feat = np.concatenate([probe_x, gallery_x]) 216 | dist = cuda_dist(feat, feat, metric).cpu().numpy() 217 | msg_mgr.log_info("Starting Re-ranking") 218 | re_rank = re_ranking( 219 | dist, probe_x.shape[0], k1=6, k2=6, lambda_value=0.3) 220 | idx = np.argsort(re_rank, axis=1) 221 | else: 222 | dist = cuda_dist(probe_x, gallery_x, metric) 223 | idx = dist.cpu().sort(1)[1].numpy() 224 | 225 | save_path = os.path.join( 226 | "HID_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv") 227 | mkdir("HID_result") 228 | with open(save_path, "w") as f: 229 | f.write("videoID,label\n") 230 | for i in range(len(idx)): 231 | f.write("{},{}\n".format(probe_y[i], gallery_y[idx[i, 0]])) 232 | print("HID result saved to {}/{}".format(os.getcwd(), save_path)) 233 | return 234 | 235 | 236 | def evaluate_segmentation(data, dataset): 237 | labels = data['mask'] 238 | pred = data['pred'] 239 | miou = mean_iou(pred, labels) 240 | get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean())) 241 | return {"scalar/test_accuracy/mIOU": miou} 242 | 243 | 244 | def evaluate_Gait3D(data, dataset, metric='euc'): 245 | msg_mgr = get_msg_mgr() 246 | 247 | features, labels, cams, time_seqs = data['embeddings'], data['labels'], data['types'], data['views'] 248 | import json 249 | 250 | probe_sets = json.load( 251 | open('./datasets/Gait3D/Gait3D.json', 'rb'))['PROBE_SET'] 252 | 253 | probe_mask = [] 254 | 255 | for id, ty, sq in zip(labels, cams, time_seqs): 256 | if '-'.join([id, ty, sq]) in probe_sets: 257 | probe_mask.append(True) 258 | else: 259 | probe_mask.append(False) 260 | probe_mask = np.array(probe_mask) 261 | 262 | # probe_features = features[:probe_num] 263 | probe_features = features[probe_mask] 264 | # gallery_features = features[probe_num:] 265 | gallery_features = features[~probe_mask] 266 | # probe_lbls = np.asarray(labels[:probe_num]) 267 | # gallery_lbls = np.asarray(labels[probe_num:]) 268 | probe_lbls = np.asarray(labels)[probe_mask] 269 | gallery_lbls = np.asarray(labels)[~probe_mask] 270 | 271 | results = {} 272 | msg_mgr.log_info(f"The test metric you choose is {metric}.") 273 | dist = cuda_dist(probe_features, gallery_features, metric).cpu().numpy() 274 | cmc, all_AP, all_INP = evaluate_rank(dist, probe_lbls, gallery_lbls) 275 | 276 | mAP = np.mean(all_AP) 277 | mINP = np.mean(all_INP) 278 | for r in [1, 5, 10]: 279 | results['scalar/test_accuracy/Rank-{}'.format(r)] = cmc[r - 1] * 100 280 | results['scalar/test_accuracy/mAP'] = mAP * 100 281 | results['scalar/test_accuracy/mINP'] = mINP * 100 282 | 283 | # print_csv_format(dataset_name, results) 284 | msg_mgr.log_info(results) 285 | return results 286 | 287 | 288 | -------------------------------------------------------------------------------- /opengait/modeling/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils import clones, is_list_or_tuple 6 | from torchvision.ops import RoIAlign 7 | 8 | class BasicLinear(nn.Module): 9 | def __init__(self, in_channels, out_channels,**kwargs): 10 | super(BasicLinear, self).__init__() 11 | self.l1 = nn.Linear(in_channels, out_channels, bias=False, **kwargs) 12 | 13 | def forward(self, x): 14 | x = x.permute(2, 0, 1).contiguous() 15 | x = self.l1(x) 16 | return x.permute(1, 2, 0).contiguous() 17 | 18 | class HorizontalPoolingPyramid(): 19 | """ 20 | Horizontal Pyramid Matching for Person Re-identification 21 | Arxiv: https://arxiv.org/abs/1804.05275 22 | Github: https://github.com/SHI-Labs/Horizontal-Pyramid-Matching 23 | """ 24 | 25 | def __init__(self, bin_num=None): 26 | if bin_num is None: 27 | bin_num = [16, 8, 4, 2, 1] 28 | self.bin_num = bin_num 29 | 30 | def __call__(self, x): 31 | """ 32 | x : [n, c, h, w] 33 | ret: [n, c, p] 34 | """ 35 | n, c = x.size()[:2] 36 | features = [] 37 | for b in self.bin_num: 38 | z = x.view(n, c, b, -1) 39 | z = z.mean(-1) + z.max(-1)[0] 40 | features.append(z) 41 | return torch.cat(features, -1) 42 | 43 | 44 | 45 | class SetBlockWrapper(nn.Module): 46 | def __init__(self, forward_block): 47 | super(SetBlockWrapper, self).__init__() 48 | self.forward_block = forward_block 49 | 50 | def forward(self, x, *args, **kwargs): 51 | """ 52 | In x: [n, c_in, s, h_in, w_in] 53 | Out x: [n, c_out, s, h_out, w_out] 54 | """ 55 | n, c, s, h, w = x.size() 56 | #print(x.size()) 57 | x= self.forward_block(x.transpose( 58 | 1, 2).reshape(-1, c, h, w), n, s, *args, **kwargs) 59 | output_size = x.size() 60 | #print(output_size) 61 | return x.reshape(n, s, *output_size[1:]).transpose(1, 2).contiguous() 62 | 63 | 64 | class GetAttention(nn.Module): 65 | def __init__(self, forward_block): 66 | super(GetAttention, self).__init__() 67 | self.forward_block = forward_block 68 | 69 | def forward(self, x, *args, **kwargs): 70 | """ 71 | In x: [n, c_in, s, h_in, w_in] 72 | Out x: [n, c_out, s, h_out, w_out] 73 | """ 74 | n, c, s, h, w = x.size() 75 | x = self.forward_block(x.transpose( 76 | 1, 2).reshape(-1, c, h, w), n, s, *args, **kwargs) 77 | return x.contiguous() 78 | 79 | class PackSequenceWrapper(nn.Module): 80 | def __init__(self, pooling_func): 81 | super(PackSequenceWrapper, self).__init__() 82 | self.pooling_func = pooling_func 83 | 84 | def forward(self, seqs, seqL, dim=2, options={}): 85 | """ 86 | In seqs: [n, c, s, ...] 87 | Out rets: [n, ...] 88 | """ 89 | if seqL is None: 90 | return self.pooling_func(seqs, **options) 91 | seqL = seqL[0].data.cpu().numpy().tolist() 92 | start = [0] + np.cumsum(seqL).tolist()[:-1] 93 | 94 | rets = [] 95 | for curr_start, curr_seqL in zip(start, seqL): 96 | narrowed_seq = seqs.narrow(dim, curr_start, curr_seqL) 97 | rets.append(self.pooling_func(narrowed_seq, **options)) 98 | if len(rets) > 0 and is_list_or_tuple(rets[0]): 99 | return [torch.cat([ret[j] for ret in rets]) 100 | for j in range(len(rets[0]))] 101 | return torch.cat(rets) 102 | 103 | class PackSequenceWrapper2D_remove(nn.Module): 104 | def __init__(self, pooling_func): 105 | super(PackSequenceWrapper2D_remove, self).__init__() 106 | self.pooling_func = pooling_func 107 | 108 | def forward(self, seqs, seqL, dim=2, options={}): 109 | """ 110 | In seqs: [n, c, s, ...] 111 | Out rets: [n, ...] 112 | """ 113 | if seqL is None: 114 | return self.pooling_func(seqs, **options) 115 | n,c,s,h,w = seqs.size() 116 | rets = [] 117 | for i in range(n): 118 | mask = seqL[i] 119 | #print(30-np.sum(np.array(mask))) 120 | narrowed_seq = seqs[i][:,mask,:,:].unsqueeze(0) 121 | #print(narrowed_seq.size()) 122 | rets.append(self.pooling_func(narrowed_seq, **options)) 123 | if len(rets) > 0 and is_list_or_tuple(rets[0]): 124 | return [torch.cat([ret[j] for ret in rets]) 125 | for j in range(len(rets[0]))] 126 | return torch.cat(rets) 127 | 128 | class BasicConv2d(nn.Module): 129 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs): 130 | super(BasicConv2d, self).__init__() 131 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 132 | stride=stride, padding=padding, bias=False, **kwargs) 133 | 134 | def forward(self, x): 135 | x = self.conv(x) 136 | return x 137 | 138 | 139 | class SeparateFCs(nn.Module): 140 | def __init__(self, parts_num, in_channels, out_channels, norm=False): 141 | super(SeparateFCs, self).__init__() 142 | self.p = parts_num 143 | self.fc_bin = nn.Parameter( 144 | nn.init.xavier_uniform_( 145 | torch.zeros(parts_num, in_channels, out_channels))) 146 | self.norm = norm 147 | 148 | def forward(self, x): 149 | """ 150 | x: [n, c_in, p] 151 | out: [n, c_out, p] 152 | """ 153 | x = x.permute(2, 0, 1).contiguous() 154 | if self.norm: 155 | out = x.matmul(F.normalize(self.fc_bin, dim=1)) 156 | else: 157 | out = x.matmul(self.fc_bin) 158 | return out.permute(1, 2, 0).contiguous() 159 | 160 | class SeparateBNNecks(nn.Module): 161 | """ 162 | Bag of Tricks and a Strong Baseline for Deep Person Re-Identification 163 | CVPR Workshop: https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf 164 | Github: https://github.com/michuanhaohao/reid-strong-baseline 165 | """ 166 | 167 | def __init__(self, parts_num, in_channels, class_num, norm=True, parallel_BN1d=True): 168 | super(SeparateBNNecks, self).__init__() 169 | self.p = parts_num 170 | self.class_num = class_num 171 | self.norm = norm 172 | self.fc_bin = nn.Parameter( 173 | nn.init.xavier_uniform_( 174 | torch.zeros(parts_num, in_channels, class_num))) 175 | if parallel_BN1d: 176 | self.bn1d = nn.BatchNorm1d(in_channels * parts_num) 177 | else: 178 | self.bn1d = clones(nn.BatchNorm1d(in_channels), parts_num) 179 | self.parallel_BN1d = parallel_BN1d 180 | 181 | def forward(self, x): 182 | """ 183 | x: [n, c, p] 184 | """ 185 | if self.parallel_BN1d: 186 | n, c, p = x.size() 187 | x = x.view(n, -1) # [n, c*p] 188 | x = self.bn1d(x) 189 | x = x.view(n, c, p) 190 | else: 191 | x = torch.cat([bn(_x) for _x, bn in zip( 192 | x.split(1, 2), self.bn1d)], 2) # [p, n, c] 193 | feature = x.permute(2, 0, 1).contiguous() 194 | if self.norm: 195 | feature = F.normalize(feature, dim=-1) # [p, n, c] 196 | logits = feature.matmul(F.normalize( 197 | self.fc_bin, dim=1)) # [p, n, c],c 中的值代表了cos,即步态向量和中心向量的cos值 198 | else: 199 | logits = feature.matmul(self.fc_bin) 200 | return feature.permute(1, 2, 0).contiguous(), logits.permute(1, 2, 0).contiguous() 201 | 202 | class SeparateBNNecks_deleteCommon(nn.Module): 203 | """ 204 | Bag of Tricks and a Strong Baseline for Deep Person Re-Identification 205 | CVPR Workshop: https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf 206 | Github: https://github.com/michuanhaohao/reid-strong-baseline 207 | """ 208 | 209 | def __init__(self, parts_num, in_channels, parallel_BN1d=True): 210 | super(SeparateBNNecks_deleteCommon, self).__init__() 211 | self.p = parts_num 212 | self.fc_common = nn.Parameter( 213 | nn.init.xavier_uniform_( 214 | torch.zeros(parts_num, 1, in_channels))) 215 | if parallel_BN1d: 216 | self.bn1d = nn.BatchNorm1d(in_channels * parts_num) 217 | else: 218 | self.bn1d = clones(nn.BatchNorm1d(in_channels), parts_num) 219 | self.parallel_BN1d = parallel_BN1d 220 | 221 | def ComputeDistance(self, x, y): 222 | """ 223 | x: [p, n_x, c] 224 | y: [p, n_y, c] 225 | """ 226 | x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1] 227 | y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y] 228 | inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y] 229 | dist = x2 + y2 - 2 * inner 230 | dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y] 231 | return dist 232 | 233 | 234 | def forward(self, x): 235 | """ 236 | x: [n, c, p] 237 | """ 238 | if self.parallel_BN1d: 239 | n, c, p = x.size() 240 | x = x.view(n, -1) # [n, c*p] 241 | x = self.bn1d(x) 242 | x = x.view(n, c, p) 243 | else: 244 | x = torch.cat([bn(_x) for _x, bn in zip( 245 | x.split(1, 2), self.bn1d)], 2) 246 | feature = x.permute(2, 0, 1).contiguous() # [p, n, c] 247 | common_dist = torch.mean(self.ComputeDistance(feature, self.fc_common)) 248 | feature = feature-self.fc_common 249 | return feature.permute(1, 2, 0).contiguous(), common_dist 250 | 251 | 252 | class FocalConv2d(nn.Module): 253 | """ 254 | GaitPart: Temporal Part-based Model for Gait Recognition 255 | CVPR2020: https://openaccess.thecvf.com/content_CVPR_2020/papers/Fan_GaitPart_Temporal_Part-Based_Model_for_Gait_Recognition_CVPR_2020_paper.pdf 256 | Github: https://github.com/ChaoFan96/GaitPart 257 | """ 258 | def __init__(self, in_channels, out_channels, kernel_size, halving, **kwargs): 259 | super(FocalConv2d, self).__init__() 260 | self.halving = halving 261 | self.conv = nn.Conv2d(in_channels, out_channels, 262 | kernel_size, bias=False, **kwargs) 263 | 264 | def forward(self, x): 265 | if self.halving == 0: 266 | z = self.conv(x) 267 | else: 268 | h = x.size(2) 269 | split_size = int(h // 2**self.halving) 270 | z = x.split(split_size, 2) 271 | z = torch.cat([self.conv(_) for _ in z], 2) 272 | return z 273 | 274 | 275 | class BasicConv3d(nn.Module): 276 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs): 277 | super(BasicConv3d, self).__init__() 278 | self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, 279 | stride=stride, padding=padding, bias=bias, **kwargs) 280 | 281 | def forward(self, ipts): 282 | ''' 283 | ipts: [n, c, s, h, w] 284 | outs: [n, c, s, h, w] 285 | ''' 286 | outs = self.conv3d(ipts) 287 | return outs 288 | 289 | 290 | class GaitAlign(nn.Module): 291 | """ 292 | GaitEdge: Beyond Plain End-to-end Gait Recognition for Better Practicality 293 | ECCV2022: https://arxiv.org/pdf/2203.03972v2.pdf 294 | Github: https://github.com/ShiqiYu/OpenGait/tree/master/configs/gaitedge 295 | """ 296 | def __init__(self, H=64, W=44, eps=1, **kwargs): 297 | super(GaitAlign, self).__init__() 298 | self.H, self.W, self.eps = H, W, eps 299 | self.Pad = nn.ZeroPad2d((int(self.W / 2), int(self.W / 2), 0, 0)) 300 | self.RoiPool = RoIAlign((self.H, self.W), 1, sampling_ratio=-1) 301 | 302 | def forward(self, feature_map, binary_mask, w_h_ratio): 303 | """ 304 | In sils: [n, c, h, w] 305 | w_h_ratio: [n, 1] 306 | Out aligned_sils: [n, c, H, W] 307 | """ 308 | n, c, h, w = feature_map.size() 309 | # w_h_ratio = w_h_ratio.repeat(1, 1) # [n, 1] 310 | w_h_ratio = w_h_ratio.view(-1, 1) # [n, 1] 311 | 312 | h_sum = binary_mask.sum(-1) # [n, c, h] 313 | _ = (h_sum >= self.eps).float().cumsum(axis=-1) # [n, c, h] 314 | h_top = (_ == 0).float().sum(-1) # [n, c] 315 | h_bot = (_ != torch.max(_, dim=-1, keepdim=True) 316 | [0]).float().sum(-1) + 1. # [n, c] 317 | 318 | w_sum = binary_mask.sum(-2) # [n, c, w] 319 | w_cumsum = w_sum.cumsum(axis=-1) # [n, c, w] 320 | w_h_sum = w_sum.sum(-1).unsqueeze(-1) # [n, c, 1] 321 | w_center = (w_cumsum < w_h_sum / 2.).float().sum(-1) # [n, c] 322 | 323 | p1 = self.W - self.H * w_h_ratio 324 | p1 = p1 / 2. 325 | p1 = torch.clamp(p1, min=0) # [n, c] 326 | t_w = w_h_ratio * self.H / w 327 | p2 = p1 / t_w # [n, c] 328 | 329 | height = h_bot - h_top # [n, c] 330 | width = height * w / h # [n, c] 331 | width_p = int(self.W / 2) 332 | 333 | feature_map = self.Pad(feature_map) 334 | w_center = w_center + width_p # [n, c] 335 | 336 | w_left = w_center - width / 2 - p2 # [n, c] 337 | w_right = w_center + width / 2 + p2 # [n, c] 338 | 339 | w_left = torch.clamp(w_left, min=0., max=w+2*width_p) 340 | w_right = torch.clamp(w_right, min=0., max=w+2*width_p) 341 | 342 | boxes = torch.cat([w_left, h_top, w_right, h_bot], dim=-1) 343 | # index of bbox in batch 344 | box_index = torch.arange(n, device=feature_map.device) 345 | rois = torch.cat([box_index.view(-1, 1), boxes], -1) 346 | crops = self.RoiPool(feature_map, rois) # [n, c, H, W] 347 | return crops 348 | 349 | 350 | def RmBN2dAffine(model): 351 | for m in model.modules(): 352 | if isinstance(m, nn.BatchNorm2d): 353 | m.weight.requires_grad = False 354 | m.bias.requires_grad = False 355 | -------------------------------------------------------------------------------- /opengait/modeling/backbones/GLGait.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | import torch.nn as nn 3 | import torch 4 | import math 5 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet 6 | # from ..modules import BasicConv2d 7 | from typing import Tuple, Optional, Callable, List, Type, Any, Union 8 | from torch import Tensor 9 | import numpy as np 10 | from einops import rearrange 11 | from einops.layers.torch import Reduce 12 | 13 | class BasicConv2d(nn.Module): 14 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs): 15 | super(BasicConv2d, self).__init__() 16 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 17 | stride=stride, padding=padding, bias=False, **kwargs) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | return x 22 | 23 | class Conv3DNoSpatial(nn.Conv3d): 24 | 25 | def __init__( 26 | self, 27 | in_planes: int, 28 | out_planes: int, 29 | stride: int = 1, 30 | padding: int = 1, 31 | group: int = 1, 32 | ) -> None: 33 | super(Conv3DNoSpatial, self).__init__( 34 | in_channels=in_planes, 35 | out_channels=out_planes, 36 | kernel_size=(3, 1, 1), 37 | stride=(stride, 1, 1), 38 | padding=(padding, 0, 0), 39 | groups=group, 40 | bias=False) 41 | 42 | @staticmethod 43 | def get_downsample_stride(stride: int) -> Tuple[int, int, int]: 44 | return stride, 1, 1 45 | 46 | 47 | class Conv3DNoTemporal(nn.Conv3d): 48 | 49 | def __init__( 50 | self, 51 | in_planes: int, 52 | out_planes: int, 53 | stride: int = 1, 54 | padding: int = 1, 55 | group: int = 1, 56 | ) -> None: 57 | super(Conv3DNoTemporal, self).__init__( 58 | in_channels=in_planes, 59 | out_channels=out_planes, 60 | kernel_size=(1, 3, 3), 61 | stride=(1, stride, stride), 62 | padding=(0, padding, padding), 63 | groups=group, 64 | bias=False) 65 | 66 | @staticmethod 67 | def get_downsample_stride(stride: int) -> Tuple[int, int, int]: 68 | return 1, stride, stride 69 | 70 | 71 | class Conv3D1x1(nn.Conv3d): 72 | 73 | def __init__( 74 | self, 75 | in_planes: int, 76 | out_planes: int, 77 | stride: int = 1, 78 | padding: int = 1, 79 | group: int = 1, 80 | ) -> None: 81 | super(Conv3D1x1, self).__init__( 82 | in_channels=in_planes, 83 | out_channels=out_planes, 84 | kernel_size=(1, 1, 1), 85 | stride=(1, stride, stride), 86 | padding=(0, 0, 0), 87 | groups=group, 88 | bias=False) 89 | 90 | @staticmethod 91 | def get_downsample_stride(stride: int) -> Tuple[int, int, int]: 92 | return 1, 1, 1 93 | 94 | 95 | class Conv3DSimple(nn.Conv3d): 96 | def __init__( 97 | self, 98 | in_planes: int, 99 | out_planes: int, 100 | stride: int = 1, 101 | padding: int = 1, 102 | group: int = 1, 103 | ) -> None: 104 | super(Conv3DSimple, self).__init__( 105 | in_channels=in_planes, 106 | out_channels=out_planes, 107 | kernel_size=(3, 3, 3), 108 | stride=(1, stride, stride), 109 | padding=padding, 110 | groups=group, 111 | bias=False) 112 | 113 | @staticmethod 114 | def get_downsample_stride(stride: int) -> Tuple[int, int, int]: 115 | return 1, stride, stride 116 | 117 | 118 | class FeedForward(nn.Module): 119 | def __init__(self, dim, hidden_dim, dropout=0.): 120 | super().__init__() 121 | self.net = nn.Sequential( 122 | nn.LayerNorm(dim), 123 | nn.Linear(dim, hidden_dim), 124 | nn.GELU(), 125 | nn.Dropout(dropout), 126 | nn.Linear(hidden_dim, dim), 127 | nn.Dropout(dropout) 128 | ) 129 | 130 | def forward(self, x): 131 | return self.net(x) 132 | 133 | class Attention(nn.Module): 134 | def __init__(self, dim, heads=4, dim_head=64, dropout=0.): 135 | super().__init__() 136 | inner_dim = dim_head * heads 137 | self.heads = heads 138 | self.scale = dim_head ** -0.5 139 | 140 | self.norm = nn.LayerNorm(dim) 141 | self.attend = nn.Softmax(dim=-1) 142 | self.dropout = nn.Dropout(dropout) 143 | 144 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 145 | 146 | self.to_out = nn.Sequential( 147 | nn.Linear(inner_dim, dim), 148 | nn.Dropout(dropout) 149 | ) 150 | 151 | def forward(self, x): 152 | x = self.norm(x) 153 | qkv = self.to_qkv(x) 154 | b, p, n, d = qkv.size() 155 | h = self.heads 156 | d1 = d//(h*3) 157 | qkv = qkv.reshape(b,p,n,h,3,d1).permute(4,0,1,3,2,5) 158 | q, k ,v = qkv #[b,p,h,n,d1] 159 | 160 | 161 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale #[b,p,h,n,n] 162 | 163 | attn = self.attend(dots) 164 | attn = self.dropout(attn) 165 | 166 | out = torch.matmul(attn, v) #[b,p,h,n,d1] 167 | out = out.permute(0,1,3,2,4).reshape(b,p,n,h*d1) 168 | return self.to_out(out) 169 | 170 | 171 | class BasicBlock2D_trans3(nn.Module): 172 | expansion = 1 173 | 174 | def __init__( 175 | self, 176 | inplanes: int, 177 | planes: int, 178 | stride: int = 1, 179 | downsample: Optional[nn.Module] = None, 180 | tem_nums: int = 3, 181 | seq: int = 1, 182 | ) -> None: 183 | 184 | super(BasicBlock2D_trans3, self).__init__() 185 | 186 | self.dims = planes 187 | self.conv1 = nn.Sequential( 188 | Conv3DNoTemporal(inplanes, planes, stride), 189 | nn.BatchNorm3d(planes), 190 | nn.ReLU(inplace=True) 191 | ) 192 | 193 | self.conv11 = nn.Sequential( 194 | Conv3D1x1(planes, self.dims), 195 | nn.BatchNorm3d(self.dims), 196 | nn.ReLU(inplace=True) 197 | ) 198 | 199 | self.tem_nums = tem_nums 200 | self.pt = 3 201 | 202 | self.basic_layers = Attention(dim=self.dims, heads=4, dim_head = 64) 203 | self.fc = nn.Linear(self.dims,self.dims) 204 | #self.fc = Conv3D1x1(self.dims, self.dims) 205 | self.conv12 =nn.Sequential( 206 | Conv3DNoSpatial(self.dims, self.dims), 207 | nn.BatchNorm3d(self.dims), 208 | nn.ReLU(inplace=True), 209 | ) 210 | self.conv13 = nn.Sequential( 211 | Conv3D1x1(self.dims, planes), 212 | nn.BatchNorm3d(planes), 213 | nn.ReLU(inplace=True), 214 | ) 215 | 216 | self.conv3 = nn.Sequential( 217 | Conv3DNoTemporal(planes, planes), 218 | nn.BatchNorm3d(planes) 219 | ) 220 | self.relu1 = nn.ReLU(inplace=True) 221 | self.relu2 = nn.ReLU(inplace=True) 222 | self.downsample = downsample 223 | self.stride = stride 224 | 225 | def forward(self, x: Tensor) -> Tensor: 226 | residual = x 227 | #spatial 228 | out1 = self.conv1(x) 229 | #temporal 230 | outx = self.conv11(out1) 231 | n,c,t,h,w = outx.size() 232 | t1 = t//self.pt 233 | outx = outx.reshape(n, c, self.pt, t1, h, w) 234 | outx = outx.permute(0, 2, 4, 5, 3, 1).reshape(n, self.pt*h*w, t1, c) #[n, pt*h*w, t1, c] 235 | outx = self.basic_layers(outx)+outx # [n,pt*h*w, t1, c] 236 | outx = outx.reshape(n, self.pt, h, w, t1, c).permute(0, 5, 1, 4, 2, 3) 237 | outx0 = outx.reshape(n, c, t, h, w) 238 | outx = self.conv12(outx0) 239 | outx = self.fc(outx.permute(0,2,3,4,1)).permute(0,4,1,2,3) 240 | #outx = self.fc(outx) 241 | outx = outx+outx0 242 | 243 | outx = self.conv13(outx) 244 | out = self.relu1(outx+out1) 245 | #spatial 246 | out = self.conv3(out) 247 | if self.downsample is not None: 248 | residual = self.downsample(x) 249 | 250 | out += residual 251 | out = self.relu2(out) 252 | 253 | return out 254 | 255 | 256 | class BasicBlockP3D(nn.Module): 257 | expansion = 1 258 | 259 | def __init__( 260 | self, 261 | inplanes: int, 262 | planes: int, 263 | stride: int = 1, 264 | downsample: Optional[nn.Module] = None, 265 | tem_nums: int = 3, 266 | seq: int = 1, 267 | ) -> None: 268 | super(BasicBlockP3D, self).__init__() 269 | self.conv1 = nn.Sequential( 270 | Conv3DNoTemporal(inplanes, planes, stride), 271 | nn.BatchNorm3d(planes), 272 | nn.ReLU(inplace=True) 273 | ) 274 | self.conv2 = nn.Sequential( 275 | Conv3DNoSpatial(planes, planes), 276 | nn.BatchNorm3d(planes), 277 | ) 278 | self.conv3 = nn.Sequential( 279 | Conv3DNoTemporal(planes, planes), 280 | nn.BatchNorm3d(planes), 281 | ) 282 | self.relu1 = nn.ReLU(inplace=True) 283 | self.relu2 = nn.ReLU(inplace=True) 284 | self.relu3 = nn.ReLU(inplace=True) 285 | self.downsample = downsample 286 | self.stride = stride 287 | 288 | def forward(self, x: Tensor) -> Tensor: 289 | residual = x 290 | 291 | out1 = self.conv1(x) 292 | out2 = self.conv2(out1) 293 | out = self.relu1(out1 + out2) 294 | out = self.conv3(out) 295 | if self.downsample is not None: 296 | residual = self.downsample(x) 297 | out += residual 298 | out = self.relu3(out) 299 | 300 | return out 301 | 302 | class BasicBlock_3D(nn.Module): 303 | 304 | expansion = 1 305 | 306 | def __init__( 307 | self, 308 | inplanes: int, 309 | planes: int, 310 | stride: int = 1, 311 | downsample: Optional[nn.Module] = None, 312 | tem_nums: int = 3, 313 | seq: int = 1, 314 | ) -> None: 315 | super(BasicBlock_3D, self).__init__() 316 | self.conv1 = nn.Sequential( 317 | Conv3DSimple(inplanes, planes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) 318 | ) 319 | self.conv2 = nn.Sequential(Conv3DSimple(planes, planes), nn.BatchNorm3d(planes)) 320 | self.relu = nn.ReLU(inplace=True) 321 | self.downsample = downsample 322 | self.stride = stride 323 | 324 | def forward(self, x: Tensor) -> Tensor: 325 | residual = x 326 | 327 | out = self.conv1(x) 328 | out = self.conv2(out) 329 | if self.downsample is not None: 330 | residual = self.downsample(x) 331 | 332 | out += residual 333 | out = self.relu(out) 334 | 335 | return out 336 | 337 | 338 | block_map = {'BasicBlock': BasicBlock, 339 | 'Bottleneck': Bottleneck, 340 | 'BasicBlockP3D': BasicBlockP3D, 341 | 'BasicBlock_3D': BasicBlock_3D, 342 | 'BasicBlock2D_trans3': BasicBlock2D_trans3} 343 | 344 | class GLGait(nn.Module): 345 | def __init__(self, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], 346 | maxpool=True): 347 | if block in block_map.keys(): 348 | block = block_map[block] 349 | else: 350 | raise ValueError( 351 | "Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.") 352 | block3D = block_map['BasicBlock2D_trans3'] 353 | self.maxpool_flag = maxpool 354 | super(GLGait, self).__init__() 355 | self._norm_layer = nn.BatchNorm2d 356 | self.dilation = 1 357 | self.groups = 1 358 | self.base_width = 64 359 | self.x_att = None 360 | self.upsample = nn.UpsamplingBilinear2d(size=(32, 24)) 361 | # Not used # 362 | # self.fc = nn.Linear(1,1) 363 | self.fc = None 364 | ############ 365 | self.inplanes = channels[0] 366 | self.bn1 = nn.BatchNorm2d(self.inplanes) 367 | self.relu = nn.ReLU(inplace=True) 368 | self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1) 369 | # 370 | # self.layer1 = self._make_layer(block, channels[0], layers[0], stride=strides[0], dilate=False) 371 | # self.layer2 = self._make_layer(block, channels[1], layers[1], stride=strides[1], dilate=False) 372 | self.layer1 = self._make_layer_P3D(block_map['BasicBlockP3D'], channels[0], layers[0], strides[0], tem_nums=3) 373 | self.layer2 = self._make_layer_P3D(block_map['BasicBlockP3D'], channels[1], layers[1], strides[1], tem_nums=3) 374 | self.layer3 = self._make_layer_P3D(block3D, channels[2], layers[2], strides[2], tem_nums=3) 375 | self.layer4 = self._make_layer_P3D(block3D, channels[3], layers[3], strides[3], tem_nums=3) 376 | self._initialize_weights() 377 | 378 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 379 | norm_layer = self._norm_layer 380 | downsample = None 381 | previous_dilation = self.dilation 382 | if dilate: 383 | self.dilation *= stride 384 | stride = 1 385 | if stride != 1 or self.inplanes != planes * block.expansion: 386 | downsample = nn.Sequential( 387 | self.conv1x1(self.inplanes, planes * block.expansion, stride), 388 | norm_layer(planes * block.expansion), 389 | ) 390 | 391 | layers = [] 392 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 393 | self.base_width, previous_dilation, norm_layer)) 394 | self.inplanes = planes * block.expansion 395 | for _ in range(1, blocks): 396 | layers.append(block(self.inplanes, planes, groups=self.groups, 397 | base_width=self.base_width, dilation=self.dilation, 398 | norm_layer=norm_layer)) 399 | 400 | return nn.Sequential(*layers) 401 | 402 | def _make_layer_P3D( 403 | self, 404 | block: Type[Union[BasicBlock, Bottleneck]], 405 | planes: int, 406 | blocks: int, 407 | stride: int = 1, 408 | tem_nums: int = 1, 409 | ) -> nn.Sequential: 410 | downsample = None 411 | 412 | if stride != 1 or self.inplanes != planes * block.expansion: 413 | ds_stride = Conv3DSimple.get_downsample_stride(stride) 414 | downsample = nn.Sequential( 415 | nn.Conv3d(self.inplanes, planes * block.expansion, 416 | kernel_size=1, stride=ds_stride, bias=False), 417 | nn.BatchNorm3d(planes * block.expansion) 418 | ) 419 | layers = [] 420 | layers.append(block(self.inplanes, planes, stride, downsample)) 421 | 422 | self.inplanes = planes * block.expansion 423 | for i in range(1, blocks): 424 | layers.append(block(self.inplanes, planes)) 425 | 426 | return nn.Sequential(*layers) 427 | 428 | def forward(self, x, n=None, s=30): 429 | x = self.conv1(x) 430 | x = self.bn1(x) 431 | x = self.relu(x) 432 | if self.maxpool_flag: 433 | x = self.maxpool(x) 434 | 435 | bs = x.shape[0] // s 436 | x = x.view(bs, x.shape[0] // bs, x.shape[1], x.shape[2], x.shape[3]) 437 | x = x.permute(0, 2, 1, 3, 4) 438 | x = self.layer1(x) 439 | x = self.layer2(x) 440 | 441 | x = self.layer3(x) 442 | x = self.layer4(x) 443 | x = x.permute(0, 2, 1, 3, 4) 444 | x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]) 445 | return x 446 | 447 | def conv1x1(self, in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 448 | """1x1 convolution""" 449 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 450 | 451 | def _initialize_weights(self) -> None: 452 | for m in self.modules(): 453 | if isinstance(m, nn.Conv3d): 454 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 455 | if m.bias is not None: 456 | nn.init.constant_(m.bias, 0) 457 | elif isinstance(m, nn.Conv2d): 458 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 459 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 460 | nn.init.constant_(m.weight, 1) 461 | nn.init.constant_(m.bias, 0) 462 | elif isinstance(m, nn.BatchNorm1d): 463 | nn.init.constant_(m.weight, 1) 464 | nn.init.constant_(m.bias, 0) 465 | elif isinstance(m, nn.BatchNorm3d): 466 | nn.init.constant_(m.weight, 1) 467 | nn.init.constant_(m.bias, 0) 468 | elif isinstance(m, nn.Linear): 469 | nn.init.normal_(m.weight, 0, 0.01) 470 | if m.bias is not None: 471 | nn.init.constant_(m.bias.data, 0.0) 472 | -------------------------------------------------------------------------------- /opengait/modeling/base_model.py: -------------------------------------------------------------------------------- 1 | """The base model definition. 2 | 3 | This module defines the abstract meta model class and base model class. In the base model, 4 | we define the basic model functions, like get_loader, build_network, and run_train, etc. 5 | The api of the base model is run_train and run_test, they are used in `opengait/main.py`. 6 | 7 | Typical usage: 8 | 9 | BaseModel.run_train(model) 10 | BaseModel.run_test(model) 11 | """ 12 | import torch 13 | import numpy as np 14 | import os.path as osp 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | import torch.utils.data as tordata 18 | 19 | from tqdm import tqdm 20 | from torch.cuda.amp import autocast 21 | from torch.cuda.amp import GradScaler 22 | from abc import ABCMeta 23 | from abc import abstractmethod 24 | 25 | from . import backbones 26 | from .loss_aggregator import LossAggregator 27 | from data.transform import get_transform 28 | from data.collate_fn import CollateFn 29 | from data.dataset import DataSet 30 | import data.sampler as Samplers 31 | from utils import Odict, mkdir, ddp_all_gather 32 | from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from 33 | from evaluation import evaluator as eval_functions 34 | from utils import NoOp 35 | from utils import get_msg_mgr 36 | from collections import Counter 37 | 38 | __all__ = ['BaseModel'] 39 | 40 | 41 | 42 | class MetaModel(metaclass=ABCMeta): 43 | """The necessary functions for the base model. 44 | 45 | This class defines the necessary functions for the base model, in the base model, we have implemented them. 46 | """ 47 | @abstractmethod 48 | def get_loader(self, data_cfg): 49 | """Based on the given data_cfg, we get the data loader.""" 50 | raise NotImplementedError 51 | 52 | @abstractmethod 53 | def build_network(self, model_cfg): 54 | """Build your network here.""" 55 | raise NotImplementedError 56 | 57 | @abstractmethod 58 | def init_parameters(self): 59 | """Initialize the parameters of your network.""" 60 | raise NotImplementedError 61 | 62 | @abstractmethod 63 | def get_optimizer(self, optimizer_cfg): 64 | """Based on the given optimizer_cfg, we get the optimizer.""" 65 | raise NotImplementedError 66 | 67 | @abstractmethod 68 | def get_scheduler(self, scheduler_cfg): 69 | """Based on the given scheduler_cfg, we get the scheduler.""" 70 | raise NotImplementedError 71 | 72 | @abstractmethod 73 | def save_ckpt(self, iteration): 74 | """Save the checkpoint, including model parameter, optimizer and scheduler.""" 75 | raise NotImplementedError 76 | 77 | @abstractmethod 78 | def resume_ckpt(self, restore_hint): 79 | """Resume the model from the checkpoint, including model parameter, optimizer and scheduler.""" 80 | raise NotImplementedError 81 | 82 | @abstractmethod 83 | def inputs_pretreament(self, inputs): 84 | """Transform the input data based on transform setting.""" 85 | raise NotImplementedError 86 | 87 | @abstractmethod 88 | def train_step(self, loss_num) -> bool: 89 | """Do one training step.""" 90 | raise NotImplementedError 91 | 92 | @abstractmethod 93 | def inference(self): 94 | """Do inference (calculate features.).""" 95 | raise NotImplementedError 96 | 97 | @abstractmethod 98 | def run_train(model): 99 | """Run a whole train schedule.""" 100 | raise NotImplementedError 101 | 102 | @abstractmethod 103 | def run_test(model): 104 | """Run a whole test schedule.""" 105 | raise NotImplementedError 106 | 107 | 108 | class BaseModel(MetaModel, nn.Module): 109 | """Base model. 110 | 111 | This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc. 112 | 113 | Attributes: 114 | msg_mgr: the massage manager. 115 | cfgs: the configs. 116 | iteration: the current iteration of the model. 117 | engine_cfg: the configs of the engine(train or test). 118 | save_path: the path to save the checkpoints. 119 | 120 | """ 121 | 122 | def __init__(self, cfgs, training): 123 | """Initialize the base model. 124 | 125 | Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss. 126 | 127 | Args: 128 | cfgs: 129 | All of the configs. 130 | training: 131 | Whether the model is in training mode. 132 | """ 133 | 134 | super(BaseModel, self).__init__() 135 | self.msg_mgr = get_msg_mgr() 136 | self.cfgs = cfgs 137 | self.iteration = 0 138 | self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg'] 139 | if self.engine_cfg is None: 140 | raise Exception("Initialize a model without -Engine-Cfgs-") 141 | 142 | if training and self.engine_cfg['enable_float16']: 143 | self.Scaler = GradScaler() 144 | self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'], 145 | cfgs['model_cfg']['model'], self.engine_cfg['save_name']) 146 | 147 | self.build_network(cfgs['model_cfg']) 148 | self.init_parameters() 149 | 150 | self.msg_mgr.log_info(cfgs['data_cfg']) 151 | if training: 152 | self.train_loader = self.get_loader( 153 | cfgs['data_cfg'], train=True) 154 | if not training or self.engine_cfg['with_test']: 155 | self.test_loader = self.get_loader( 156 | cfgs['data_cfg'], train=False) 157 | 158 | self.device = torch.distributed.get_rank() 159 | torch.cuda.set_device(self.device) 160 | self.to(device=torch.device( 161 | "cuda", self.device)) 162 | 163 | if training: 164 | self.loss_aggregator = LossAggregator(cfgs['loss_cfg']) 165 | self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg']) 166 | self.scheduler = self.get_scheduler(cfgs['scheduler_cfg']) 167 | self.train(training) 168 | restore_hint = self.engine_cfg['restore_hint'] 169 | if restore_hint != 0: 170 | self.resume_ckpt(restore_hint) 171 | 172 | def get_backbone(self, backbone_cfg): 173 | """Get the backbone of the model.""" 174 | if is_dict(backbone_cfg): 175 | Backbone = get_attr_from([backbones], backbone_cfg['type']) 176 | valid_args = get_valid_args(Backbone, backbone_cfg, ['type']) 177 | return Backbone(**valid_args) 178 | if is_list(backbone_cfg): 179 | Backbone = nn.ModuleList([self.get_backbone(cfg) 180 | for cfg in backbone_cfg]) 181 | return Backbone 182 | raise ValueError( 183 | "Error type for -Backbone-Cfg-, supported: (A list of) dict.") 184 | 185 | def build_network(self, model_cfg): 186 | if 'backbone_cfg' in model_cfg.keys(): 187 | self.Backbone = self.get_backbone(model_cfg['backbone_cfg']) 188 | 189 | def init_parameters(self): 190 | for m in self.modules(): 191 | if isinstance(m, (nn.Conv3d, nn.Conv2d)): 192 | nn.init.xavier_uniform_(m.weight.data) 193 | if m.bias is not None: 194 | nn.init.constant_(m.bias.data, 0.0) 195 | elif isinstance(m, nn.Linear): 196 | nn.init.xavier_uniform_(m.weight.data) 197 | if m.bias is not None: 198 | nn.init.constant_(m.bias.data, 0.0) 199 | elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)): 200 | if m.affine: 201 | nn.init.normal_(m.weight.data, 1.0, 0.02) 202 | nn.init.constant_(m.bias.data, 0.0) 203 | 204 | def get_loader(self, data_cfg, train=True): 205 | sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler'] 206 | dataset = DataSet(data_cfg, train) 207 | 208 | Sampler = get_attr_from([Samplers], sampler_cfg['type']) 209 | vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[ 210 | 'sample_type', 'type']) 211 | sampler = Sampler(dataset, **vaild_args) 212 | 213 | loader = tordata.DataLoader( 214 | dataset=dataset, 215 | batch_sampler=sampler, 216 | collate_fn=CollateFn(dataset.label_set, sampler_cfg), 217 | num_workers=data_cfg['num_workers']) 218 | return loader 219 | 220 | def get_optimizer(self, optimizer_cfg): 221 | self.msg_mgr.log_info(optimizer_cfg) 222 | optimizer = get_attr_from([optim], optimizer_cfg['solver']) 223 | valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver']) 224 | 225 | # param_conv1 = [] 226 | # for p in self.parameters(): 227 | # if p.size() == torch.tensor([[[0,1,0]]]).size(): 228 | # param_conv1.append(p) 229 | # print(len(param_conv1)) 230 | # param_groups = [ 231 | # {'params': filter(lambda p: p.requires_grad and p not in param_groups, self.parameters()), 'lr': 0.1}, # 第一个线性层的参数使用学习率0.1 232 | # {'params': filter(lambda p: p in param_groups, self.parameters()), 'lr': 1e-7}, # 第二个线性层的参数使用学习率0.01 233 | # ] 234 | 235 | optimizer = optimizer( 236 | filter(lambda p: p.requires_grad, self.parameters()), **valid_arg) 237 | #optimizer = optimizer(param_groups, **valid_arg) 238 | 239 | 240 | 241 | # optimizer = optimizer( 242 | # {[filter(lambda p: p.requires_grad, self.parameters()), ''],}, **valid_arg) 243 | 244 | return optimizer 245 | 246 | def get_scheduler(self, scheduler_cfg): 247 | self.msg_mgr.log_info(scheduler_cfg) 248 | Scheduler = get_attr_from( 249 | [optim.lr_scheduler], scheduler_cfg['scheduler']) 250 | valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler']) 251 | scheduler = Scheduler(self.optimizer, **valid_arg) 252 | return scheduler 253 | 254 | def save_ckpt(self, iteration): 255 | if torch.distributed.get_rank() == 0: 256 | mkdir(osp.join(self.save_path, "checkpoints/")) 257 | save_name = self.engine_cfg['save_name'] 258 | checkpoint = { 259 | 'model': self.state_dict(), 260 | 'optimizer': self.optimizer.state_dict(), 261 | 'scheduler': self.scheduler.state_dict(), 262 | 'iteration': iteration} 263 | torch.save(checkpoint, 264 | osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration))) 265 | 266 | def _load_ckpt(self, save_name): 267 | load_ckpt_strict = self.engine_cfg['restore_ckpt_strict'] 268 | 269 | checkpoint = torch.load(save_name, map_location=torch.device( 270 | "cuda", self.device)) 271 | model_state_dict = checkpoint['model'] 272 | model_state_dict_1 = {} 273 | for i in model_state_dict: 274 | if i != "loss_aggregator.losses.triplet_fq.module.fc.weight" and i!="loss_aggregator.losses.triplet_fq.module.fc.bias": 275 | model_state_dict_1[i] = model_state_dict[i] 276 | 277 | 278 | if not load_ckpt_strict: 279 | self.msg_mgr.log_info("-------- Restored Params List --------") 280 | self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection( 281 | set(self.state_dict().keys())))) 282 | 283 | self.load_state_dict(model_state_dict_1, strict=load_ckpt_strict) 284 | if self.training: 285 | if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint: 286 | self.optimizer.load_state_dict(checkpoint['optimizer']) 287 | else: 288 | self.msg_mgr.log_warning( 289 | "Restore NO Optimizer from %s !!!" % save_name) 290 | if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint: 291 | self.scheduler.load_state_dict( 292 | checkpoint['scheduler']) 293 | print(self.scheduler.milestones) 294 | # self.scheduler.milestones = Counter({80000:1, 120000:1, 150000:1}) 295 | # print(self.scheduler.milestones) 296 | else: 297 | self.msg_mgr.log_warning( 298 | "Restore NO Scheduler from %s !!!" % save_name) 299 | self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name) 300 | 301 | def resume_ckpt(self, restore_hint): 302 | if isinstance(restore_hint, int): 303 | save_name = self.engine_cfg['save_name'] 304 | save_name = osp.join( 305 | self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint)) 306 | self.iteration = restore_hint 307 | elif isinstance(restore_hint, str): 308 | save_name = restore_hint 309 | self.iteration = 0 310 | else: 311 | raise ValueError( 312 | "Error type for -Restore_Hint-, supported: int or string.") 313 | self._load_ckpt(save_name) 314 | 315 | def fix_BN(self): 316 | for module in self.modules(): 317 | classname = module.__class__.__name__ 318 | if classname.find('BatchNorm') != -1: 319 | module.eval() 320 | 321 | def inputs_pretreament(self, inputs): 322 | """Conduct transforms on input data. 323 | 324 | Args: 325 | inputs: the input data. 326 | Returns: 327 | tuple: training data including inputs, labels, and some meta data. 328 | """ 329 | seqs_batch, labs_batch, typs_batch, vies_batch, ids_batch, index_batch, seqL_batch = inputs 330 | trf_cfgs = self.engine_cfg['transform'] 331 | seq_trfs = get_transform(trf_cfgs) 332 | if len(seqs_batch) != len(seq_trfs): 333 | raise ValueError( 334 | "The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs))) 335 | requires_grad = bool(self.training) 336 | seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float() 337 | for trf, seq in zip(seq_trfs, seqs_batch)] 338 | 339 | typs = typs_batch 340 | vies = vies_batch 341 | ids = ids_batch 342 | index = index_batch 343 | labs = list2var(labs_batch).long() 344 | 345 | if seqL_batch is not None: 346 | seqL_batch = np2var(seqL_batch).int() 347 | seqL = seqL_batch 348 | 349 | if seqL is not None: 350 | seqL_sum = int(seqL.sum().data.cpu().numpy()) 351 | ipts = [_[:, :seqL_sum] for _ in seqs] 352 | else: 353 | ipts = seqs 354 | del seqs 355 | return ipts, labs, typs, vies, ids, index, seqL 356 | 357 | def train_step(self, loss_sum) -> bool: 358 | """Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step(). 359 | 360 | Args: 361 | loss_sum:The loss of the current batch. 362 | Returns: 363 | bool: True if the training is finished, False otherwise. 364 | """ 365 | 366 | self.optimizer.zero_grad() 367 | if loss_sum <= 1e-9: 368 | self.msg_mgr.log_warning( 369 | "Find the loss sum less than 1e-9 but the training process will continue!") 370 | 371 | if self.engine_cfg['enable_float16']: 372 | self.Scaler.scale(loss_sum).backward() 373 | self.Scaler.step(self.optimizer) 374 | scale = self.Scaler.get_scale() 375 | self.Scaler.update() 376 | # Warning caused by optimizer skip when NaN 377 | # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5 378 | if scale != self.Scaler.get_scale(): 379 | self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format( 380 | scale, self.Scaler.get_scale())) 381 | return False 382 | else: 383 | loss_sum.backward() 384 | self.optimizer.step() 385 | 386 | self.iteration += 1 387 | self.scheduler.step() 388 | return True 389 | 390 | def inference(self, rank): 391 | """Inference all the test data. 392 | 393 | Args: 394 | rank: the rank of the current process.Transform 395 | Returns: 396 | Odict: contains the inference results. 397 | """ 398 | total_size = len(self.test_loader) 399 | if rank == 0: 400 | pbar = tqdm(total=total_size, desc='Transforming') 401 | else: 402 | pbar = NoOp() 403 | batch_size = self.test_loader.batch_sampler.batch_size 404 | rest_size = total_size 405 | info_dict = Odict() 406 | for inputs in self.test_loader: 407 | ipts = self.inputs_pretreament(inputs) 408 | with autocast(enabled=self.engine_cfg['enable_float16']): 409 | retval = self.forward(ipts) 410 | inference_feat = retval['inference_feat'] 411 | for k, v in inference_feat.items(): 412 | inference_feat[k] = ddp_all_gather(v, requires_grad=False) 413 | del retval 414 | for k, v in inference_feat.items(): 415 | inference_feat[k] = ts2np(v) 416 | info_dict.append(inference_feat) 417 | rest_size -= batch_size 418 | if rest_size >= 0: 419 | update_size = batch_size 420 | else: 421 | update_size = total_size % batch_size 422 | pbar.update(update_size) 423 | pbar.close() 424 | for k, v in info_dict.items(): 425 | v = np.concatenate(v)[:total_size] 426 | info_dict[k] = v 427 | return info_dict 428 | 429 | @ staticmethod 430 | def run_train(model): 431 | """Accept the instance object(model) here, and then run the train loop.""" 432 | 433 | for inputs in model.train_loader: 434 | ipts = model.inputs_pretreament(inputs) 435 | with autocast(enabled=model.engine_cfg['enable_float16']): 436 | retval = model(ipts) 437 | training_feat, visual_summary = retval['training_feat'], retval['visual_summary'] 438 | del retval 439 | loss_sum, loss_info = model.loss_aggregator(training_feat) 440 | ok = model.train_step(loss_sum) 441 | if not ok: 442 | continue 443 | 444 | visual_summary.update(loss_info) 445 | visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr'] 446 | 447 | model.msg_mgr.train_step(loss_info, visual_summary) 448 | if model.iteration % model.engine_cfg['save_iter'] == 0: 449 | # save the checkpoint 450 | model.save_ckpt(model.iteration) 451 | 452 | 453 | # run test if with_test = true 454 | if model.engine_cfg['with_test']: 455 | model.msg_mgr.log_info("Running test...") 456 | model.eval() 457 | result_dict = BaseModel.run_test(model) 458 | model.train() 459 | if model.cfgs['trainer_cfg']['fix_BN']: 460 | model.fix_BN() 461 | if result_dict: 462 | model.msg_mgr.write_to_tensorboard(result_dict) 463 | model.msg_mgr.reset_time() 464 | if model.iteration >= model.engine_cfg['total_iter']: 465 | break 466 | 467 | @ staticmethod 468 | def run_test(model): 469 | """Accept the instance object(model) here, and then run the test loop.""" 470 | 471 | rank = torch.distributed.get_rank() 472 | with torch.no_grad(): 473 | info_dict = model.inference(rank) 474 | if rank == 0: 475 | loader = model.test_loader 476 | label_list = loader.dataset.label_list 477 | types_list = loader.dataset.types_list 478 | views_list = loader.dataset.views_list 479 | 480 | info_dict.update({ 481 | 'labels': label_list, 'types': types_list, 'views': views_list}) 482 | 483 | if 'eval_func' in model.cfgs["evaluator_cfg"].keys(): 484 | eval_func = model.cfgs['evaluator_cfg']["eval_func"] 485 | else: 486 | eval_func = 'identification' 487 | print(eval_func) 488 | eval_func = getattr(eval_functions, eval_func) 489 | valid_args = get_valid_args( 490 | eval_func, model.cfgs["evaluator_cfg"], ['metric']) 491 | try: 492 | dataset_name = model.cfgs['data_cfg']['test_dataset_name'] 493 | except: 494 | dataset_name = model.cfgs['data_cfg']['dataset_name'] 495 | return eval_func(info_dict, dataset_name, **valid_args) 496 | --------------------------------------------------------------------------------