├── lib ├── data │ ├── __init__.py │ ├── transform.py │ ├── sampler.py │ ├── collate_fn.py │ └── dataset.py ├── utils │ ├── __init__.py │ ├── msg_manager.py │ ├── evaluation.py │ └── common.py ├── modeling │ ├── losses │ │ ├── __init__.py │ │ ├── base.py │ │ ├── softmax.py │ │ └── triplet.py │ ├── models │ │ ├── __init__.py │ │ ├── HSTL-CB.py │ │ ├── HSTL-OU.py │ │ └── HSTL-Gait3D.py │ ├── backbones │ │ ├── __init__.py │ │ └── plain.py │ ├── loss_aggregator.py │ ├── modules.py │ └── base_model.py └── main.py ├── train.sh ├── test.sh ├── config ├── hstl.yaml ├── hstl_oumvlp.yaml └── hstl_gait3d.yaml ├── misc └── partitions │ ├── CASIA-B.json │ └── CASIA-B_include_005.json └── README.md /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import get_ddp_module, ddp_all_gather 2 | from .common import Odict, Ntuple 3 | from .common import get_valid_args 4 | from .common import is_list_or_tuple, is_bool, is_str, is_list, is_dict, is_tensor, is_array, config_loader, init_seeds, handler, params_count 5 | from .common import ts2np, ts2var, np2var, list2var 6 | from .common import mkdir, clones 7 | from .common import MergeCfgsDict 8 | from .common import get_attr_from 9 | from .common import NoOp 10 | from .msg_manager import get_msg_mgr -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # For CASIA-B 2 | CUDA_VISIBLE_DEVICES=0, 1 python -m torch.distributed.launch --master_port XXXXX --nproc_per_node=2 lib/main.py --cfgs ./config/hstl.yaml --log_to_file --phase train 3 | # For OUMVLP 4 | #CUDA_VISIBLE_DEVICES=0, 1, 2, 3 python -m torch.distributed.launch --master_port XXXXX --nproc_per_node=4 lib/main.py --cfgs ./config/hstl_OUMVLP.yaml --log_to_file --phase train 5 | #CUDA_VISIBLE_DEVICES=0, 1, 2, 3 python -m torch.distributed.launch --master_port XXXXX --nproc_per_node=4 lib/main.py --cfgs ./config/hstl_gait3d.yaml --log_to_file --phase train 6 | 7 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # For CASIA-B 2 | CUDA_VISIBLE_DEVICES=0, 1 python -m torch.distributed.launch --master_port XXXXX --nproc_per_node=2 lib/main.py --cfgs ./config/hstl.yaml --log_to_file --phase test 3 | 4 | # For OUMVLP 5 | #CUDA_VISIBLE_DEVICES=0, 1, 2, 3 python -m torch.distributed.launch --master_port XXXXX --nproc_per_node=4 lib/main.py --cfgs ./config/hstl_OUMVLP.yaml --log_to_file --phase test 6 | #CUDA_VISIBLE_DEVICES=0, 1, 2, 3 python -m torch.distributed.launch --master_port XXXXX --nproc_per_node=4 lib/main.py --cfgs ./config/hstl_gait3d.yaml --log_to_file --phase test 7 | 8 | 9 | -------------------------------------------------------------------------------- /lib/modeling/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | 6 | # iterate through the modules in the current package 7 | package_dir = Path(__file__).resolve().parent 8 | for (_, module_name, _) in iter_modules([package_dir]): 9 | 10 | # import the module and iterate through its attributes 11 | module = import_module(f"{__name__}.{module_name}") 12 | for attribute_name in dir(module): 13 | attribute = getattr(module, attribute_name) 14 | 15 | if isclass(attribute): 16 | # Add the class to this package's variables 17 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /lib/modeling/models/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | 6 | # iterate through the modules in the current package 7 | package_dir = Path(__file__).resolve().parent 8 | for (_, module_name, _) in iter_modules([package_dir]): 9 | 10 | # import the module and iterate through its attributes 11 | module = import_module(f"{__name__}.{module_name}") 12 | for attribute_name in dir(module): 13 | attribute = getattr(module, attribute_name) 14 | 15 | if isclass(attribute): 16 | # Add the class to this package's variables 17 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /lib/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass 2 | from pkgutil import iter_modules 3 | from pathlib import Path 4 | from importlib import import_module 5 | 6 | # iterate through the modules in the current package 7 | package_dir = Path(__file__).resolve().parent 8 | for (_, module_name, _) in iter_modules([package_dir]): 9 | 10 | # import the module and iterate through its attributes 11 | module = import_module(f"{__name__}.{module_name}") 12 | for attribute_name in dir(module): 13 | attribute = getattr(module, attribute_name) 14 | 15 | if isclass(attribute): 16 | # Add the class to this package's variables 17 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /lib/modeling/losses/base.py: -------------------------------------------------------------------------------- 1 | from ctypes import ArgumentError 2 | import torch.nn as nn 3 | import torch 4 | from utils import Odict 5 | import functools 6 | from utils import ddp_all_gather 7 | 8 | 9 | def gather_and_scale_wrapper(func): 10 | 11 | @functools.wraps(func) 12 | def inner(*args, **kwds): 13 | try: 14 | 15 | for k, v in kwds.items(): 16 | kwds[k] = ddp_all_gather(v) 17 | 18 | loss, loss_info = func(*args, **kwds) 19 | loss *= torch.distributed.get_world_size() 20 | return loss, loss_info 21 | except: 22 | raise ArgumentError 23 | return inner 24 | 25 | 26 | class BaseLoss(nn.Module): 27 | 28 | def __init__(self, loss_term_weight=1.0): 29 | 30 | super(BaseLoss, self).__init__() 31 | self.loss_term_weight = loss_term_weight 32 | self.info = Odict() 33 | 34 | def forward(self, logits, labels): 35 | 36 | return .0, self.info 37 | -------------------------------------------------------------------------------- /lib/modeling/backbones/plain.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..modules import BasicConv2d 3 | 4 | 5 | class Plain(nn.Module): 6 | 7 | def __init__(self, layers_cfg, in_channels=1): 8 | super(Plain, self).__init__() 9 | self.layers_cfg = layers_cfg 10 | self.in_channels = in_channels 11 | 12 | self.feature = self.make_layers() 13 | 14 | def forward(self, seqs): 15 | out = self.feature(seqs) 16 | return out 17 | 18 | def make_layers(self): 19 | 20 | def get_layer(cfg, in_c, kernel_size, stride, padding): 21 | cfg = cfg.split('-') 22 | typ = cfg[0] 23 | if typ not in ['BC', 'FC']: 24 | raise ValueError('Only support BC or FC, but got {}'.format(typ)) 25 | out_c = int(cfg[1]) 26 | 27 | if typ == 'BC': 28 | return BasicConv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding) 29 | 30 | Layers = [get_layer(self.layers_cfg[0], self.in_channels, 31 | 5, 1, 2), nn.LeakyReLU(inplace=True)] 32 | in_c = int(self.layers_cfg[0].split('-')[1]) 33 | for cfg in self.layers_cfg[1:]: 34 | if cfg == 'M': 35 | Layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 36 | else: 37 | conv2d = get_layer(cfg, in_c, 3, 1, 1) 38 | Layers += [conv2d, nn.LeakyReLU(inplace=True)] 39 | in_c = int(cfg.split('-')[1]) 40 | return nn.Sequential(*Layers) -------------------------------------------------------------------------------- /config/hstl.yaml: -------------------------------------------------------------------------------- 1 | data_cfg: 2 | dataset_name: CASIA-B 3 | dataset_root: your dataset path 4 | dataset_partition: ./misc/partitions/CASIA-B_include_005.json 5 | num_workers: 32 6 | remove_no_gallery: false 7 | test_dataset_name: CASIA-B 8 | 9 | evaluator_cfg: 10 | enable_distributed: true 11 | enable_float16: false 12 | restore_ckpt_strict: true 13 | restore_hint: 90000 14 | save_name: HSTL 15 | sampler: 16 | batch_size: 2 17 | sample_type: all_ordered 18 | type: InferenceSampler 19 | 20 | loss_cfg: 21 | - loss_term_weight: 1.0 22 | margin: 0.2 23 | type: TripletLoss 24 | log_prefix: triplet 25 | - loss_term_weight: 1.0 26 | scale: 1 27 | type: CrossEntropyLoss 28 | log_accuracy: true 29 | label_smooth: false 30 | log_prefix: softmax 31 | 32 | model_cfg: 33 | model: HSTL 34 | channels: [32, 64, 128] 35 | class_num: 74 36 | 37 | optimizer_cfg: 38 | lr: 1.0e-4 39 | solver: Adam 40 | weight_decay: 5.0e-4 41 | 42 | scheduler_cfg: 43 | gamma: 0.1 44 | milestones: 45 | - 70000 46 | scheduler: MultiStepLR 47 | 48 | trainer_cfg: 49 | enable_distributed: true 50 | enable_float16: true 51 | log_iter: 100 52 | restore_ckpt_strict: true 53 | restore_hint: 0 54 | save_iter: 10000 55 | save_name: HSTL 56 | sync_BN: true 57 | total_iter: 90000 58 | sampler: 59 | batch_shuffle: true 60 | batch_size: 61 | - 8 62 | - 8 63 | frames_num_fixed: 30 64 | frames_skip_num: 0 65 | sample_type: fixed_ordered 66 | type: TripletSampler 67 | -------------------------------------------------------------------------------- /config/hstl_oumvlp.yaml: -------------------------------------------------------------------------------- 1 | # Note : *** the batch_size should be equal to the gpus number at the test phase!!! *** 2 | data_cfg: 3 | dataset_name: OUMVLP 4 | dataset_root: your path 5 | dataset_partition: ./misc/partitions/OUMVLP.json 6 | num_workers: 1 7 | remove_no_gallery: false 8 | test_dataset_name: OUMVLP 9 | 10 | evaluator_cfg: 11 | enable_float16: true 12 | restore_ckpt_strict: true 13 | restore_hint: 250000 14 | save_name: HSTL_OU 15 | sampler: 16 | batch_size: 8 17 | sample_type: all_ordered 18 | type: InferenceSampler 19 | 20 | loss_cfg: 21 | - loss_term_weight: 1.0 22 | margin: 0.2 23 | type: TripletLoss 24 | log_prefix: triplet 25 | - loss_term_weight: 1.0 26 | scale: 1 27 | type: CrossEntropyLoss 28 | log_accuracy: true 29 | label_smooth: true 30 | log_prefix: softmax 31 | 32 | model_cfg: 33 | model: HSTL_OU 34 | channels: [32, 64, 128, 256] 35 | class_num: 5153 36 | 37 | optimizer_cfg: 38 | lr: 0.1 39 | momentum: 0.9 40 | solver: SGD 41 | weight_decay: 0.0005 42 | 43 | scheduler_cfg: 44 | gamma: 0.1 45 | milestones: 46 | - 150000 47 | - 200000 48 | scheduler: MultiStepLR 49 | 50 | trainer_cfg: 51 | enable_float16: true 52 | with_test: false 53 | log_iter: 100 54 | restore_ckpt_strict: true 55 | restore_hint: 0 56 | save_iter: 10000 57 | save_name: HSTL_OU 58 | sync_BN: true 59 | total_iter: 250000 60 | sampler: 61 | batch_shuffle: true 62 | batch_size: 63 | - 32 64 | - 8 65 | frames_num_fixed: 30 66 | frames_skip_num: 0 67 | sample_type: fixed_ordered 68 | type: TripletSampler 69 | 70 | -------------------------------------------------------------------------------- /config/hstl_gait3d.yaml: -------------------------------------------------------------------------------- 1 | # Note : *** the batch_size should be equal to the gpus number at the test phase!!! *** 2 | data_cfg: 3 | dataset_name: Gai3D 4 | dataset_root: your path 5 | dataset_partition: ./misc/partitions/Gait3D.json 6 | num_workers: 1 7 | remove_no_gallery: false 8 | test_dataset_name: Gait3D 9 | 10 | evaluator_cfg: 11 | enable_distributed: true 12 | enable_float16: false 13 | restore_ckpt_strict: true 14 | restore_hint: 210000 15 | save_name: HSTL_Gait3D 16 | eval_func: evaluation_Gait3D # identification_real_scene # identification_GREW_submission 17 | sampler: 18 | batch_size: 8 19 | sample_type: all_ordered 20 | type: InferenceSampler 21 | 22 | loss_cfg: 23 | - loss_term_weight: 1.0 24 | margin: 0.2 25 | type: TripletLoss 26 | log_prefix: triplet 27 | - loss_term_weight: 1.0 28 | scale: 16 29 | type: CrossEntropyLoss 30 | log_accuracy: true 31 | label_smooth: true 32 | log_prefix: softmax 33 | 34 | model_cfg: 35 | model: HSTL_Gait3D 36 | channels: [32, 64, 128, 256] 37 | class_num: 3000 38 | 39 | optimizer_cfg: 40 | lr: 0.1 41 | momentum: 0.9 42 | solver: SGD 43 | weight_decay: 0.0005 44 | 45 | scheduler_cfg: 46 | gamma: 0.1 47 | milestones: 48 | - 80000 49 | - 150000 50 | - 200000 51 | scheduler: MultiStepLR 52 | 53 | trainer_cfg: 54 | enable_float16: true 55 | with_test: false 56 | log_iter: 100 57 | restore_ckpt_strict: true 58 | restore_hint: 0 59 | save_iter: 10000 60 | save_name: HSTL_Gait3D 61 | sync_BN: true 62 | total_iter: 210000 63 | sampler: 64 | batch_shuffle: true 65 | batch_size: 66 | - 32 67 | - 8 68 | frames_num_fixed: 30 69 | frames_skip_num: 4 70 | sample_type: fixed_ordered 71 | type: TripletSampler 72 | -------------------------------------------------------------------------------- /lib/modeling/losses/softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .base import BaseLoss 5 | 6 | 7 | class CrossEntropyLoss(BaseLoss): 8 | def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False): 9 | super(CrossEntropyLoss, self).__init__(loss_term_weight) 10 | self.scale = scale 11 | self.label_smooth = label_smooth 12 | self.eps = eps 13 | self.log_accuracy = log_accuracy 14 | 15 | def forward(self, logits, labels): 16 | """ 17 | logits: [n, v, c] 18 | labels: [n] 19 | """ 20 | logits = logits.permute(1, 0, 2).contiguous() # [n, v, c] -> [p, n, c] 21 | p, _, c = logits.size() 22 | log_preds = F.log_softmax(logits * self.scale, dim=-1) # [v, n, c] 23 | one_hot_labels = self.label2one_hot( 24 | labels, c).unsqueeze(0).repeat(p, 1, 1) # [v, n, c] 25 | loss = self.compute_loss(log_preds, one_hot_labels) 26 | self.info.update({'loss': loss.detach().clone()}) 27 | if self.log_accuracy: 28 | pred = logits.argmax(dim=-1) # [v, n] 29 | accu = (pred == labels.unsqueeze(0)).float().mean() 30 | self.info.update({'accuracy': accu}) 31 | return loss, self.info 32 | 33 | def compute_loss(self, predis, labels): 34 | softmax_loss = -(labels * predis).sum(-1) # [v, n] 35 | losses = softmax_loss.mean(-1) 36 | 37 | if self.label_smooth: 38 | smooth_loss = - predis.mean(dim=-1) # [v, n] 39 | smooth_loss = smooth_loss.mean() # [v] 40 | smooth_loss = smooth_loss * self.eps 41 | losses = smooth_loss + losses * (1. - self.eps) 42 | return losses 43 | 44 | def label2one_hot(self, label, class_num): 45 | label = label.unsqueeze(-1) 46 | batch_size = label.size(0) 47 | device = label.device 48 | return torch.zeros(batch_size, class_num).to(device).scatter(1, label, 1) 49 | -------------------------------------------------------------------------------- /lib/data/transform.py: -------------------------------------------------------------------------------- 1 | from data import transform as base_transform 2 | import numpy as np 3 | 4 | from utils import is_list, is_dict, get_valid_args 5 | 6 | 7 | class NoOperation(): 8 | def __call__(self, x): 9 | return x 10 | 11 | 12 | class BaseSilTransform(): 13 | def __init__(self, disvor=255.0, img_shape=None): 14 | self.disvor = disvor 15 | self.img_shape = img_shape 16 | 17 | def __call__(self, x): 18 | if self.img_shape is not None: 19 | s = x.shape[0] 20 | _ = [s] + [*self.img_shape] 21 | x = x.reshape(*_) 22 | return x / self.disvor 23 | 24 | 25 | class BaseSilCuttingTransform(): 26 | def __init__(self, img_w=64, disvor=255.0, cutting=None): 27 | self.img_w = img_w 28 | self.disvor = disvor 29 | self.cutting = cutting 30 | 31 | def __call__(self, x): 32 | if self.cutting is not None: 33 | cutting = self.cutting 34 | else: 35 | cutting = int(self.img_w // 64) * 10 36 | x = x[..., cutting:-cutting] 37 | return x / self.disvor 38 | 39 | 40 | class BaseRgbTransform(): 41 | def __init__(self, mean=None, std=None): 42 | if mean is None: 43 | mean = [0.485*255, 0.456*255, 0.406*255] 44 | if std is None: 45 | std = [0.229*255, 0.224*255, 0.225*255] 46 | self.mean = np.array(mean).reshape((1, 3, 1, 1)) 47 | self.std = np.array(std).reshape((1, 3, 1, 1)) 48 | 49 | def __call__(self, x): 50 | return (x - self.mean) / self.std 51 | 52 | 53 | def get_transform(trf_cfg=None): 54 | if is_dict(trf_cfg): 55 | transform = getattr(base_transform, trf_cfg['type']) 56 | valid_trf_arg = get_valid_args(transform, trf_cfg, ['type']) 57 | return transform(**valid_trf_arg) 58 | if trf_cfg is None: 59 | return lambda x: x 60 | if is_list(trf_cfg): 61 | transform = [get_transform(cfg) for cfg in trf_cfg] 62 | return transform 63 | raise "Error type for -Transform-Cfg-" 64 | -------------------------------------------------------------------------------- /lib/modeling/loss_aggregator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import losses 3 | from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module 4 | from utils import Odict 5 | from utils import get_msg_mgr 6 | 7 | 8 | class LossAggregator(): 9 | 10 | def __init__(self, loss_cfg) -> None: 11 | 12 | self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \ 13 | else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg} 14 | 15 | def _build_loss_(self, loss_cfg): 16 | 17 | Loss = get_attr_from([losses], loss_cfg['type']) 18 | valid_loss_arg = get_valid_args( 19 | Loss, loss_cfg, ['type', 'gather_and_scale']) 20 | loss = get_ddp_module(Loss(**valid_loss_arg).cuda()) 21 | return loss 22 | 23 | def __call__(self, training_feats): 24 | 25 | loss_sum = .0 26 | loss_info = Odict() 27 | 28 | for k, v in training_feats.items(): 29 | if k in self.losses: 30 | loss_func = self.losses[k] 31 | loss, info = loss_func(**v) 32 | for name, value in info.items(): 33 | loss_info['scalar/%s/%s' % (k, name)] = value 34 | loss = loss.mean() * loss_func.loss_term_weight 35 | loss_sum += loss 36 | 37 | else: 38 | if isinstance(v, dict): 39 | raise ValueError( 40 | "The key %s in -Trainng-Feat- should be stated as the log_prefix of a certain loss defined in your loss_cfg."%v 41 | ) 42 | elif is_tensor(v): 43 | _ = v.mean() 44 | loss_info['scalar/%s' % k] = _ 45 | loss_sum += _ 46 | get_msg_mgr().log_debug( 47 | "Please check whether %s needed in training." % k) 48 | else: 49 | raise ValueError( 50 | "Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.") 51 | 52 | return loss_sum, loss_info -------------------------------------------------------------------------------- /misc/partitions/CASIA-B.json: -------------------------------------------------------------------------------- 1 | { 2 | "TRAIN_SET": [ 3 | "001", 4 | "002", 5 | "003", 6 | "004", 7 | "006", 8 | "007", 9 | "008", 10 | "009", 11 | "010", 12 | "011", 13 | "012", 14 | "013", 15 | "014", 16 | "015", 17 | "016", 18 | "017", 19 | "018", 20 | "019", 21 | "020", 22 | "021", 23 | "022", 24 | "023", 25 | "024", 26 | "025", 27 | "026", 28 | "027", 29 | "028", 30 | "029", 31 | "030", 32 | "031", 33 | "032", 34 | "033", 35 | "034", 36 | "035", 37 | "036", 38 | "037", 39 | "038", 40 | "039", 41 | "040", 42 | "041", 43 | "042", 44 | "043", 45 | "044", 46 | "045", 47 | "046", 48 | "047", 49 | "048", 50 | "049", 51 | "050", 52 | "051", 53 | "052", 54 | "053", 55 | "054", 56 | "055", 57 | "056", 58 | "057", 59 | "058", 60 | "059", 61 | "060", 62 | "061", 63 | "062", 64 | "063", 65 | "064", 66 | "065", 67 | "066", 68 | "067", 69 | "068", 70 | "069", 71 | "070", 72 | "071", 73 | "072", 74 | "073", 75 | "074" 76 | ], 77 | "TEST_SET": [ 78 | "075", 79 | "076", 80 | "077", 81 | "078", 82 | "079", 83 | "080", 84 | "081", 85 | "082", 86 | "083", 87 | "084", 88 | "085", 89 | "086", 90 | "087", 91 | "088", 92 | "089", 93 | "090", 94 | "091", 95 | "092", 96 | "093", 97 | "094", 98 | "095", 99 | "096", 100 | "097", 101 | "098", 102 | "099", 103 | "100", 104 | "101", 105 | "102", 106 | "103", 107 | "104", 108 | "105", 109 | "106", 110 | "107", 111 | "108", 112 | "109", 113 | "110", 114 | "111", 115 | "112", 116 | "113", 117 | "114", 118 | "115", 119 | "116", 120 | "117", 121 | "118", 122 | "119", 123 | "120", 124 | "121", 125 | "122", 126 | "123", 127 | "124" 128 | ] 129 | } -------------------------------------------------------------------------------- /misc/partitions/CASIA-B_include_005.json: -------------------------------------------------------------------------------- 1 | { 2 | "TRAIN_SET": [ 3 | "001", 4 | "002", 5 | "003", 6 | "004", 7 | "005", 8 | "006", 9 | "007", 10 | "008", 11 | "009", 12 | "010", 13 | "011", 14 | "012", 15 | "013", 16 | "014", 17 | "015", 18 | "016", 19 | "017", 20 | "018", 21 | "019", 22 | "020", 23 | "021", 24 | "022", 25 | "023", 26 | "024", 27 | "025", 28 | "026", 29 | "027", 30 | "028", 31 | "029", 32 | "030", 33 | "031", 34 | "032", 35 | "033", 36 | "034", 37 | "035", 38 | "036", 39 | "037", 40 | "038", 41 | "039", 42 | "040", 43 | "041", 44 | "042", 45 | "043", 46 | "044", 47 | "045", 48 | "046", 49 | "047", 50 | "048", 51 | "049", 52 | "050", 53 | "051", 54 | "052", 55 | "053", 56 | "054", 57 | "055", 58 | "056", 59 | "057", 60 | "058", 61 | "059", 62 | "060", 63 | "061", 64 | "062", 65 | "063", 66 | "064", 67 | "065", 68 | "066", 69 | "067", 70 | "068", 71 | "069", 72 | "070", 73 | "071", 74 | "072", 75 | "073", 76 | "074" 77 | ], 78 | "TEST_SET": [ 79 | "075", 80 | "076", 81 | "077", 82 | "078", 83 | "079", 84 | "080", 85 | "081", 86 | "082", 87 | "083", 88 | "084", 89 | "085", 90 | "086", 91 | "087", 92 | "088", 93 | "089", 94 | "090", 95 | "091", 96 | "092", 97 | "093", 98 | "094", 99 | "095", 100 | "096", 101 | "097", 102 | "098", 103 | "099", 104 | "100", 105 | "101", 106 | "102", 107 | "103", 108 | "104", 109 | "105", 110 | "106", 111 | "107", 112 | "108", 113 | "109", 114 | "110", 115 | "111", 116 | "112", 117 | "113", 118 | "114", 119 | "115", 120 | "116", 121 | "117", 122 | "118", 123 | "119", 124 | "120", 125 | "121", 126 | "122", 127 | "123", 128 | "124" 129 | ] 130 | } -------------------------------------------------------------------------------- /lib/modeling/losses/triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .base import BaseLoss, gather_and_scale_wrapper 5 | 6 | 7 | class TripletLoss(BaseLoss): 8 | def __init__(self, margin, loss_term_weight=1.0): 9 | super(TripletLoss, self).__init__(loss_term_weight) 10 | self.margin = margin 11 | 12 | @gather_and_scale_wrapper 13 | def forward(self, embeddings, labels): 14 | # embeddings: [n, v, c], label: [n] 15 | embeddings = embeddings.permute( 16 | 1, 0, 2).contiguous() # [n, v, c] -> [v, n, c] 17 | embeddings = embeddings.float() 18 | 19 | ref_embed, ref_label = embeddings, labels 20 | dist = self.ComputeDistance(embeddings, ref_embed) # [v, n1, n2] 21 | mean_dist = dist.mean(1).mean(1) 22 | ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist) 23 | dist_diff = ap_dist - an_dist 24 | loss = F.relu(dist_diff + self.margin) 25 | 26 | hard_loss = torch.max(loss, -1)[0] 27 | loss_avg, loss_num = self.AvgNonZeroReducer(loss) 28 | 29 | self.info.update({ 30 | 'loss': loss_avg.detach().clone(), 31 | 'hard_loss': hard_loss.detach().clone(), 32 | 'loss_num': loss_num.detach().clone(), 33 | 'mean_dist': mean_dist.detach().clone()}) 34 | 35 | return loss_avg, self.info 36 | 37 | def AvgNonZeroReducer(self, loss): 38 | eps = 1.0e-9 39 | loss_sum = loss.sum(-1) 40 | loss_num = (loss != 0).sum(-1).float() 41 | 42 | loss_avg = loss_sum / (loss_num + eps) 43 | loss_avg[loss_num == 0] = 0 44 | return loss_avg, loss_num 45 | 46 | def ComputeDistance(self, x, y): 47 | """ 48 | x: [v, n_x, c] 49 | y: [v, n_y, c] 50 | """ 51 | x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [v, n_x, 1] 52 | y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [v, 1, n_y] 53 | inner = x.matmul(y.transpose(-1, -2)) # [v, n_x, n_y] 54 | dist = x2 + y2 - 2 * inner 55 | dist = torch.sqrt(F.relu(dist)) # [v, n_x, n_y] 56 | return dist 57 | 58 | def Convert2Triplets(self, row_labels, clo_label, dist): 59 | """ 60 | row_labels: tensor with size [n_r] 61 | clo_label : tensor with size [n_c] 62 | """ 63 | matches = (row_labels.unsqueeze(1) == 64 | clo_label.unsqueeze(0)).byte() # [n_r, n_c] 65 | diffenc = matches ^ 1 # [n_r, n_c] 66 | mask = matches.unsqueeze(2) * diffenc.unsqueeze(1) 67 | a_idx, p_idx, n_idx = torch.where(mask) 68 | 69 | ap_dist = dist[:, a_idx, p_idx] 70 | an_dist = dist[:, a_idx, n_idx] 71 | return ap_dist, an_dist 72 | -------------------------------------------------------------------------------- /lib/main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | from modeling import models 7 | from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr 8 | 9 | parser = argparse.ArgumentParser(description='Main program ') 10 | parser.add_argument('--local_rank', type=int, default=0, 11 | help="passed by torch.distributed.launch module") 12 | parser.add_argument('--cfgs', type=str, 13 | default='config/default.yaml', help="path of config file") 14 | parser.add_argument('--phase', default='train', 15 | choices=['train', 'test'], help="choose train or test phase") 16 | parser.add_argument('--log_to_file', action='store_true', 17 | help="log to file, default path is: output/////.txt") 18 | parser.add_argument('--iter', default=0, help="iter to restore") 19 | opt = parser.parse_args() 20 | 21 | 22 | def initialization(cfgs, training): 23 | msg_mgr = get_msg_mgr() 24 | engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg'] 25 | output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'], 26 | cfgs['model_cfg']['model'], engine_cfg['save_name']) 27 | if training: 28 | msg_mgr.init_manager(output_path, opt.log_to_file, engine_cfg['log_iter'], 29 | engine_cfg['restore_hint'] if isinstance(engine_cfg['restore_hint'], (int)) else 0) 30 | else: 31 | msg_mgr.init_logger(output_path, opt.log_to_file) 32 | 33 | msg_mgr.log_info(engine_cfg) 34 | 35 | seed = torch.distributed.get_rank() 36 | init_seeds(seed) 37 | 38 | 39 | def run_model(cfgs, training): 40 | msg_mgr = get_msg_mgr() 41 | model_cfg = cfgs['model_cfg'] 42 | msg_mgr.log_info(model_cfg) 43 | Model = getattr(models, model_cfg['model']) 44 | model = Model(cfgs, training) 45 | if training and cfgs['trainer_cfg']['sync_BN']: 46 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 47 | model = get_ddp_module(model) 48 | msg_mgr.log_info(params_count(model)) 49 | msg_mgr.log_info("Model Initialization Finished!") 50 | 51 | if training: 52 | Model.run_train(model) 53 | else: 54 | Model.run_test(model) 55 | 56 | 57 | if __name__ == '__main__': 58 | torch.distributed.init_process_group('nccl', init_method='env://') 59 | if torch.distributed.get_world_size() != torch.cuda.device_count(): 60 | raise ValueError("Expect number of availuable GPUs({}) equals to the world size({}).".format( 61 | torch.distributed.get_world_size(), torch.cuda.device_count())) 62 | cfgs = config_loader(opt.cfgs) 63 | if opt.iter != 0: 64 | cfgs['evaluator_cfg']['restore_hint'] = int(opt.iter) 65 | cfgs['trainer_cfg']['restore_hint'] = int(opt.iter) 66 | 67 | training = (opt.phase == 'train') 68 | initialization(cfgs, training) 69 | run_model(cfgs, training) -------------------------------------------------------------------------------- /lib/data/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.distributed as dist 4 | import torch.utils.data as tordata 5 | 6 | 7 | class TripletSampler(tordata.sampler.Sampler): 8 | def __init__(self, dataset, batch_size, batch_shuffle=False): 9 | self.dataset = dataset 10 | self.batch_size = batch_size 11 | self.batch_shuffle = batch_shuffle 12 | 13 | self.world_size = dist.get_world_size() 14 | self.rank = dist.get_rank() 15 | 16 | def __iter__(self): 17 | while True: 18 | sample_indices = [] 19 | pid_list = sync_random_sample_list( 20 | self.dataset.label_set, self.batch_size[0]) 21 | 22 | for pid in pid_list: 23 | indices = self.dataset.indices_dict[pid] 24 | indices = sync_random_sample_list( 25 | indices, k=self.batch_size[1]) 26 | sample_indices += indices 27 | 28 | if self.batch_shuffle: 29 | sample_indices = sync_random_sample_list( 30 | sample_indices, len(sample_indices)) 31 | 32 | total_batch_size = self.batch_size[0] * self.batch_size[1] 33 | total_size = int(math.ceil(total_batch_size / 34 | self.world_size)) * self.world_size 35 | sample_indices += sample_indices[:( 36 | total_batch_size - len(sample_indices))] 37 | 38 | sample_indices = sample_indices[self.rank:total_size:self.world_size] 39 | yield sample_indices 40 | 41 | def __len__(self): 42 | return len(self.dataset) 43 | 44 | 45 | def sync_random_sample_list(obj_list, k): 46 | idx = torch.randperm(len(obj_list))[:k] 47 | if torch.cuda.is_available(): 48 | idx = idx.cuda() 49 | torch.distributed.broadcast(idx, src=0) 50 | idx = idx.tolist() 51 | return [obj_list[i] for i in idx] 52 | 53 | 54 | class InferenceSampler(tordata.sampler.Sampler): 55 | def __init__(self, dataset, batch_size): 56 | self.dataset = dataset 57 | self.batch_size = batch_size 58 | 59 | self.size = len(dataset) 60 | indices = list(range(self.size)) 61 | 62 | world_size = dist.get_world_size() 63 | rank = dist.get_rank() 64 | 65 | if batch_size % world_size != 0: 66 | raise ValueError("World size({}) is not divisible by batch_size({})".format( 67 | world_size, batch_size)) 68 | 69 | if batch_size != 1: 70 | complement_size = math.ceil(self.size / batch_size) * \ 71 | batch_size 72 | indices += indices[:(complement_size - self.size)] 73 | self.size = complement_size 74 | 75 | batch_size_per_rank = int(self.batch_size / world_size) 76 | indx_batch_per_rank = [] 77 | 78 | for i in range(int(self.size / batch_size_per_rank)): 79 | indx_batch_per_rank.append( 80 | indices[i*batch_size_per_rank:(i+1)*batch_size_per_rank]) 81 | 82 | self.idx_batch_this_rank = indx_batch_per_rank[rank::world_size] 83 | 84 | def __iter__(self): 85 | yield from self.idx_batch_this_rank 86 | 87 | def __len__(self): 88 | return len(self.dataset) 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HSTL 2 | 3 | This is the code for the paper [Hierarchical Spatio-Temporal Representation Learning for Gait Recognition](https://openaccess.thecvf.com/content/ICCV2023/papers/Wang_Hierarchical_Spatio-Temporal_Representation_Learning_for_Gait_Recognition_ICCV_2023_paper.pdf). 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchical-spatio-temporal-representation/gait-recognition-in-the-wild-on-gait3d)](https://paperswithcode.com/sota/gait-recognition-in-the-wild-on-gait3d?p=hierarchical-spatio-temporal-representation) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchical-spatio-temporal-representation/gait-recognition-on-gait3d)](https://paperswithcode.com/sota/gait-recognition-on-gait3d?p=hierarchical-spatio-temporal-representation) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchical-spatio-temporal-representation/gait-recognition-on-oumvlp)](https://paperswithcode.com/sota/gait-recognition-on-oumvlp?p=hierarchical-spatio-temporal-representation) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchical-spatio-temporal-representation/multiview-gait-recognition-on-casia-b)](https://paperswithcode.com/sota/multiview-gait-recognition-on-casia-b?p=hierarchical-spatio-temporal-representation) 9 | 10 | # Operating Environments 11 | ## Hardware Environment 12 | Our code is running on a server with 8 GeForce RTX 3090 GPUs 13 | and a CPU model Intel(R) Core(TM) i7-9800X @ 3.80GHz. 14 | ## Software Environment 15 | - pytorch = 1.10 16 | - torchvision 17 | - pyyaml 18 | - tensorboard 19 | - opencv-python 20 | - tqdm 21 | 22 | # Checkpoints 23 | * The checkpoints for CASIA-B [link](https://drive.google.com/file/d/1keZBtWr9O8gfeqBB9qHNbZ-96Eh6LggB/view?usp=sharing) 24 | * The checkpoints for OUMVLP [link](https://drive.google.com/file/d/1VNYC0QbHxw1aaBTFLj4DMIC2D36B1-ng/view?usp=sharing) 25 | 26 | # Train and test 27 | ## Train 28 | Train a model by 29 | ``` 30 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/htsl/hstl.yaml --phase train 31 | ``` 32 | - `python -m torch.distributed.launch` [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) launch instruction. 33 | - `--nproc_per_node` The number of gpus to use, and it must equal the length of `CUDA_VISIBLE_DEVICES`. 34 | - `--cfgs` The path to config file. 35 | - `--phase` Specified as `train`. 36 | 37 | - `--log_to_file` If specified, the terminal log will be written on disk simultaneously. 38 | 39 | You can run commands in [train.sh](train.sh) for training different models. 40 | 41 | ## Test 42 | Evaluate the trained model by 43 | ``` 44 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/htsl/hstl.yaml --phase test 45 | ``` 46 | - `--phase` Specified as `test`. 47 | - `--iter` Specify a iteration checkpoint. 48 | 49 | **Tip**: Other arguments are the same as train phase. 50 | 51 | You can run commands in [test.sh](test.sh) for testing different models. 52 | 53 | # Acknowledgement 54 | * The codebase is based on [OpenGait](https://github.com/ShiqiYu/OpenGait). 55 | 56 | # Citation 57 | ``` 58 | @InProceedings{Wang_2023_ICCV, 59 | author = {Wang, Lei and Liu, Bo and Liang, Fangfang and Wang, Bincheng}, 60 | title = {Hierarchical Spatio-Temporal Representation Learning for Gait Recognition}, 61 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 62 | month = {October}, 63 | year = {2023}, 64 | pages = {19639-19649} 65 | } 66 | ``` 67 | -------------------------------------------------------------------------------- /lib/modeling/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils import clones, is_list_or_tuple 6 | 7 | class BasicConv2d(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 11 | stride=stride, padding=padding, bias=False, **kwargs) 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | return x 16 | 17 | class PackSequenceWrapper(nn.Module): 18 | def __init__(self, pooling_func): 19 | super(PackSequenceWrapper, self).__init__() 20 | self.pooling_func = pooling_func 21 | 22 | def forward(self, seqs, seqL, seq_dim=1, **kwargs): 23 | """ 24 | In seqs: [n, s, ...] 25 | Out rets: [n, ...] 26 | """ 27 | if seqL is None: 28 | return self.pooling_func(seqs, **kwargs) 29 | seqL = seqL[0].data.cpu().numpy().tolist() 30 | start = [0] + np.cumsum(seqL).tolist()[:-1] 31 | 32 | rets = [] 33 | for curr_start, curr_seqL in zip(start, seqL): 34 | narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL) 35 | # save the memory 36 | # splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1) 37 | # ret = [] 38 | # for seq_to_pooling in splited_narrowed_seq: 39 | # ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs) 40 | # [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs)) 41 | rets.append(self.pooling_func(narrowed_seq, **kwargs)) 42 | if len(rets) > 0 and is_list_or_tuple(rets[0]): 43 | return [torch.cat([ret[j] for ret in rets]) 44 | for j in range(len(rets[0]))] 45 | return torch.cat(rets) 46 | 47 | 48 | class SeparateFCs(nn.Module): 49 | def __init__(self, parts_num, in_channels, out_channels, norm=False): 50 | super(SeparateFCs, self).__init__() 51 | self.p = parts_num 52 | self.fc_bin = nn.Parameter( 53 | nn.init.xavier_uniform_( 54 | torch.zeros(parts_num, in_channels, out_channels))) 55 | self.norm = norm 56 | 57 | def forward(self, x): 58 | if self.norm: 59 | out = x.matmul(F.normalize(self.fc_bin, dim=1)) 60 | else: 61 | out = x.matmul(self.fc_bin) 62 | return out 63 | 64 | 65 | class BasicConv3d(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs): 67 | super(BasicConv3d, self).__init__() 68 | self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, 69 | stride=stride, padding=padding, bias=bias, **kwargs) 70 | 71 | def forward(self, ipts): 72 | ''' 73 | ipts: [n, c, d, h, w] 74 | outs: [n, c, d, h, w] 75 | ''' 76 | outs = self.conv3d(ipts) 77 | return outs 78 | 79 | class BasicConv2d(nn.Module): 80 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs): 81 | super(BasicConv2d, self).__init__() 82 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 83 | stride=stride, padding=padding, bias=False, **kwargs) 84 | 85 | def forward(self, x): 86 | x = self.conv(x) 87 | return x 88 | 89 | 90 | def RmBN2dAffine(model): 91 | for m in model.modules(): 92 | if isinstance(m, nn.BatchNorm2d): 93 | m.weight.requires_grad = False 94 | m.bias.requires_grad = False -------------------------------------------------------------------------------- /lib/utils/msg_manager.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | import numpy as np 5 | import torchvision.utils as vutils 6 | import os.path as osp 7 | from time import strftime, localtime 8 | 9 | from torch.utils.tensorboard import SummaryWriter 10 | from .common import is_list, is_tensor, ts2np, mkdir, Odict, NoOp 11 | import logging 12 | 13 | 14 | class MessageManager: 15 | def __init__(self): 16 | self.info_dict = Odict() 17 | self.writer_hparams = ['image', 'scalar'] 18 | self.time = time.time() 19 | 20 | def init_manager(self, save_path, log_to_file, log_iter, iteration=0): 21 | self.iteration = iteration 22 | self.log_iter = log_iter 23 | mkdir(osp.join(save_path, "summary/")) 24 | self.writer = SummaryWriter( 25 | osp.join(save_path, "summary/"), purge_step=self.iteration) 26 | self.init_logger(save_path, log_to_file) 27 | 28 | def init_logger(self, save_path, log_to_file): 29 | # init logger 30 | self.logger = logging.getLogger('gait') 31 | self.logger.setLevel(logging.INFO) 32 | self.logger.propagate = False 33 | formatter = logging.Formatter( 34 | fmt='[%(asctime)s] [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 35 | if log_to_file: 36 | mkdir(osp.join(save_path, "logs/")) 37 | vlog = logging.FileHandler( 38 | osp.join(save_path, "logs/", strftime('%Y-%m-%d-%H-%M-%S', localtime())+'.txt')) 39 | vlog.setLevel(logging.INFO) 40 | vlog.setFormatter(formatter) 41 | self.logger.addHandler(vlog) 42 | 43 | console = logging.StreamHandler() 44 | console.setFormatter(formatter) 45 | console.setLevel(logging.DEBUG) 46 | self.logger.addHandler(console) 47 | 48 | def append(self, info): 49 | for k, v in info.items(): 50 | v = [v] if not is_list(v) else v 51 | v = [ts2np(_) if is_tensor(_) else _ for _ in v] 52 | info[k] = v 53 | self.info_dict.append(info) 54 | 55 | def flush(self): 56 | self.info_dict.clear() 57 | self.writer.flush() 58 | 59 | def write_to_tensorboard(self, summary): 60 | 61 | for k, v in summary.items(): 62 | module_name = k.split('/')[0] 63 | if module_name not in self.writer_hparams: 64 | self.log_warning( 65 | 'Not Expected --Summary-- type [{}] appear!!!{}'.format(k, self.writer_hparams)) 66 | continue 67 | board_name = k.replace(module_name + "/", '') 68 | writer_module = getattr(self.writer, 'add_' + module_name) 69 | v = v.detach() if is_tensor(v) else v 70 | v = vutils.make_grid( 71 | v, normalize=True, scale_each=True) if 'image' in module_name else v 72 | if module_name == 'scalar': 73 | try: 74 | v = v.mean() 75 | except: 76 | v = v 77 | writer_module(board_name, v, self.iteration) 78 | 79 | def log_training_info(self): 80 | now = time.time() 81 | string = "Iteration {:0>5}, Cost {:.2f}s".format( 82 | self.iteration, now-self.time, end="") 83 | for i, (k, v) in enumerate(self.info_dict.items()): 84 | if 'scalar' not in k: 85 | continue 86 | k = k.replace('scalar/', '').replace('/', '_') 87 | end = "\n" if i == len(self.info_dict)-1 else "" 88 | string += ", {0}={1:.4f}".format(k, np.mean(v), end=end) 89 | self.log_info(string) 90 | self.reset_time() 91 | 92 | def reset_time(self): 93 | self.time = time.time() 94 | 95 | def train_step(self, info, summary): 96 | self.iteration += 1 97 | self.append(info) 98 | if self.iteration % self.log_iter == 0: 99 | self.log_training_info() 100 | self.flush() 101 | self.write_to_tensorboard(summary) 102 | 103 | def log_debug(self, *args, **kwargs): 104 | self.logger.debug(*args, **kwargs) 105 | 106 | def log_info(self, *args, **kwargs): 107 | self.logger.info(*args, **kwargs) 108 | 109 | def log_warning(self, *args, **kwargs): 110 | self.logger.warning(*args, **kwargs) 111 | 112 | 113 | msg_mgr = MessageManager() 114 | noop = NoOp() 115 | 116 | 117 | def get_msg_mgr(): 118 | if torch.distributed.get_rank() > 0: 119 | return noop 120 | else: 121 | return msg_mgr 122 | -------------------------------------------------------------------------------- /lib/data/collate_fn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | from utils import get_msg_mgr 5 | 6 | 7 | class CollateFn(object): 8 | def __init__(self, label_set, sample_config): 9 | self.label_set = label_set 10 | sample_type = sample_config['sample_type'] 11 | sample_type = sample_type.split('_') 12 | self.sampler = sample_type[0] 13 | self.ordered = sample_type[1] 14 | if self.sampler not in ['fixed', 'unfixed', 'all']: 15 | raise ValueError 16 | if self.ordered not in ['ordered', 'unordered']: 17 | raise ValueError 18 | self.ordered = sample_type[1] == 'ordered' 19 | 20 | # fixed cases 21 | if self.sampler == 'fixed': 22 | self.frames_num_fixed = sample_config['frames_num_fixed'] 23 | 24 | # unfixed cases 25 | if self.sampler == 'unfixed': 26 | self.frames_num_max = sample_config['frames_num_max'] 27 | self.frames_num_min = sample_config['frames_num_min'] 28 | 29 | if self.sampler != 'all' and self.ordered: 30 | self.frames_skip_num = sample_config['frames_skip_num'] 31 | 32 | self.frames_all_limit = -1 33 | if self.sampler == 'all' and 'frames_all_limit' in sample_config: 34 | self.frames_all_limit = sample_config['frames_all_limit'] 35 | 36 | def __call__(self, batch): 37 | batch_size = len(batch) 38 | feature_num = len(batch[0][0]) 39 | seqs_batch, labs_batch, typs_batch, vies_batch = [], [], [], [] 40 | 41 | for bt in batch: 42 | seqs_batch.append(bt[0]) 43 | labs_batch.append(self.label_set.index(bt[1][0])) 44 | typs_batch.append(bt[1][1]) 45 | vies_batch.append(bt[1][2]) 46 | 47 | global count 48 | count = 0 49 | 50 | def sample_frames(seqs): 51 | global count 52 | sampled_fras = [[] for i in range(feature_num)] 53 | seq_len = len(seqs[0]) 54 | indices = list(range(seq_len)) 55 | 56 | if self.sampler in ['fixed', 'unfixed']: 57 | if self.sampler == 'fixed': 58 | frames_num = self.frames_num_fixed 59 | else: 60 | frames_num = random.choice( 61 | list(range(self.frames_num_min, self.frames_num_max+1))) 62 | 63 | if self.ordered: 64 | fs_n = frames_num + self.frames_skip_num 65 | if seq_len < fs_n: 66 | it = math.ceil(fs_n / seq_len) 67 | seq_len = seq_len * it 68 | indices = indices * it 69 | 70 | start = random.choice(list(range(0, seq_len - fs_n + 1))) 71 | end = start + fs_n 72 | idx_lst = list(range(seq_len)) 73 | idx_lst = idx_lst[start:end] 74 | idx_lst = sorted(np.random.choice( 75 | idx_lst, frames_num, replace=False)) 76 | indices = [indices[i] for i in idx_lst] 77 | else: 78 | replace = seq_len < frames_num 79 | 80 | if seq_len == 0: 81 | get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.' 82 | % (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count]))) 83 | 84 | count += 1 85 | indices = np.random.choice( 86 | indices, frames_num, replace=replace) 87 | 88 | for i in range(feature_num): 89 | for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices: 90 | sampled_fras[i].append(seqs[i][j]) 91 | return sampled_fras 92 | 93 | # f: feature_num 94 | # b: batch_size 95 | # p: batch_size_per_gpu 96 | # g: gpus_num 97 | fras_batch = [sample_frames(seqs) for seqs in seqs_batch] # [b, f] 98 | batch = [fras_batch, labs_batch, typs_batch, vies_batch, None] 99 | 100 | if self.sampler == "fixed": 101 | fras_batch = [[np.asarray(fras_batch[i][j]) for i in range(batch_size)] 102 | for j in range(feature_num)] # [f, b] 103 | else: 104 | seqL_batch = [[len(fras_batch[i][0]) 105 | for i in range(batch_size)]] # [1, p] 106 | 107 | def my_cat(k): return np.concatenate( 108 | [fras_batch[i][k] for i in range(batch_size)], 0) 109 | fras_batch = [[my_cat(k)] for k in range(feature_num)] # [f, g] 110 | 111 | batch[-1] = np.asarray(seqL_batch) 112 | 113 | batch[0] = fras_batch 114 | return batch 115 | -------------------------------------------------------------------------------- /lib/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import strftime, localtime 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from utils import get_msg_mgr, mkdir 7 | 8 | 9 | def cuda_dist(x, y, metric='euc'): 10 | x = torch.from_numpy(x).cuda() 11 | y = torch.from_numpy(y).cuda() 12 | if metric == 'cos': 13 | x = F.normalize(x, p=2, dim=2) # n v c 14 | y = F.normalize(y, p=2, dim=2) # n v c 15 | num_bin = x.size(1) 16 | n_x = x.size(0) 17 | n_y = y.size(0) 18 | dist = torch.zeros(n_x, n_y).cuda() 19 | for i in range(num_bin): 20 | _x = x[:, i, ...] 21 | _y = y[:, i, ...] 22 | if metric == 'cos': 23 | dist += torch.matmul(_x, _y.transpose(0, 1)) 24 | else: 25 | _dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze( 26 | 1).transpose(0, 1) - 2 * torch.matmul(_x, _y.transpose(0, 1)) 27 | dist += torch.sqrt(F.relu(_dist)) 28 | return 1 - dist/num_bin if metric == 'cos' else dist / num_bin 29 | 30 | # Exclude identical-view cases 31 | 32 | 33 | def de_diag(acc, each_angle=False): 34 | dividend = acc.shape[1] - 1. 35 | result = np.sum(acc - np.diag(np.diag(acc)), 1) / dividend 36 | if not each_angle: 37 | result = np.mean(result) 38 | return result 39 | 40 | 41 | 42 | def identification(data, dataset, metric='euc'): 43 | msg_mgr = get_msg_mgr() 44 | feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views'] 45 | label = np.array(label) 46 | 47 | view_list = list(set(view)) 48 | view_list.sort() 49 | view_num = len(view_list) 50 | 51 | probe_seq_dict = {'CASIA-B': [['nm-05', 'nm-06'], ['bg-01', 'bg-02'], ['cl-01', 'cl-02']], 52 | 'OUMVLP': [['00']]} 53 | 54 | gallery_seq_dict = {'CASIA-B': [['nm-01', 'nm-02', 'nm-03', 'nm-04']], 55 | 'OUMVLP': [['01']]} 56 | if dataset not in (probe_seq_dict or gallery_seq_dict): 57 | raise KeyError("DataSet %s hasn't been supported !" % dataset) 58 | num_rank = 5 59 | acc = np.zeros([len(probe_seq_dict[dataset]), 60 | view_num, view_num, num_rank]) - 1. 61 | for (p, probe_seq) in enumerate(probe_seq_dict[dataset]): 62 | for gallery_seq in gallery_seq_dict[dataset]: 63 | for (v1, probe_view) in enumerate(view_list): 64 | for (v2, gallery_view) in enumerate(view_list): 65 | gseq_mask = np.isin(seq_type, gallery_seq) & np.isin( 66 | view, [gallery_view]) 67 | gallery_x = feature[gseq_mask, :] 68 | gallery_y = label[gseq_mask] 69 | 70 | pseq_mask = np.isin(seq_type, probe_seq) & np.isin( 71 | view, [probe_view]) 72 | probe_x = feature[pseq_mask, :] 73 | probe_y = label[pseq_mask] 74 | 75 | dist = cuda_dist(probe_x, gallery_x, metric) 76 | # print('dis',dist.shape[0]) 77 | idx = dist.sort(1)[1].cpu().numpy() 78 | acc[p, v1, v2, :] = np.round( 79 | np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, 80 | 0) * 100 / dist.shape[0], 2) 81 | result_dict = {} 82 | np.set_printoptions(precision=3, suppress=True) 83 | if 'OUMVLP' not in dataset: 84 | for i in range(1): 85 | msg_mgr.log_info( 86 | '===Rank-%d (Include identical-view cases)===' % (i + 1)) 87 | msg_mgr.log_info('NM: %.1f,\tBG: %.1f,\tCL: %.1f' % ( 88 | np.mean(acc[0, :, :, i]), 89 | np.mean(acc[1, :, :, i]), 90 | np.mean(acc[2, :, :, i]))) 91 | for i in range(1): 92 | msg_mgr.log_info( 93 | '===Rank-%d (Exclude identical-view cases)===' % (i + 1)) 94 | msg_mgr.log_info('NM: %.1f,\tBG: %.1f,\tCL: %.1f' % ( 95 | de_diag(acc[0, :, :, i]), 96 | de_diag(acc[1, :, :, i]), 97 | de_diag(acc[2, :, :, i]))) 98 | result_dict["scalar/test_accuracy/NM"] = de_diag(acc[0, :, :, i]) 99 | result_dict["scalar/test_accuracy/BG"] = de_diag(acc[1, :, :, i]) 100 | result_dict["scalar/test_accuracy/CL"] = de_diag(acc[2, :, :, i]) 101 | np.set_printoptions(precision=2, floatmode='fixed') 102 | for i in range(1): 103 | msg_mgr.log_info( 104 | '===Rank-%d of each angle (Exclude identical-view cases)===' % (i + 1)) 105 | msg_mgr.log_info('NM: {}'.format(de_diag(acc[0, :, :, i], True))) 106 | msg_mgr.log_info('BG: {}'.format(de_diag(acc[1, :, :, i], True))) 107 | msg_mgr.log_info('CL: {}'.format(de_diag(acc[2, :, :, i], True))) 108 | else: 109 | msg_mgr.log_info('===Rank-1 (Include identical-view cases)===') 110 | msg_mgr.log_info('NM: %.1f ' % (np.mean(acc[0, :, :, 0]))) 111 | msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===') 112 | msg_mgr.log_info('NM: %.1f ' % (de_diag(acc[0, :, :, 0]))) 113 | msg_mgr.log_info( 114 | '===Rank-1 of each angle (Exclude identical-view cases)===') 115 | msg_mgr.log_info('NM: {}'.format(de_diag(acc[0, :, :, 0], True))) 116 | result_dict["scalar/test_accuracy/NM"] = de_diag(acc[0, :, :, 0]) 117 | return result_dict 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /lib/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import os.path as osp 4 | import torch.utils.data as tordata 5 | import json 6 | from utils import get_msg_mgr 7 | 8 | 9 | class DataSet(tordata.Dataset): 10 | def __init__(self, data_cfg, training): 11 | """ 12 | seqs_info: the list with each element indicating 13 | a certain gait sequence presented as [label, type, view, paths]; 14 | """ 15 | self.__dataset_parser(data_cfg, training) 16 | self.cache = data_cfg['cache'] 17 | self.label_list = [seq_info[0] for seq_info in self.seqs_info] 18 | self.types_list = [seq_info[1] for seq_info in self.seqs_info] 19 | self.views_list = [seq_info[2] for seq_info in self.seqs_info] 20 | 21 | self.label_set = sorted(list(set(self.label_list))) 22 | self.types_set = sorted(list(set(self.types_list))) 23 | self.views_set = sorted(list(set(self.views_list))) 24 | self.seqs_data = [None] * len(self) 25 | self.indices_dict = {label: [] for label in self.label_set} 26 | for i, seq_info in enumerate(self.seqs_info): 27 | self.indices_dict[seq_info[0]].append(i) 28 | if self.cache: 29 | self.__load_all_data() 30 | 31 | def __len__(self): 32 | return len(self.seqs_info) 33 | 34 | def __loader__(self, paths): 35 | paths = sorted(paths) 36 | data_list = [] 37 | for pth in paths: 38 | if pth.endswith('.pkl'): 39 | with open(pth, 'rb') as f: 40 | _ = pickle.load(f) 41 | f.close() 42 | else: 43 | raise ValueError('- Loader - just support .pkl !!!') 44 | data_list.append(_) 45 | for idx, data in enumerate(data_list): 46 | if len(data) != len(data_list[0]): 47 | raise ValueError( 48 | 'Each input data({}) should have the same length.'.format(paths[idx])) 49 | if len(data) == 0: 50 | raise ValueError( 51 | 'Each input data({}) should have at least one element.'.format(paths[idx])) 52 | return data_list 53 | 54 | def __getitem__(self, idx): 55 | if not self.cache: 56 | data_list = self.__loader__(self.seqs_info[idx][-1]) 57 | elif self.seqs_data[idx] is None: 58 | data_list = self.__loader__(self.seqs_info[idx][-1]) 59 | self.seqs_data[idx] = data_list 60 | else: 61 | data_list = self.seqs_data[idx] 62 | seq_info = self.seqs_info[idx] 63 | return data_list, seq_info 64 | 65 | def __load_all_data(self): 66 | for idx in range(len(self)): 67 | self.__getitem__(idx) 68 | 69 | def __dataset_parser(self, data_config, training): 70 | dataset_root = data_config['dataset_root'] 71 | try: 72 | data_in_use = data_config['data_in_use'] # [n], true or false 73 | except: 74 | data_in_use = None 75 | 76 | with open(data_config['dataset_partition'], "rb") as f: 77 | partition = json.load(f) 78 | train_set = partition["TRAIN_SET"] 79 | test_set = partition["TEST_SET"] 80 | label_list = os.listdir(dataset_root) 81 | train_set = [label for label in train_set if label in label_list] 82 | test_set = [label for label in test_set if label in label_list] 83 | miss_pids = [label for label in label_list if label not in ( 84 | train_set + test_set)] 85 | msg_mgr = get_msg_mgr() 86 | 87 | def log_pid_list(pid_list): 88 | if len(pid_list) >= 3: 89 | msg_mgr.log_info('[%s, %s, ..., %s]' % 90 | (pid_list[0], pid_list[1], pid_list[-1])) 91 | else: 92 | msg_mgr.log_info(pid_list) 93 | 94 | if len(miss_pids) > 0: 95 | msg_mgr.log_debug('-------- Miss Pid List --------') 96 | msg_mgr.log_debug(miss_pids) 97 | if training: 98 | msg_mgr.log_info("-------- Train Pid List --------") 99 | log_pid_list(train_set) 100 | else: 101 | msg_mgr.log_info("-------- Test Pid List --------") 102 | log_pid_list(test_set) 103 | 104 | def get_seqs_info_list(label_set): 105 | seqs_info_list = [] 106 | for lab in label_set: 107 | for typ in sorted(os.listdir(osp.join(dataset_root, lab))): 108 | for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))): 109 | seq_info = [lab, typ, vie] 110 | seq_path = osp.join(dataset_root, *seq_info) 111 | seq_dirs = sorted(os.listdir(seq_path)) 112 | if seq_dirs != []: 113 | seq_dirs = [osp.join(seq_path, dir) 114 | for dir in seq_dirs] 115 | if data_in_use is not None: 116 | seq_dirs = [dir for dir, use_bl in zip( 117 | seq_dirs, data_in_use) if use_bl] 118 | seqs_info_list.append([*seq_info, seq_dirs]) 119 | else: 120 | msg_mgr.log_debug( 121 | 'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie)) 122 | return seqs_info_list 123 | 124 | self.seqs_info = get_seqs_info_list( 125 | train_set) if training else get_seqs_info_list(test_set) 126 | -------------------------------------------------------------------------------- /lib/utils/common.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import inspect 4 | import logging 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.autograd as autograd 9 | import yaml 10 | import random 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | from collections import OrderedDict, namedtuple 13 | 14 | 15 | class NoOp: 16 | def __getattr__(self, *args): 17 | def no_op(*args, **kwargs): pass 18 | return no_op 19 | 20 | 21 | class Odict(OrderedDict): 22 | def append(self, odict): 23 | dst_keys = self.keys() 24 | for k, v in odict.items(): 25 | if not is_list(v): 26 | v = [v] 27 | if k in dst_keys: 28 | if is_list(self[k]): 29 | self[k] += v 30 | else: 31 | self[k] = [self[k]] + v 32 | else: 33 | self[k] = v 34 | 35 | 36 | def Ntuple(description, keys, values): 37 | if not is_list_or_tuple(keys): 38 | keys = [keys] 39 | values = [values] 40 | Tuple = namedtuple(description, keys) 41 | return Tuple._make(values) 42 | 43 | 44 | def get_valid_args(obj, input_args, free_keys=[]): 45 | if inspect.isfunction(obj): 46 | expected_keys = inspect.getargspec(obj)[0] 47 | elif inspect.isclass(obj): 48 | expected_keys = inspect.getargspec(obj.__init__)[0] 49 | else: 50 | raise ValueError('Just support function and class object!') 51 | unexpect_keys = list() 52 | expected_args = {} 53 | for k, v in input_args.items(): 54 | if k in expected_keys: 55 | expected_args[k] = v 56 | elif k in free_keys: 57 | pass 58 | else: 59 | unexpect_keys.append(k) 60 | if unexpect_keys != []: 61 | logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" % 62 | (', '.join(unexpect_keys), obj.__name__)) 63 | return expected_args 64 | 65 | 66 | def get_attr_from(sources, name): 67 | try: 68 | return getattr(sources[0], name) 69 | except: 70 | return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name) 71 | 72 | 73 | def is_list_or_tuple(x): 74 | return isinstance(x, (list, tuple)) 75 | 76 | 77 | def is_bool(x): 78 | return isinstance(x, bool) 79 | 80 | 81 | def is_str(x): 82 | return isinstance(x, str) 83 | 84 | 85 | def is_list(x): 86 | return isinstance(x, list) or isinstance(x, nn.ModuleList) 87 | 88 | 89 | def is_dict(x): 90 | return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict) 91 | 92 | 93 | def is_tensor(x): 94 | return isinstance(x, torch.Tensor) 95 | 96 | 97 | def is_array(x): 98 | return isinstance(x, np.ndarray) 99 | 100 | 101 | def ts2np(x): 102 | return x.cpu().data.numpy() 103 | 104 | 105 | def ts2var(x, **kwargs): 106 | return autograd.Variable(x, **kwargs).cuda() 107 | 108 | 109 | def np2var(x, **kwargs): 110 | return ts2var(torch.from_numpy(x), **kwargs) 111 | 112 | 113 | def list2var(x, **kwargs): 114 | return np2var(np.array(x), **kwargs) 115 | 116 | 117 | def mkdir(path): 118 | if not os.path.exists(path): 119 | os.makedirs(path) 120 | 121 | 122 | def MergeCfgsDict(src, dst): 123 | for k, v in src.items(): 124 | if (k not in dst.keys()) or (type(v) != type(dict())): 125 | dst[k] = v 126 | else: 127 | if is_dict(src[k]) and is_dict(dst[k]): 128 | MergeCfgsDict(src[k], dst[k]) 129 | else: 130 | dst[k] = v 131 | 132 | 133 | def clones(module, N): 134 | "Produce N identical layers." 135 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 136 | 137 | 138 | def config_loader(path): 139 | with open(path, 'r') as stream: 140 | src_cfgs = yaml.safe_load(stream) 141 | with open("./config/default.yaml", 'r') as stream: 142 | dst_cfgs = yaml.safe_load(stream) 143 | MergeCfgsDict(src_cfgs, dst_cfgs) 144 | return dst_cfgs 145 | 146 | 147 | def init_seeds(seed=0, cuda_deterministic=True): 148 | random.seed(seed) 149 | np.random.seed(seed) 150 | torch.manual_seed(seed) 151 | torch.cuda.manual_seed_all(seed) 152 | if cuda_deterministic: # slower, more reproducible 153 | torch.backends.cudnn.deterministic = True 154 | torch.backends.cudnn.benchmark = False 155 | else: # faster, less reproducible 156 | torch.backends.cudnn.deterministic = False 157 | torch.backends.cudnn.benchmark = True 158 | 159 | 160 | def handler(signum, frame): 161 | logging.info('Ctrl+c/z pressed') 162 | os.system( 163 | "kill $(ps aux | grep main.py | grep -v grep | awk '{print $2}') ") 164 | logging.info('process group flush!') 165 | 166 | 167 | def ddp_all_gather(features, dim=0, requires_grad=True): 168 | ''' 169 | inputs: [n, ...] 170 | ''' 171 | 172 | world_size = torch.distributed.get_world_size() 173 | rank = torch.distributed.get_rank() 174 | feature_list = [torch.ones_like(features) for _ in range(world_size)] 175 | torch.distributed.all_gather(feature_list, features.contiguous()) 176 | 177 | if requires_grad: 178 | feature_list[rank] = features 179 | feature = torch.cat(feature_list, dim=dim) 180 | return feature 181 | 182 | 183 | class DDPPassthrough(DDP): 184 | def __getattr__(self, name): 185 | try: 186 | return super().__getattr__(name) 187 | except AttributeError: 188 | return getattr(self.module, name) 189 | 190 | 191 | def get_ddp_module(module, **kwargs): 192 | if len(list(module.parameters())) == 0: 193 | # for the case that loss module has not parameters. 194 | return module 195 | device = torch.cuda.current_device() 196 | module = DDPPassthrough(module, device_ids=[device], output_device=device, 197 | find_unused_parameters=False, **kwargs) 198 | return module 199 | 200 | 201 | def params_count(net): 202 | n_parameters = sum(p.numel() for p in net.parameters()) 203 | return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6) 204 | -------------------------------------------------------------------------------- /lib/modeling/models/HSTL-CB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base_model import BaseModel 6 | from ..modules import SeparateFCs, BasicConv3d, PackSequenceWrapper, BasicConv2d 7 | 8 | # frame-level temporal aggregation (FTA) 9 | # frame-level temporal aggregation (FTA) 10 | class FTA(nn.Module): 11 | 12 | def __init__(self, channel=64, FP_kernels=[3, 5]): 13 | super().__init__() 14 | self.FP_3 = nn.MaxPool3d(kernel_size=(FP_kernels[0],1,1),stride=(3,1,1),padding=(0,0,0)) 15 | self.FP_5 = nn.MaxPool3d(kernel_size=(FP_kernels[1],1,1),stride=(3,1,1),padding=(1,0,0)) 16 | self.gap = nn.AdaptiveAvgPool3d((None,1,1)) 17 | self.fcs = nn.ModuleList([]) 18 | for i in range(len(FP_kernels)): 19 | self.fcs.append(nn.Conv1d(channel,channel,1)) 20 | self.softmax = nn.Softmax(dim=0) 21 | 22 | def forward(self, x): 23 | bs, c, t,_, _ = x.size() 24 | aggregate_outs = [] 25 | outs1 = self.FP_3(x) 26 | aggregate_outs.append(outs1) 27 | outs2 = self.FP_5(x) 28 | aggregate_outs.append(outs2) 29 | aggregate_features = torch.stack(aggregate_outs, 0) 30 | hat_out_mid = sum(aggregate_outs) 31 | hat_out = self.gap(hat_out_mid).squeeze(-1).squeeze(-1) 32 | temporal = hat_out.size(-1) 33 | weights = [] 34 | for fc in self.fcs: 35 | weight = fc(hat_out) 36 | weights.append(weight.view(bs, c, temporal, 1, 1)) 37 | select_weights = torch.stack(weights, 0) 38 | select_weights = self.softmax(select_weights) 39 | outs = (select_weights * aggregate_features).sum(0) 40 | return outs 41 | 42 | class FTA_Block(nn.Module): 43 | def __init__(self, split_param, m, in_channels): 44 | super(FTA_Block, self).__init__() 45 | self.split_param = split_param 46 | self.m = m 47 | self.mma = nn.ModuleList([ 48 | FTA(channel=in_channels, FP_kernels=[3, 5]) 49 | for i in range(self.m)]) 50 | def forward(self, x): 51 | feat = x.split(self.split_param, 3) 52 | feat = torch.cat([self.mma[i](_) for i, _ in enumerate(feat)], 3) 53 | return feat 54 | 55 | # adaptive region-based motion extractor (ARME) 56 | class ARME_Conv(nn.Module): 57 | def __init__(self, in_channels, out_channels, split_param ,m, kernel_size=(3, 3, 3), stride=(1, 1, 1), 58 | padding=(1, 1, 1),bias=False,**kwargs): 59 | super(ARME_Conv, self).__init__() 60 | self.m = m 61 | 62 | self.split_param = split_param 63 | 64 | self.conv3d = nn.ModuleList([ 65 | BasicConv3d(in_channels, out_channels, kernel_size, stride, padding,bias ,**kwargs) 66 | for i in range(self.m)]) 67 | 68 | 69 | def forward(self, x): 70 | ''' 71 | x: [n, c, s, h, w] 72 | ''' 73 | feat = x.split(self.split_param, 3) 74 | feat = torch.cat([self.conv3d[i](_) for i, _ in enumerate(feat)], 3) 75 | feat = F.leaky_relu(feat) 76 | return feat 77 | 78 | # Generalized Mean Pooling (GeM) 79 | class GeMHPP(nn.Module): 80 | def __init__(self, bin_num=[64], p=6.5, eps=1.0e-6): 81 | super(GeMHPP, self).__init__() 82 | self.bin_num = bin_num 83 | self.p = nn.Parameter( 84 | torch.ones(1)*p) 85 | self.eps = eps 86 | 87 | def gem(self, ipts): 88 | return F.avg_pool2d(ipts.clamp(min=self.eps).pow(self.p), (1, ipts.size(-1))).pow(1. / self.p) 89 | 90 | def forward(self, x): 91 | """ 92 | x : [n, c, h, w] 93 | ret: [n, c, p] 94 | """ 95 | n, c = x.size()[:2] 96 | features = [] 97 | for b in self.bin_num: 98 | z = x.view(n, c, b, -1) 99 | z = self.gem(z).squeeze(-1) 100 | features.append(z) 101 | return torch.cat(features, -1) 102 | 103 | 104 | # adaptive spatio-temporal pooling 105 | class ASTP(nn.Module): 106 | def __init__(self, split_param, m, in_channels, out_channels, flag=True): 107 | super(ASTP, self).__init__() 108 | self.split_param = split_param 109 | self.m = m 110 | self.hpp = nn.ModuleList([ 111 | GeMHPP(bin_num=[1]) for i in range(self.m)]) 112 | 113 | self.flag = flag 114 | if self.flag: 115 | self.proj = BasicConv2d(in_channels, out_channels, 1, 1, 0) 116 | 117 | 118 | self.SP1 = PackSequenceWrapper(torch.max) 119 | def forward(self, x, seqL): 120 | x = self.SP1(x, seqL=seqL, options={"dim": 2})[0] 121 | if self.flag: 122 | x = self.proj(x) 123 | feat = x.split(self.split_param, 2) 124 | feat = torch.cat([self.hpp[i](_) for i, _ in enumerate(feat)], -1) 125 | return feat 126 | 127 | class HSTL(BaseModel): 128 | """ 129 | Hierarchical Spatio-Temporal Feature Learning for Gait Recognition 130 | """ 131 | 132 | def __init__(self, *args, **kargs): 133 | super(HSTL, self).__init__(*args, **kargs) 134 | 135 | def build_network(self, model_cfg): 136 | in_c = model_cfg['channels'] 137 | class_num = model_cfg['class_num'] 138 | # For CASIA-B dataset. 139 | self.arme1 = nn.Sequential( 140 | BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3), 141 | stride=(1, 1, 1), padding=(1, 1, 1)), 142 | nn.LeakyReLU(inplace=True) 143 | ) 144 | 145 | self.astp1 = ASTP(split_param=[64], m=1, in_channels=in_c[0], out_channels=in_c[-1]) 146 | 147 | self.arme2 = nn.Sequential( 148 | ARME_Conv(in_c[0], in_c[0], split_param=[40, 24], m=2, kernel_size=( 149 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 150 | ARME_Conv(in_c[0], in_c[1], split_param=[40, 24], m=2, kernel_size=( 151 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 152 | ) 153 | 154 | self.astp2 = ASTP(split_param=[40, 24], m=2, in_channels=in_c[1], out_channels=in_c[-1]) 155 | 156 | self.fta = FTA_Block(split_param=[40, 24], m=2, in_channels=in_c[1]) 157 | 158 | self.astp2_fta = ASTP(split_param=[40, 24], m=2, in_channels=in_c[1], out_channels=in_c[-1]) 159 | 160 | self.arme3 = nn.Sequential( 161 | ARME_Conv(in_c[1], in_c[2], split_param=[8, 32, 16, 8], m=4, kernel_size=( 162 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 163 | ARME_Conv(in_c[2], in_c[2], split_param=[8, 32, 16, 8], m=4, kernel_size=( 164 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 165 | ) 166 | 167 | self.astp3 = ASTP(split_param=[8, 32, 16, 8], m=4, in_channels=in_c[2], out_channels=in_c[-1], flag=False) 168 | 169 | # self.astp4 = ASTP(split_param=[1,1,1,1,1,1,1,1, 170 | # 1,1,1,1,1,1,1,1, 171 | # 1,1,1,1,1,1,1,1, 172 | # 1,1,1,1,1,1,1,1, 173 | # 1,1,1,1,1,1,1,1, 174 | # 1,1,1,1,1,1,1,1, 175 | # 1,1,1,1,1,1,1,1, 176 | # 1,1,1,1,1,1,1,1], m=64, in_channels=in_c[2], out_channels=in_c[-1]) 177 | self.HPP = GeMHPP() 178 | 179 | # separable fully connected layer (SeFC) 180 | self.Head0 = SeparateFCs(73, in_c[-1], in_c[-1]) 181 | # batchnorm layer (BN) 182 | self.Bn = nn.BatchNorm1d(in_c[-1]) 183 | # separable fully connected layer (SeFC) 184 | self.Head1 = SeparateFCs(73, in_c[-1], class_num) 185 | # Temporal Pooling (TP) 186 | self.TP = PackSequenceWrapper(torch.max) 187 | 188 | def forward(self, inputs): 189 | ipts, labs, _, _, seqL = inputs 190 | seqL = None if not self.training else seqL 191 | if not self.training and len(labs) != 1: 192 | raise ValueError( 193 | 'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs))) 194 | sils = ipts[0].unsqueeze(1) 195 | del ipts 196 | n, _, s, h, w = sils.size() 197 | if s < 3: 198 | repeat = 3 if s == 1 else 2 199 | sils = sils.repeat(1, 1, repeat, 1, 1) 200 | outs = self.arme1(sils) 201 | astp1 = self.astp1(outs, seqL) 202 | outs = self.arme2(outs) 203 | astp2 = self.astp2(outs, seqL) 204 | outs = self.fta(outs) 205 | astp2_fta = self.astp2_fta(outs, seqL) 206 | outs = self.arme3(outs) 207 | astp3 = self.astp3(outs, seqL) 208 | astp4 = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w] 209 | astp4 = self.HPP(astp4) 210 | # astp4 = self.astp4(outs, seqL) 211 | outs = torch.cat([astp1,astp2, astp2_fta, astp3, astp4], dim=-1) # [n, c, p] 212 | outs = outs.permute(2, 0, 1).contiguous() # [p, n, c] 213 | gait = self.Head0(outs) # [p, n, c] 214 | gait = gait.permute(1, 2, 0).contiguous() # [n, c, p] 215 | bnft = self.Bn(gait) # [n, c, p] 216 | logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c] 217 | 218 | gait = gait.permute(0, 2, 1).contiguous() # [n, p, c] 219 | bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c] 220 | logi = logi.permute(1, 0, 2).contiguous() # [n, p, c] 221 | # print(logi.size()) 222 | 223 | n, _, s, h, w = sils.size() 224 | retval = { 225 | 'training_feat': { 226 | 'triplet': {'embeddings': bnft, 'labels': labs}, 227 | 'softmax': {'logits': logi, 'labels': labs} 228 | }, 229 | 'visual_summary': { 230 | 'image/sils': sils.view(n * s, 1, h, w) 231 | }, 232 | 'inference_feat': { 233 | 'embeddings': bnft 234 | } 235 | } 236 | return retval 237 | -------------------------------------------------------------------------------- /lib/modeling/models/HSTL-OU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base_model import BaseModel 6 | from ..modules import SeparateFCs, BasicConv3d, PackSequenceWrapper, BasicConv2d 7 | import einops 8 | 9 | class FTA(nn.Module): 10 | 11 | def __init__(self, channel=64, FP_kernels=[3, 5]): 12 | super().__init__() 13 | self.FP_3 = nn.MaxPool3d(kernel_size=(FP_kernels[0],1,1),stride=(3,1,1),padding=(0,0,0)) 14 | self.FP_5 = nn.MaxPool3d(kernel_size=(FP_kernels[1],1,1),stride=(3,1,1),padding=(1,0,0)) 15 | self.gap = nn.AdaptiveAvgPool3d((None,1,1)) 16 | self.fcs = nn.ModuleList([]) 17 | for i in range(len(FP_kernels)): 18 | self.fcs.append(nn.Conv1d(channel,channel,1)) 19 | self.softmax = nn.Softmax(dim=0) 20 | 21 | def forward(self, x): 22 | bs, c, t,_, _ = x.size() 23 | aggregate_outs = [] 24 | outs1 = self.FP_3(x) 25 | aggregate_outs.append(outs1) 26 | outs2 = self.FP_5(x) 27 | aggregate_outs.append(outs2) 28 | aggregate_features = torch.stack(aggregate_outs, 0) 29 | hat_out_mid = sum(aggregate_outs) 30 | hat_out = self.gap(hat_out_mid).squeeze(-1).squeeze(-1) 31 | temporal = hat_out.size(-1) 32 | weights = [] 33 | for fc in self.fcs: 34 | weight = fc(hat_out) 35 | weights.append(weight.view(bs, c, temporal, 1, 1)) 36 | select_weights = torch.stack(weights, 0) 37 | select_weights = self.softmax(select_weights) 38 | outs = (select_weights * aggregate_features).sum(0) 39 | return outs 40 | 41 | # adaptive region-based motion extractor (ARME) 42 | class ARME_Conv(nn.Module): 43 | def __init__(self, in_channels, out_channels, split_param ,m, kernel_size=(3, 3, 3), stride=(1, 1, 1), 44 | padding=(1, 1, 1),bias=False,**kwargs): 45 | super(ARME_Conv, self).__init__() 46 | self.m = m 47 | 48 | self.split_param = split_param 49 | 50 | self.conv3d = nn.ModuleList([ 51 | BasicConv3d(in_channels, out_channels, kernel_size, stride, padding,bias ,**kwargs) 52 | for i in range(self.m)]) 53 | 54 | 55 | def forward(self, x): 56 | ''' 57 | x: [n, c, s, h, w] 58 | ''' 59 | feat = x.split(self.split_param, 3) 60 | feat = torch.cat([self.conv3d[i](_) for i, _ in enumerate(feat)], 3) 61 | feat = F.leaky_relu(feat) 62 | return feat 63 | 64 | 65 | # Generalized Mean Pooling (GeM) 66 | class GeMHPP(nn.Module): 67 | def __init__(self, bin_num=[32], p=6.5, eps=1.0e-6): 68 | super(GeMHPP, self).__init__() 69 | self.bin_num = bin_num 70 | self.p = nn.Parameter( 71 | torch.ones(1)*p) 72 | self.eps = eps 73 | 74 | def gem(self, ipts): 75 | return F.avg_pool2d(ipts.clamp(min=self.eps).pow(self.p), (1, ipts.size(-1))).pow(1. / self.p) 76 | 77 | def forward(self, x): 78 | """ 79 | x : [n, c, h, w] 80 | ret: [n, c, p] 81 | """ 82 | n, c = x.size()[:2] 83 | features = [] 84 | for b in self.bin_num: 85 | z = x.view(n, c, b, -1) 86 | z = self.gem(z).squeeze(-1) 87 | features.append(z) 88 | return torch.cat(features, -1) 89 | 90 | class MGeMHPP(nn.Module): 91 | def __init__(self, split_param, m): 92 | super(MGeMHPP, self).__init__() 93 | self.split_param = split_param 94 | self.m = m 95 | self.hpp = nn.ModuleList([ 96 | GeMHPP(bin_num=[1]) for i in range(self.m)]) 97 | def forward(self, x): 98 | feat = x.split(self.split_param, 2) 99 | # print(feat[0].size()) 100 | feat = torch.cat([self.hpp[i](_) for i, _ in enumerate(feat)], -1) 101 | return feat 102 | 103 | 104 | class FTA_Block(nn.Module): 105 | def __init__(self, split_param, m, in_channels): 106 | super(FTA_Block, self).__init__() 107 | self.split_param = split_param 108 | self.m = m 109 | self.mma = nn.ModuleList([ 110 | FTA(channel=in_channels, FP_kernels=[3, 5]) 111 | for i in range(self.m)]) 112 | def forward(self, x): 113 | feat = x.split(self.split_param, 3) 114 | feat = torch.cat([self.mma[i](_) for i, _ in enumerate(feat)], 3) 115 | return feat 116 | 117 | # adaptive spatio-temporal pooling 118 | class ASTP(nn.Module): 119 | def __init__(self, split_param, m, in_channels, out_channels, flag=True): 120 | super(ASTP, self).__init__() 121 | self.split_param = split_param 122 | self.m = m 123 | self.hpp = nn.ModuleList([ 124 | GeMHPP(bin_num=[1]) for i in range(self.m)]) 125 | 126 | self.flag = flag 127 | if self.flag: 128 | self.proj = BasicConv2d(in_channels, out_channels, 1, 1, 0) 129 | 130 | 131 | self.SP1 = PackSequenceWrapper(torch.max) 132 | def forward(self, x, seqL): 133 | x = self.SP1(x, seqL=seqL, options={"dim": 2})[0] 134 | if self.flag: 135 | x = self.proj(x) 136 | feat = x.split(self.split_param, 2) 137 | feat = torch.cat([self.hpp[i](_) for i, _ in enumerate(feat)], -1) 138 | return feat 139 | 140 | class HSTL_OU(BaseModel): 141 | """ 142 | GaitGL: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation 143 | Arxiv : https://arxiv.org/pdf/2011.01461.pdf 144 | """ 145 | 146 | def __init__(self, *args, **kargs): 147 | super(HSTL_OU, self).__init__(*args, **kargs) 148 | 149 | def build_network(self, model_cfg): 150 | in_c = model_cfg['channels'] 151 | class_num = model_cfg['class_num'] 152 | dataset_name = self.cfgs['data_cfg']['dataset_name'] 153 | 154 | if dataset_name in ['OUMVLP']: 155 | # For OUMVLP and GREW 156 | self.arme1 = nn.Sequential( 157 | BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3), 158 | stride=(1, 1, 1), padding=(1, 1, 1)), 159 | nn.LeakyReLU(inplace=True), 160 | BasicConv3d(in_c[0], in_c[0], kernel_size=(3, 3, 3), 161 | stride=(1, 1, 1), padding=(1, 1, 1)), 162 | nn.LeakyReLU(inplace=True) 163 | ) 164 | 165 | self.astp1 = ASTP(split_param=[64], m=1, in_channels=in_c[0], out_channels=in_c[-1]) 166 | 167 | self.MaxPool0 = nn.MaxPool3d( 168 | kernel_size=(1, 2, 2), stride=(1, 2, 2)) 169 | 170 | self.arme2 = nn.Sequential( 171 | ARME_Conv(in_c[0], in_c[1], split_param=[4, 24, 4], m=3, kernel_size=( 172 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 173 | ARME_Conv(in_c[1], in_c[1], split_param=[4, 24, 4], m=3, kernel_size=( 174 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 175 | ) 176 | 177 | 178 | self.astp2 = ASTP(split_param=[4, 24, 4], m=3, in_channels=in_c[1], out_channels=in_c[-1]) 179 | 180 | self.fta = FTA_Block(split_param=[4, 24, 4], m=3, in_channels=in_c[1]) 181 | 182 | self.astp2_fta = ASTP(split_param=[4, 24, 4], m=3, in_channels=in_c[1], out_channels=in_c[-1]) 183 | 184 | # self.MaxPool0 = nn.MaxPool3d( 185 | # kernel_size=(1, 2, 2), stride=(1, 2, 2)) 186 | 187 | self.arme3 = nn.Sequential( 188 | ARME_Conv(in_c[1], in_c[2], split_param=[4, 4, 12, 8, 4], m=5, kernel_size=( 189 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 190 | ARME_Conv(in_c[2], in_c[2], split_param=[4, 4, 12, 8, 4], m=5, kernel_size=( 191 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 192 | ) 193 | 194 | self.astp3 = ASTP(split_param=[4, 4, 12, 8, 4], m=5, in_channels=in_c[2], out_channels=in_c[-1]) 195 | 196 | # self.MaxPool1 = nn.MaxPool3d( 197 | # kernel_size=(1, 2, 2), stride=(1, 2, 2)) 198 | 199 | self.arme4 = nn.Sequential( 200 | ARME_Conv(in_c[2], in_c[3], split_param=[4, 4, 12, 4, 4, 4], m=6, kernel_size=( 201 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 202 | ARME_Conv(in_c[3], in_c[3], split_param=[4, 4, 12, 4, 4, 4], m=6, kernel_size=( 203 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 204 | ) 205 | 206 | self.astp4 = ASTP(split_param=[4, 4, 12, 4, 4, 4], m=6, in_channels=in_c[3], out_channels=in_c[-1], flag=False) 207 | 208 | self.TP = PackSequenceWrapper(torch.max) 209 | 210 | self.HPP = GeMHPP() 211 | # self.mHPP = MGeMHPP(split_param=[40, 24], m=2) 212 | 213 | self.Head0 = SeparateFCs(50, in_c[-1], in_c[-1]) 214 | 215 | self.Bn = nn.BatchNorm1d(in_c[-1]) 216 | self.Head1 = SeparateFCs(50, in_c[-1], class_num) 217 | 218 | def forward(self, inputs): 219 | ipts, labs, _, _, seqL = inputs 220 | seqL = None if not self.training else seqL 221 | if not self.training and len(labs) != 1: 222 | raise ValueError( 223 | 'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs))) 224 | sils = ipts[0].unsqueeze(1) 225 | del ipts 226 | n, _, s, h, w = sils.size() 227 | if s < 3: 228 | repeat = 3 if s == 1 else 2 229 | sils = sils.repeat(1, 1, repeat, 1, 1) 230 | 231 | outs = self.arme1(sils) 232 | astp1 = self.astp1(outs, seqL) 233 | outs = self.MaxPool0(outs) 234 | outs = self.arme2(outs) 235 | astp2 = self.astp2(outs, seqL) 236 | outs = self.fta(outs) 237 | astp2_fta = self.astp2_fta(outs, seqL) 238 | outs = self.arme3(outs) 239 | astp3 = self.astp3(outs, seqL) 240 | outs = self.arme4(outs) 241 | astp4 = self.astp4(outs, seqL) 242 | astp5 = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w] 243 | astp5 = self.HPP(astp5) 244 | # astp4 = self.astp4(outs, seqL) 245 | outs = torch.cat([astp1, astp2, astp2_fta, astp3, astp4, astp5], dim=-1) # [n, c, p] 246 | 247 | gait = self.Head0(outs) # [n, c, p] 248 | # print(gait.size()) 249 | 250 | bnft = self.Bn(gait) # [n, c, p] 251 | logi = self.Head1(bnft) # [n, c, p] 252 | 253 | gait = gait.permute(0, 2, 1).contiguous() # [n, p, c] 254 | bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c] 255 | logi = logi.permute(1, 0, 2).contiguous() # [n, p, c] 256 | # print(logi.size()) 257 | 258 | n, _, s, h, w = sils.size() 259 | retval = { 260 | 'training_feat': { 261 | 'triplet': {'embeddings': bnft, 'labels': labs}, 262 | 'softmax': {'logits': logi, 'labels': labs} 263 | }, 264 | 'visual_summary': { 265 | 'image/sils': sils.view(n * s, 1, h, w) 266 | }, 267 | 'inference_feat': { 268 | 'embeddings': bnft 269 | } 270 | } 271 | return retval 272 | -------------------------------------------------------------------------------- /lib/modeling/models/HSTL-Gait3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..base_model import BaseModel 6 | from ..modules import SeparateFCs, BasicConv3d, PackSequenceWrapper, BasicConv2d 7 | import einops 8 | 9 | class FTA(nn.Module): 10 | 11 | def __init__(self, channel=64, FP_kernels=[3, 5]): 12 | super().__init__() 13 | self.FP_3 = nn.MaxPool3d(kernel_size=(FP_kernels[0],1,1),stride=(3,1,1),padding=(0,0,0)) 14 | self.FP_5 = nn.MaxPool3d(kernel_size=(FP_kernels[1],1,1),stride=(3,1,1),padding=(1,0,0)) 15 | self.gap = nn.AdaptiveAvgPool3d((None,1,1)) 16 | self.fcs = nn.ModuleList([]) 17 | for i in range(len(FP_kernels)): 18 | self.fcs.append(nn.Conv1d(channel,channel,1)) 19 | self.softmax = nn.Softmax(dim=0) 20 | 21 | def forward(self, x): 22 | bs, c, t,_, _ = x.size() 23 | aggregate_outs = [] 24 | outs1 = self.FP_3(x) 25 | aggregate_outs.append(outs1) 26 | outs2 = self.FP_5(x) 27 | aggregate_outs.append(outs2) 28 | aggregate_features = torch.stack(aggregate_outs, 0) 29 | hat_out_mid = sum(aggregate_outs) 30 | hat_out = self.gap(hat_out_mid).squeeze(-1).squeeze(-1) 31 | temporal = hat_out.size(-1) 32 | weights = [] 33 | for fc in self.fcs: 34 | weight = fc(hat_out) 35 | weights.append(weight.view(bs, c, temporal, 1, 1)) 36 | select_weights = torch.stack(weights, 0) 37 | select_weights = self.softmax(select_weights) 38 | outs = (select_weights * aggregate_features).sum(0) 39 | return outs 40 | 41 | # adaptive region-based motion extractor (ARME) 42 | class ARME_Conv(nn.Module): 43 | def __init__(self, in_channels, out_channels, split_param ,m, kernel_size=(3, 3, 3), stride=(1, 1, 1), 44 | padding=(1, 1, 1),bias=False,**kwargs): 45 | super(ARME_Conv, self).__init__() 46 | self.m = m 47 | 48 | self.split_param = split_param 49 | 50 | self.conv3d = nn.ModuleList([ 51 | BasicConv3d(in_channels, out_channels, kernel_size, stride, padding,bias ,**kwargs) 52 | for i in range(self.m)]) 53 | 54 | 55 | def forward(self, x): 56 | ''' 57 | x: [n, c, s, h, w] 58 | ''' 59 | feat = x.split(self.split_param, 3) 60 | feat = torch.cat([self.conv3d[i](_) for i, _ in enumerate(feat)], 3) 61 | feat = F.leaky_relu(feat) 62 | return feat 63 | 64 | 65 | # Generalized Mean Pooling (GeM) 66 | class GeMHPP(nn.Module): 67 | def __init__(self, bin_num=[32], p=6.5, eps=1.0e-6): 68 | super(GeMHPP, self).__init__() 69 | self.bin_num = bin_num 70 | self.p = nn.Parameter( 71 | torch.ones(1)*p) 72 | self.eps = eps 73 | 74 | def gem(self, ipts): 75 | return F.avg_pool2d(ipts.clamp(min=self.eps).pow(self.p), (1, ipts.size(-1))).pow(1. / self.p) 76 | 77 | def forward(self, x): 78 | """ 79 | x : [n, c, h, w] 80 | ret: [n, c, p] 81 | """ 82 | n, c = x.size()[:2] 83 | features = [] 84 | for b in self.bin_num: 85 | z = x.view(n, c, b, -1) 86 | z = self.gem(z).squeeze(-1) 87 | features.append(z) 88 | return torch.cat(features, -1) 89 | 90 | class MGeMHPP(nn.Module): 91 | def __init__(self, split_param, m): 92 | super(MGeMHPP, self).__init__() 93 | self.split_param = split_param 94 | self.m = m 95 | self.hpp = nn.ModuleList([ 96 | GeMHPP(bin_num=[1]) for i in range(self.m)]) 97 | def forward(self, x): 98 | feat = x.split(self.split_param, 2) 99 | # print(feat[0].size()) 100 | feat = torch.cat([self.hpp[i](_) for i, _ in enumerate(feat)], -1) 101 | return feat 102 | 103 | 104 | class FTA_Block(nn.Module): 105 | def __init__(self, split_param, m, in_channels): 106 | super(FTA_Block, self).__init__() 107 | self.split_param = split_param 108 | self.m = m 109 | self.mma = nn.ModuleList([ 110 | FTA(channel=in_channels, FP_kernels=[3, 5]) 111 | for i in range(self.m)]) 112 | def forward(self, x): 113 | feat = x.split(self.split_param, 3) 114 | feat = torch.cat([self.mma[i](_) for i, _ in enumerate(feat)], 3) 115 | return feat 116 | 117 | # adaptive spatio-temporal pooling 118 | class ASTP(nn.Module): 119 | def __init__(self, split_param, m, in_channels, out_channels, flag=True): 120 | super(ASTP, self).__init__() 121 | self.split_param = split_param 122 | self.m = m 123 | self.hpp = nn.ModuleList([ 124 | GeMHPP(bin_num=[1]) for i in range(self.m)]) 125 | 126 | self.flag = flag 127 | if self.flag: 128 | self.proj = BasicConv2d(in_channels, out_channels, 1, 1, 0) 129 | 130 | 131 | self.SP1 = PackSequenceWrapper(torch.max) 132 | def forward(self, x, seqL): 133 | x = self.SP1(x, seqL=seqL, options={"dim": 2})[0] 134 | if self.flag: 135 | x = self.proj(x) 136 | feat = x.split(self.split_param, 2) 137 | feat = torch.cat([self.hpp[i](_) for i, _ in enumerate(feat)], -1) 138 | return feat 139 | 140 | class HSTL_Gait3D(BaseModel): 141 | """ 142 | GaitGL: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation 143 | Arxiv : https://arxiv.org/pdf/2011.01461.pdf 144 | """ 145 | 146 | def __init__(self, *args, **kargs): 147 | super(HSTL_Gait3D, self).__init__(*args, **kargs) 148 | 149 | def build_network(self, model_cfg): 150 | in_c = model_cfg['channels'] 151 | class_num = model_cfg['class_num'] 152 | dataset_name = self.cfgs['data_cfg']['dataset_name'] 153 | 154 | if dataset_name in ['OUMVLP']: 155 | # For OUMVLP GREW and Gait3D 156 | self.arme1 = nn.Sequential( 157 | BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3), 158 | stride=(1, 1, 1), padding=(1, 1, 1)), 159 | nn.LeakyReLU(inplace=True), 160 | BasicConv3d(in_c[0], in_c[0], kernel_size=(3, 3, 3), 161 | stride=(1, 1, 1), padding=(1, 1, 1)), 162 | nn.LeakyReLU(inplace=True) 163 | ) 164 | 165 | self.astp1 = ASTP(split_param=[64], m=1, in_channels=in_c[0], out_channels=in_c[-1]) 166 | 167 | self.MaxPool0 = nn.MaxPool3d( 168 | kernel_size=(1, 2, 2), stride=(1, 2, 2)) 169 | 170 | self.arme2 = nn.Sequential( 171 | ARME_Conv(in_c[0], in_c[1], split_param=[20, 12], m=2, kernel_size=( 172 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 173 | ARME_Conv(in_c[1], in_c[1], split_param=[20, 12], m=2, kernel_size=( 174 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 175 | ) 176 | 177 | 178 | self.astp2 = ASTP(split_param=[20, 12], m=2, in_channels=in_c[1], out_channels=in_c[-1]) 179 | 180 | self.fta = FTA_Block(split_param=[20, 12], m=2, in_channels=in_c[1]) 181 | 182 | self.astp2_fta = ASTP(split_param=[20, 12], m=2, in_channels=in_c[1], out_channels=in_c[-1]) 183 | 184 | # self.MaxPool0 = nn.MaxPool3d( 185 | # kernel_size=(1, 2, 2), stride=(1, 2, 2)) 186 | 187 | self.arme3 = nn.Sequential( 188 | ARME_Conv(in_c[1], in_c[2], split_param=[4, 16, 8, 4], m=4, kernel_size=( 189 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 190 | ARME_Conv(in_c[2], in_c[2], split_param=[4, 16, 8, 4], m=4, kernel_size=( 191 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 192 | ) 193 | 194 | self.astp3 = ASTP(split_param=[4, 16, 8, 4], m=4, in_channels=in_c[2], out_channels=in_c[-1]) 195 | 196 | # self.MaxPool1 = nn.MaxPool3d( 197 | # kernel_size=(1, 2, 2), stride=(1, 2, 2)) 198 | 199 | self.arme4 = nn.Sequential( 200 | ARME_Conv(in_c[2], in_c[3], split_param=[4, 4, 12, 8, 4], m=5, kernel_size=( 201 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)), 202 | ARME_Conv(in_c[3], in_c[3], split_param=[4, 4, 12, 8, 4], m=5, kernel_size=( 203 | 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 204 | ) 205 | 206 | self.astp4 = ASTP(split_param=[4, 4, 12, 8, 4], m=5, in_channels=in_c[3], out_channels=in_c[-1], flag=False) 207 | 208 | self.TP = PackSequenceWrapper(torch.max) 209 | 210 | self.HPP = GeMHPP() 211 | # self.mHPP = MGeMHPP(split_param=[40, 24], m=2) 212 | 213 | self.Head0 = SeparateFCs(46, in_c[-1], in_c[-1]) 214 | 215 | self.Bn = nn.BatchNorm1d(in_c[-1]) 216 | self.Head1 = SeparateFCs(46, in_c[-1], class_num) 217 | 218 | def forward(self, inputs): 219 | ipts, labs, _, _, seqL = inputs 220 | seqL = None if not self.training else seqL 221 | if not self.training and len(labs) != 1: 222 | raise ValueError( 223 | 'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs))) 224 | sils = ipts[0].unsqueeze(1) 225 | del ipts 226 | n, _, s, h, w = sils.size() 227 | if s < 3: 228 | repeat = 3 if s == 1 else 2 229 | sils = sils.repeat(1, 1, repeat, 1, 1) 230 | 231 | outs = self.arme1(sils) 232 | astp1 = self.astp1(outs, seqL) 233 | outs = self.MaxPool0(outs) 234 | outs = self.arme2(outs) 235 | astp2 = self.astp2(outs, seqL) 236 | outs = self.fta(outs) 237 | astp2_fta = self.astp2_fta(outs, seqL) 238 | outs = self.arme3(outs) 239 | astp3 = self.astp3(outs, seqL) 240 | outs = self.arme4(outs) 241 | astp4 = self.astp4(outs, seqL) 242 | astp5 = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w] 243 | astp5 = self.HPP(astp5) 244 | # astp4 = self.astp4(outs, seqL) 245 | outs = torch.cat([astp1, astp2, astp2_fta, astp3, astp4, astp5], dim=-1) # [n, c, p] 246 | 247 | gait = self.Head0(outs) # [n, c, p] 248 | # print(gait.size()) 249 | 250 | bnft = self.Bn(gait) # [n, c, p] 251 | logi = self.Head1(bnft) # [n, c, p] 252 | 253 | gait = gait.permute(0, 2, 1).contiguous() # [n, p, c] 254 | bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c] 255 | logi = logi.permute(1, 0, 2).contiguous() # [n, p, c] 256 | # print(logi.size()) 257 | 258 | n, _, s, h, w = sils.size() 259 | retval = { 260 | 'training_feat': { 261 | 'triplet': {'embeddings': bnft, 'labels': labs}, 262 | 'softmax': {'logits': logi, 'labels': labs} 263 | }, 264 | 'visual_summary': { 265 | 'image/sils': sils.view(n * s, 1, h, w) 266 | }, 267 | 'inference_feat': { 268 | 'embeddings': bnft 269 | } 270 | } 271 | return retval -------------------------------------------------------------------------------- /lib/modeling/base_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import numpy as np 5 | import os.path as osp 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.utils.data as tordata 9 | 10 | from tqdm import tqdm 11 | from torch.cuda.amp import autocast 12 | from torch.cuda.amp import GradScaler 13 | from abc import ABCMeta 14 | from abc import abstractmethod 15 | 16 | from . import backbones 17 | from .loss_aggregator import LossAggregator 18 | from data.transform import get_transform 19 | from data.collate_fn import CollateFn 20 | from data.dataset import DataSet 21 | import data.sampler as Samplers 22 | from utils import Odict, mkdir, ddp_all_gather 23 | from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from 24 | from utils import evaluation as eval_functions 25 | from utils import NoOp 26 | from utils import get_msg_mgr 27 | 28 | __all__ = ['BaseModel'] 29 | 30 | 31 | class MetaModel(metaclass=ABCMeta): 32 | @abstractmethod 33 | def get_loader(self, data_cfg): 34 | """Based on the given data_cfg, we get the data loader.""" 35 | raise NotImplementedError 36 | 37 | @abstractmethod 38 | def build_network(self, model_cfg): 39 | """Build your network here.""" 40 | raise NotImplementedError 41 | 42 | @abstractmethod 43 | def init_parameters(self): 44 | """Initialize the parameters of your network.""" 45 | raise NotImplementedError 46 | 47 | @abstractmethod 48 | def get_optimizer(self, optimizer_cfg): 49 | """Based on the given optimizer_cfg, we get the optimizer.""" 50 | raise NotImplementedError 51 | 52 | @abstractmethod 53 | def get_scheduler(self, scheduler_cfg): 54 | """Based on the given scheduler_cfg, we get the scheduler.""" 55 | raise NotImplementedError 56 | 57 | @abstractmethod 58 | def save_ckpt(self, iteration): 59 | """Save the checkpoint, including model parameter, optimizer and scheduler.""" 60 | raise NotImplementedError 61 | 62 | @abstractmethod 63 | def resume_ckpt(self, restore_hint): 64 | """Resume the model from the checkpoint, including model parameter, optimizer and scheduler.""" 65 | raise NotImplementedError 66 | 67 | @abstractmethod 68 | def inputs_pretreament(self, inputs): 69 | """Transform the input data based on transform setting.""" 70 | raise NotImplementedError 71 | 72 | @abstractmethod 73 | def train_step(self, loss_num) -> bool: 74 | """Do one training step.""" 75 | raise NotImplementedError 76 | 77 | @abstractmethod 78 | def inference(self): 79 | """Do inference (calculate features.).""" 80 | raise NotImplementedError 81 | 82 | @abstractmethod 83 | def run_train(model): 84 | """Run a whole train schedule.""" 85 | raise NotImplementedError 86 | 87 | @abstractmethod 88 | def run_test(model): 89 | """Run a whole test schedule.""" 90 | raise NotImplementedError 91 | 92 | 93 | class BaseModel(MetaModel, nn.Module): 94 | 95 | def __init__(self, cfgs, training): 96 | 97 | super(BaseModel, self).__init__() 98 | self.msg_mgr = get_msg_mgr() 99 | self.cfgs = cfgs 100 | self.iteration = 0 101 | self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg'] 102 | if self.engine_cfg is None: 103 | raise Exception("Initialize a model without -Engine-Cfgs-") 104 | 105 | if training and self.engine_cfg['enable_float16']: 106 | self.Scaler = GradScaler() 107 | self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'], 108 | cfgs['model_cfg']['model'], self.engine_cfg['save_name']) 109 | 110 | self.build_network(cfgs['model_cfg']) 111 | self.init_parameters() 112 | 113 | self.msg_mgr.log_info(cfgs['data_cfg']) 114 | if training: 115 | self.train_loader = self.get_loader( 116 | cfgs['data_cfg'], train=True) 117 | if not training or self.engine_cfg['with_test']: 118 | self.test_loader = self.get_loader( 119 | cfgs['data_cfg'], train=False) 120 | 121 | self.device = torch.distributed.get_rank() 122 | torch.cuda.set_device(self.device) 123 | self.to(device=torch.device( 124 | "cuda", self.device)) 125 | 126 | if training: 127 | self.loss_aggregator = LossAggregator(cfgs['loss_cfg']) 128 | self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg']) 129 | self.scheduler = self.get_scheduler(cfgs['scheduler_cfg']) 130 | self.train(training) 131 | restore_hint = self.engine_cfg['restore_hint'] 132 | if restore_hint != 0: 133 | self.resume_ckpt(restore_hint) 134 | 135 | if training: 136 | if cfgs['trainer_cfg']['fix_BN']: 137 | self.fix_BN() 138 | 139 | def get_backbone(self, backbone_cfg): 140 | if is_dict(backbone_cfg): 141 | Backbone = get_attr_from([backbones], backbone_cfg['type']) 142 | valid_args = get_valid_args(Backbone, backbone_cfg, ['type']) 143 | return Backbone(**valid_args) 144 | if is_list(backbone_cfg): 145 | Backbone = nn.ModuleList([self.get_backbone(cfg) 146 | for cfg in backbone_cfg]) 147 | return Backbone 148 | raise ValueError( 149 | "Error type for -Backbone-Cfg-, supported: (A list of) dict.") 150 | 151 | def build_network(self, model_cfg): 152 | if 'backbone_cfg' in model_cfg.keys(): 153 | self.Backbone = self.get_backbone(model_cfg['backbone_cfg']) 154 | 155 | def init_parameters(self): 156 | for m in self.modules(): 157 | if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)): 158 | nn.init.xavier_uniform_(m.weight.data) 159 | if m.bias is not None: 160 | nn.init.constant_(m.bias.data, 0.0) 161 | elif isinstance(m, nn.Linear): 162 | nn.init.xavier_uniform_(m.weight.data) 163 | if m.bias is not None: 164 | nn.init.constant_(m.bias.data, 0.0) 165 | elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)): 166 | if m.affine: 167 | nn.init.normal_(m.weight.data, 1.0, 0.02) 168 | nn.init.constant_(m.bias.data, 0.0) 169 | 170 | def get_loader(self, data_cfg, train=True): 171 | sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler'] 172 | dataset = DataSet(data_cfg, train) 173 | 174 | Sampler = get_attr_from([Samplers], sampler_cfg['type']) 175 | vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[ 176 | 'sample_type', 'type']) 177 | sampler = Sampler(dataset, **vaild_args) 178 | 179 | loader = tordata.DataLoader( 180 | dataset=dataset, 181 | batch_sampler=sampler, 182 | collate_fn=CollateFn(dataset.label_set, sampler_cfg), 183 | num_workers=data_cfg['num_workers']) 184 | return loader 185 | 186 | def get_optimizer(self, optimizer_cfg): 187 | self.msg_mgr.log_info(optimizer_cfg) 188 | optimizer = get_attr_from([optim], optimizer_cfg['solver']) 189 | valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver']) 190 | optimizer = optimizer( 191 | filter(lambda p: p.requires_grad, self.parameters()), **valid_arg) 192 | return optimizer 193 | 194 | def get_scheduler(self, scheduler_cfg): 195 | self.msg_mgr.log_info(scheduler_cfg) 196 | Scheduler = get_attr_from( 197 | [optim.lr_scheduler], scheduler_cfg['scheduler']) 198 | valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler']) 199 | scheduler = Scheduler(self.optimizer, **valid_arg) 200 | return scheduler 201 | 202 | def save_ckpt(self, iteration): 203 | if torch.distributed.get_rank() == 0: 204 | mkdir(osp.join(self.save_path, "checkpoints/")) 205 | save_name = self.engine_cfg['save_name'] 206 | checkpoint = { 207 | 'model': self.state_dict(), 208 | 'optimizer': self.optimizer.state_dict(), 209 | 'scheduler': self.scheduler.state_dict(), 210 | 'iteration': iteration} 211 | torch.save(checkpoint, 212 | osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration))) 213 | 214 | def _load_ckpt(self, save_name): 215 | load_ckpt_strict = self.engine_cfg['restore_ckpt_strict'] 216 | 217 | checkpoint = torch.load(save_name, map_location=torch.device( 218 | "cuda", self.device)) 219 | model_state_dict = checkpoint['model'] 220 | 221 | if not load_ckpt_strict: 222 | self.msg_mgr.log_info("-------- Restored Params List --------") 223 | self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection( 224 | set(self.state_dict().keys())))) 225 | 226 | self.load_state_dict(model_state_dict, strict=load_ckpt_strict) 227 | if self.training: 228 | if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint: 229 | self.optimizer.load_state_dict(checkpoint['optimizer']) 230 | else: 231 | self.msg_mgr.log_warning( 232 | "Restore NO Optimizer from %s !!!" % save_name) 233 | if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint: 234 | self.scheduler.load_state_dict( 235 | checkpoint['scheduler']) 236 | else: 237 | self.msg_mgr.log_warning( 238 | "Restore NO Scheduler from %s !!!" % save_name) 239 | self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name) 240 | 241 | def resume_ckpt(self, restore_hint): 242 | if isinstance(restore_hint, int): 243 | save_name = self.engine_cfg['save_name'] 244 | save_name = osp.join( 245 | self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint)) 246 | self.iteration = restore_hint 247 | elif isinstance(restore_hint, str): 248 | save_name = restore_hint 249 | self.iteration = 0 250 | else: 251 | raise ValueError( 252 | "Error type for -Restore_Hint-, supported: int or string.") 253 | self._load_ckpt(save_name) 254 | 255 | def fix_BN(self): 256 | for module in self.modules(): 257 | classname = module.__class__.__name__ 258 | if classname.find('BatchNorm') != -1: 259 | module.eval() 260 | 261 | def inputs_pretreament(self, inputs): 262 | seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs 263 | trf_cfgs = self.engine_cfg['transform'] 264 | seq_trfs = get_transform(trf_cfgs) 265 | 266 | requires_grad = bool(self.training) 267 | seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float() 268 | for trf, seq in zip(seq_trfs, seqs_batch)] 269 | 270 | typs = typs_batch 271 | vies = vies_batch 272 | 273 | labs = list2var(labs_batch).long() 274 | 275 | if seqL_batch is not None: 276 | seqL_batch = np2var(seqL_batch).int() 277 | seqL = seqL_batch 278 | 279 | if seqL is not None: 280 | seqL_sum = int(seqL.sum().data.cpu().numpy()) 281 | ipts = [_[:, :seqL_sum] for _ in seqs] 282 | else: 283 | ipts = seqs 284 | del seqs 285 | return ipts, labs, typs, vies, seqL 286 | 287 | def train_step(self, loss_sum) -> bool: 288 | 289 | self.optimizer.zero_grad() 290 | if loss_sum <= 1e-9: 291 | self.msg_mgr.log_warning( 292 | "Find the loss sum less than 1e-9 but the training process will continue!") 293 | 294 | if self.engine_cfg['enable_float16']: 295 | self.Scaler.scale(loss_sum).backward() 296 | self.Scaler.step(self.optimizer) 297 | scale = self.Scaler.get_scale() 298 | self.Scaler.update() 299 | if scale != self.Scaler.get_scale(): 300 | self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format( 301 | scale, self.Scaler.get_scale())) 302 | return False 303 | else: 304 | loss_sum.backward() 305 | self.optimizer.step() 306 | 307 | self.iteration += 1 308 | self.scheduler.step() 309 | return True 310 | 311 | def inference(self, rank): 312 | total_size = len(self.test_loader) 313 | if rank == 0: 314 | pbar = tqdm(total=total_size, desc='Transforming') 315 | else: 316 | pbar = NoOp() 317 | batch_size = self.test_loader.batch_sampler.batch_size 318 | rest_size = total_size 319 | info_dict = Odict() 320 | for inputs in self.test_loader: 321 | ipts = self.inputs_pretreament(inputs) 322 | with autocast(enabled=self.engine_cfg['enable_float16']): 323 | retval = self.forward(ipts) 324 | inference_feat = retval['inference_feat'] 325 | for k, v in inference_feat.items(): 326 | inference_feat[k] = ddp_all_gather(v, requires_grad=False) 327 | del retval 328 | for k, v in inference_feat.items(): 329 | inference_feat[k] = ts2np(v) 330 | info_dict.append(inference_feat) 331 | rest_size -= batch_size 332 | if rest_size >= 0: 333 | update_size = batch_size 334 | else: 335 | update_size = total_size % batch_size 336 | pbar.update(update_size) 337 | pbar.close() 338 | for k, v in info_dict.items(): 339 | v = np.concatenate(v)[:total_size] 340 | info_dict[k] = v 341 | return info_dict 342 | 343 | @ staticmethod 344 | def run_train(model): 345 | for inputs in model.train_loader: 346 | ipts = model.inputs_pretreament(inputs) 347 | with autocast(enabled=model.engine_cfg['enable_float16']): 348 | retval = model(ipts) 349 | training_feat, visual_summary = retval['training_feat'], retval['visual_summary'] 350 | del retval 351 | loss_sum, loss_info = model.loss_aggregator(training_feat) 352 | ok = model.train_step(loss_sum) 353 | if not ok: 354 | continue 355 | 356 | visual_summary.update(loss_info) 357 | visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr'] 358 | 359 | model.msg_mgr.train_step(loss_info, visual_summary) 360 | if model.iteration % model.engine_cfg['save_iter'] == 0: 361 | # save the checkpoint 362 | model.save_ckpt(model.iteration) 363 | 364 | # run test if with_test = true 365 | if model.engine_cfg['with_test']: 366 | model.msg_mgr.log_info("Running test...") 367 | model.eval() 368 | result_dict = BaseModel.run_test(model) 369 | model.train() 370 | model.msg_mgr.write_to_tensorboard(result_dict) 371 | model.msg_mgr.reset_time() 372 | if model.iteration >= model.engine_cfg['total_iter']: 373 | break 374 | 375 | @ staticmethod 376 | def run_test(model): 377 | 378 | rank = torch.distributed.get_rank() 379 | with torch.no_grad(): 380 | info_dict = model.inference(rank) 381 | if rank == 0: 382 | loader = model.test_loader 383 | label_list = loader.dataset.label_list 384 | types_list = loader.dataset.types_list 385 | views_list = loader.dataset.views_list 386 | 387 | info_dict.update({ 388 | 'labels': label_list, 'types': types_list, 'views': views_list}) 389 | 390 | if 'eval_func' in model.cfgs["evaluator_cfg"].keys(): 391 | eval_func = model.cfgs['evaluator_cfg']["eval_func"] 392 | else: 393 | eval_func = 'identification' 394 | eval_func = getattr(eval_functions, eval_func) 395 | valid_args = get_valid_args( 396 | eval_func, model.cfgs["evaluator_cfg"], ['metric']) 397 | try: 398 | dataset_name = model.cfgs['data_cfg']['test_dataset_name'] 399 | except: 400 | dataset_name = model.cfgs['data_cfg']['dataset_name'] 401 | return eval_func(info_dict, dataset_name, **valid_args) --------------------------------------------------------------------------------