├── reid ├── train │ ├── __init__.py │ ├── trainer.pyc │ ├── __init__.pyc │ └── trainer.py ├── __init__.pyc ├── loss │ ├── oim.pyc │ ├── __init__.pyc │ ├── pairloss.pyc │ ├── __init__.py │ ├── pairloss.py │ └── oim.py ├── data │ ├── __init__.pyc │ ├── sampler.pyc │ ├── dataloader.pyc │ ├── datasequence.pyc │ ├── seqtransforms.pyc │ ├── seqpreprocessor.pyc │ ├── __init__.py │ ├── dataloader.py │ ├── sampler.py │ ├── seqpreprocessor.py │ ├── datasequence.py │ └── seqtransforms.py ├── models │ ├── resnet.pyc │ ├── __init__.pyc │ ├── alexnet.pyc │ ├── attmodel.pyc │ ├── classifier.pyc │ ├── crosspoolingdir.pyc │ ├── selfpoolingdir.pyc │ ├── __init__.py │ ├── selfpoolingdir.py │ ├── alexnet.py │ ├── crosspoolingdir.py │ ├── attmodel.py │ ├── classifier.py │ └── resnet.py ├── dataset │ ├── __init__.pyc │ ├── ilidsvidsequence.pyc │ ├── prid2011sequence.pyc │ ├── __init__.py │ ├── prid2011sequence.py │ └── ilidsvidsequence.py └── evaluator │ ├── __init__.pyc │ ├── evaluator.pyc │ ├── attevaluator.pyc │ ├── eva_functions.pyc │ ├── __init__.py │ ├── eva_functions.py │ ├── evaluator.py │ └── attevaluator.py ├── utils ├── logging.pyc ├── meters.pyc ├── osutils.pyc ├── __init__.pyc ├── serialization.pyc ├── osutils.py ├── meters.py ├── __init__.py ├── logging.py └── serialization.py ├── tensorboardX ├── graph.pyc ├── x2num.pyc ├── crc32c.pyc ├── summary.pyc ├── writer.pyc ├── embedding.pyc ├── graph_onnx.pyc ├── record_writer.pyc ├── src │ ├── event_pb2.pyc │ ├── graph_pb2.pyc │ ├── summary_pb2.pyc │ ├── tensor_pb2.pyc │ ├── types_pb2.pyc │ ├── node_def_pb2.pyc │ ├── versions_pb2.pyc │ ├── attr_value_pb2.pyc │ ├── tensor_shape_pb2.pyc │ ├── plugin_pr_curve_pb2.pyc │ ├── resource_handle_pb2.pyc │ ├── plugin_pr_curve.proto │ ├── resource_handle.proto │ ├── versions.proto │ ├── tensor_shape.proto │ ├── types.proto │ ├── graph.proto │ ├── event.proto │ ├── attr_value.proto │ ├── node_def.proto │ ├── plugin_pr_curve_pb2.py │ ├── tensor.proto │ ├── versions_pb2.py │ ├── graph_pb2.py │ ├── resource_handle_pb2.py │ ├── summary.proto │ ├── tensor_shape_pb2.py │ ├── node_def_pb2.py │ ├── types_pb2.py │ └── tensor_pb2.py ├── event_file_writer.pyc ├── record_writer.py ├── graph.py ├── embedding.py ├── x2num.py ├── crc32c.py ├── graph_onnx.py └── event_file_writer.py ├── README.md ├── scripts ├── ilid │ ├── cnn_scan_test.sh │ └── cnn_scan_train.sh └── prid │ ├── cnn_scan_test.sh │ └── cnn_scan_train.sh └── .gitignore /reid/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import SEQTrainer -------------------------------------------------------------------------------- /reid/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/__init__.pyc -------------------------------------------------------------------------------- /reid/loss/oim.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/loss/oim.pyc -------------------------------------------------------------------------------- /utils/logging.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/utils/logging.pyc -------------------------------------------------------------------------------- /utils/meters.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/utils/meters.pyc -------------------------------------------------------------------------------- /utils/osutils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/utils/osutils.pyc -------------------------------------------------------------------------------- /utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/utils/__init__.pyc -------------------------------------------------------------------------------- /reid/data/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/data/__init__.pyc -------------------------------------------------------------------------------- /reid/data/sampler.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/data/sampler.pyc -------------------------------------------------------------------------------- /reid/loss/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/loss/__init__.pyc -------------------------------------------------------------------------------- /reid/loss/pairloss.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/loss/pairloss.pyc -------------------------------------------------------------------------------- /reid/models/resnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/models/resnet.pyc -------------------------------------------------------------------------------- /reid/train/trainer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/train/trainer.pyc -------------------------------------------------------------------------------- /tensorboardX/graph.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/graph.pyc -------------------------------------------------------------------------------- /tensorboardX/x2num.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/x2num.pyc -------------------------------------------------------------------------------- /reid/data/dataloader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/data/dataloader.pyc -------------------------------------------------------------------------------- /reid/models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/models/__init__.pyc -------------------------------------------------------------------------------- /reid/models/alexnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/models/alexnet.pyc -------------------------------------------------------------------------------- /reid/models/attmodel.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/models/attmodel.pyc -------------------------------------------------------------------------------- /reid/train/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/train/__init__.pyc -------------------------------------------------------------------------------- /tensorboardX/crc32c.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/crc32c.pyc -------------------------------------------------------------------------------- /tensorboardX/summary.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/summary.pyc -------------------------------------------------------------------------------- /tensorboardX/writer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/writer.pyc -------------------------------------------------------------------------------- /utils/serialization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/utils/serialization.pyc -------------------------------------------------------------------------------- /reid/data/datasequence.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/data/datasequence.pyc -------------------------------------------------------------------------------- /reid/data/seqtransforms.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/data/seqtransforms.pyc -------------------------------------------------------------------------------- /reid/dataset/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/dataset/__init__.pyc -------------------------------------------------------------------------------- /reid/evaluator/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/evaluator/__init__.pyc -------------------------------------------------------------------------------- /reid/models/classifier.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/models/classifier.pyc -------------------------------------------------------------------------------- /tensorboardX/embedding.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/embedding.pyc -------------------------------------------------------------------------------- /tensorboardX/graph_onnx.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/graph_onnx.pyc -------------------------------------------------------------------------------- /reid/data/seqpreprocessor.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/data/seqpreprocessor.pyc -------------------------------------------------------------------------------- /reid/evaluator/evaluator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/evaluator/evaluator.pyc -------------------------------------------------------------------------------- /reid/evaluator/attevaluator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/evaluator/attevaluator.pyc -------------------------------------------------------------------------------- /reid/evaluator/eva_functions.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/evaluator/eva_functions.pyc -------------------------------------------------------------------------------- /reid/models/crosspoolingdir.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/models/crosspoolingdir.pyc -------------------------------------------------------------------------------- /reid/models/selfpoolingdir.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/models/selfpoolingdir.pyc -------------------------------------------------------------------------------- /tensorboardX/record_writer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/record_writer.pyc -------------------------------------------------------------------------------- /tensorboardX/src/event_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/event_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/graph_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/graph_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/summary_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/summary_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/tensor_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/tensor_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/types_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/types_pb2.pyc -------------------------------------------------------------------------------- /reid/dataset/ilidsvidsequence.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/dataset/ilidsvidsequence.pyc -------------------------------------------------------------------------------- /reid/dataset/prid2011sequence.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/reid/dataset/prid2011sequence.pyc -------------------------------------------------------------------------------- /tensorboardX/event_file_writer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/event_file_writer.pyc -------------------------------------------------------------------------------- /tensorboardX/src/node_def_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/node_def_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/versions_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/versions_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/attr_value_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/attr_value_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/tensor_shape_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/tensor_shape_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/plugin_pr_curve_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/plugin_pr_curve_pb2.pyc -------------------------------------------------------------------------------- /tensorboardX/src/resource_handle_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruixuejianfei/SCAN/HEAD/tensorboardX/src/resource_handle_pb2.pyc -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .pairloss import PairLoss 5 | 6 | __all__ = [ 7 | 'oim', 8 | 'OIM', 9 | 'OIMLoss', 10 | 'PairwiseLoss' 11 | ] 12 | 13 | 14 | -------------------------------------------------------------------------------- /reid/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .sampler import * 3 | from .datasequence import Datasequence 4 | from .seqpreprocessor import SeqTrainPreprocessor 5 | from .seqpreprocessor import SeqTestPreprocessor 6 | from .dataloader import get_data 7 | 8 | -------------------------------------------------------------------------------- /utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .eva_functions import accuracy, cmc, mean_ap 3 | from .evaluator import CNNEvaluator 4 | from .attevaluator import ATTEvaluator 5 | 6 | 7 | __all__ = [ 8 | 'accuracy', 9 | 'cmc', 10 | 'mean_ap', 11 | ] 12 | 13 | 14 | -------------------------------------------------------------------------------- /reid/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .ilidsvidsequence import iLIDSVIDSEQUENCE 3 | from .prid2011sequence import PRID2011SEQUENCE 4 | 5 | 6 | def get_sequence(name, root, *args, **kwargs): 7 | __factory = { 8 | 'ilidsvidsequence': iLIDSVIDSEQUENCE, 9 | 'prid2011sequence': PRID2011SEQUENCE 10 | } 11 | 12 | if name not in __factory: 13 | raise KeyError("Unknown dataset", name) 14 | return __factory[name](root, *args, **kwargs) -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to_numpy(tensor): 5 | if torch.is_tensor(tensor): 6 | return tensor.cpu().numpy() 7 | elif type(tensor).__module__ != 'numpy': 8 | raise ValueError("Cannot convert {} to numpy array" 9 | .format(type(tensor))) 10 | return tensor 11 | 12 | 13 | def to_torch(ndarray): 14 | if type(ndarray).__module__ == 'numpy': 15 | return torch.from_numpy(ndarray) 16 | elif not torch.is_tensor(ndarray): 17 | raise ValueError("Cannot convert {} to torch tensor" 18 | .format(type(ndarray))) 19 | return ndarray 20 | 21 | -------------------------------------------------------------------------------- /tensorboardX/src/plugin_pr_curve.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorboard; 19 | 20 | message PrCurvePluginData { 21 | // Version `0` is the only supported version. 22 | int32 version = 1; 23 | 24 | uint32 num_thresholds = 2; 25 | } 26 | -------------------------------------------------------------------------------- /tensorboardX/src/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "ResourceHandle"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 10 | // not valid across executions, but can be serialized back and forth from within 11 | // a single run. 12 | message ResourceHandleProto { 13 | // Unique name for the device containing the resource. 14 | string device = 1; 15 | 16 | // Container in which this resource is placed. 17 | string container = 2; 18 | 19 | // Unique name of this resource. 20 | string name = 3; 21 | 22 | // Hash code for the type of the resource. Is only valid in the same device 23 | // and in the same execution. 24 | uint64 hash_code = 4; 25 | 26 | // For debug-only, the name of the type pointed to by this handle, if 27 | // available. 28 | string maybe_type_name = 5; 29 | }; 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-and-Collaborative Attention Network 2 | 3 | This solution contains source code of the project "SCAN: Self-and-Collaborative Attention Network for Video Person Re-identification" 4 | 5 | The source code is for educational and research use only without any warranty; if you use any part of the source code, please cite related paper: 6 | 7 | 8 | ``` 9 | @article{zhang2018scan, 10 | title={SCAN: Self-and-Collaborative Attention Network for Video Person Re-identification}, 11 | author={Zhang, Ruimao and Sun, Hongbin and Li, Jingyu and Ge, Yuying and Lin, Liang and Luo, Ping and Wang, Xiaogang}, 12 | journal={IEEE Trans. on Image Processing}, 13 | year={2019} 14 | } 15 | ``` 16 | 17 | More detailed description about the codebase is coming soon. 18 | 19 | 20 | If you have any question about the code, please email ruimao.zhang@ieee.org or jingyuli@cuhk.edu.hk 21 | 22 | 23 | # License 24 | 25 | All materials in this repository are released under the [CC-BY-NC 4.0 LICENSE](https://creativecommons.org/licenses/by-nc/4.0/). 26 | 27 | 28 | -------------------------------------------------------------------------------- /scripts/ilid/cnn_scan_test.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | 3 | export PATH=/mnt/lustre/lijingyu/Data_t1/anaconda2/envs/py27pt02/bin:$PATH 4 | export TORCH_MODEL_ZOO=/mnt/lustre/DATAshare2/sunhongbin/pytorch_pretrained_models 5 | 6 | split=0 7 | jobname=ilid-$split-scan-128-test 8 | 9 | num_gpus=1 10 | log_dir=logs/ilid-split${split}-scan-128 11 | 12 | if [ ! -d $log_dir ]; then 13 | echo create log $log_dir 14 | mkdir -p $log_dir 15 | fi 16 | 17 | srun -p P100 --job-name=$jobname --gres=gpu:$num_gpus \ 18 | python -u train_val.py \ 19 | -d ilidsvidsequence \ 20 | -b 32 \ 21 | --seq_len 10 \ 22 | --seq_srd 5 \ 23 | --split $split \ 24 | --features 128 \ 25 | --a1 resnet50 \ 26 | --lr1 1e-3 \ 27 | --lr2 1e-3 \ 28 | --lr3 1 \ 29 | --train_mode cnn_rnn \ 30 | --lr1step 20 \ 31 | --lr2step 10 \ 32 | --lr3step 30 \ 33 | --test 1 \ 34 | --logs-dir $log_dir \ 35 | 2>&1 | tee ${log_dir}/record-test-${now}.txt &\ 36 | -------------------------------------------------------------------------------- /scripts/ilid/cnn_scan_train.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | 3 | export PATH=/mnt/lustre/lijingyu/Data_t1/anaconda2/envs/py27pt02/bin:$PATH 4 | export TORCH_MODEL_ZOO=/mnt/lustre/DATAshare2/sunhongbin/pytorch_pretrained_models 5 | 6 | split=0 7 | jobname=ilid-$split-scan-128 8 | 9 | num_gpus=4 10 | log_dir=logs/ilid-split${split}-scan-256 11 | 12 | if [ ! -d $log_dir ]; then 13 | echo create log $log_dir 14 | mkdir -p $log_dir 15 | fi 16 | 17 | srun -p TITANXP --job-name=$jobname --gres=gpu:$num_gpus \ 18 | python -u train_val.py \ 19 | -d ilidsvidsequence \ 20 | -b 32 \ 21 | --seq_len 10 \ 22 | --seq_srd 5 \ 23 | --split $split \ 24 | --epoch 30 \ 25 | --features 128 \ 26 | --a1 resnet50 \ 27 | --lr1 1e-3 \ 28 | --lr2 1e-3 \ 29 | --lr3 1 \ 30 | --train_mode cnn_rnn \ 31 | --lr1step 20 \ 32 | --lr2step 10 \ 33 | --lr3step 30 \ 34 | --logs-dir $log_dir \ 35 | 2>&1 | tee ${log_dir}/record-train-${now}.txt &\ 36 | -------------------------------------------------------------------------------- /scripts/prid/cnn_scan_test.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | 3 | export PATH=/mnt/lustre/lijingyu/Data_t1/anaconda2/envs/py27pt02/bin:$PATH 4 | export TORCH_MODEL_ZOO=/mnt/lustre/DATAshare2/sunhongbin/pytorch_pretrained_models 5 | 6 | split=0 7 | jobname=prid-$split-scan-128-test 8 | 9 | num_gpus=1 10 | log_dir=logs/prid-split${split}-scan-128 11 | 12 | if [ ! -d $log_dir ]; then 13 | echo create log $log_dir 14 | mkdir -p $log_dir 15 | fi 16 | 17 | srun -p P100 --job-name=$jobname --gres=gpu:$num_gpus \ 18 | python -u train_val.py \ 19 | -d prid2011sequence \ 20 | -b 32 \ 21 | --seq_len 10 \ 22 | --seq_srd 5 \ 23 | --split $split \ 24 | --features 128 \ 25 | --a1 resnet50 \ 26 | --lr1 1e-3 \ 27 | --lr2 1e-3 \ 28 | --lr3 1 \ 29 | --train_mode cnn_rnn \ 30 | --lr1step 20 \ 31 | --lr2step 10 \ 32 | --lr3step 30 \ 33 | --test 1 \ 34 | --logs-dir $log_dir \ 35 | 2>&1 | tee ${log_dir}/record-test-${now}.txt &\ 36 | -------------------------------------------------------------------------------- /scripts/prid/cnn_scan_train.sh: -------------------------------------------------------------------------------- 1 | now=$(date +"%Y%m%d_%H%M%S") 2 | 3 | export PATH=/mnt/lustre/lijingyu/Data_t1/anaconda2/envs/py27pt02/bin:$PATH 4 | export TORCH_MODEL_ZOO=/mnt/lustre/DATAshare2/sunhongbin/pytorch_pretrained_models 5 | 6 | split=0 7 | jobname=prid-$split-scan-128 8 | 9 | num_gpus=4 10 | log_dir=logs/prid-split${split}-scan-128 11 | 12 | if [ ! -d $log_dir ]; then 13 | echo create log $log_dir 14 | mkdir -p $log_dir 15 | fi 16 | 17 | srun -p TITANXP --job-name=$jobname --gres=gpu:$num_gpus \ 18 | python -u train_val.py \ 19 | -d prid2011sequence \ 20 | -b 32 \ 21 | --seq_len 10 \ 22 | --seq_srd 5 \ 23 | --split $split \ 24 | --epoch 30 \ 25 | --features 128 \ 26 | --a1 resnet50 \ 27 | --lr1 1e-3 \ 28 | --lr2 1e-3 \ 29 | --lr3 1 \ 30 | --train_mode cnn_rnn \ 31 | --lr1step 20 \ 32 | --lr2step 10 \ 33 | --lr3step 30 \ 34 | --logs-dir $log_dir \ 35 | 2>&1 | tee ${log_dir}/record-train-${now}.txt &\ 36 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /tensorboardX/src/versions.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "VersionsProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Version information for a piece of serialized data 10 | // 11 | // There are different types of versions for each type of data 12 | // (GraphDef, etc.), but they all have the same common shape 13 | // described here. 14 | // 15 | // Each consumer has "consumer" and "min_producer" versions (specified 16 | // elsewhere). A consumer is allowed to consume this data if 17 | // 18 | // producer >= min_producer 19 | // consumer >= min_consumer 20 | // consumer not in bad_consumers 21 | // 22 | message VersionDef { 23 | // The version of the code that produced this data. 24 | int32 producer = 1; 25 | 26 | // Any consumer below this version is not allowed to consume this data. 27 | int32 min_consumer = 2; 28 | 29 | // Specific consumer versions which are disallowed (e.g. due to bugs). 30 | repeated int32 bad_consumers = 3; 31 | }; 32 | -------------------------------------------------------------------------------- /tensorboardX/record_writer.py: -------------------------------------------------------------------------------- 1 | """ 2 | To write tf_record into file. Here we use it for tensorboard's event writting. 3 | The code was borrow from https://github.com/TeamHG-Memex/tensorboard_logger 4 | """ 5 | 6 | import re 7 | import struct 8 | 9 | from .crc32c import crc32c 10 | 11 | _VALID_OP_NAME_START = re.compile('^[A-Za-z0-9.]') 12 | _VALID_OP_NAME_PART = re.compile('[A-Za-z0-9_.\\-/]+') 13 | 14 | 15 | class RecordWriter(object): 16 | def __init__(self, path, flush_secs=2): 17 | self._name_to_tf_name = {} 18 | self._tf_names = set() 19 | self.path = path 20 | self.flush_secs = flush_secs # TODO. flush every flush_secs, not every time. 21 | self._writer = None 22 | self._writer = open(path, 'wb') 23 | 24 | def write(self, event_str): 25 | w = self._writer.write 26 | header = struct.pack('Q', len(event_str)) 27 | w(header) 28 | w(struct.pack('I', masked_crc32c(header))) 29 | w(event_str) 30 | w(struct.pack('I', masked_crc32c(event_str))) 31 | self._writer.flush() 32 | 33 | 34 | def masked_crc32c(data): 35 | x = u32(crc32c(data)) 36 | return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8) 37 | 38 | 39 | def u32(x): 40 | return x & 0xffffffff 41 | 42 | 43 | def make_valid_tf_name(name): 44 | if not _VALID_OP_NAME_START.match(name): 45 | # Must make it valid somehow, but don't want to remove stuff 46 | name = '.' + name 47 | return '_'.join(_VALID_OP_NAME_PART.findall(name)) 48 | -------------------------------------------------------------------------------- /tensorboardX/src/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorShapeProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | package tensorboard; 10 | 11 | // Dimensions of a tensor. 12 | message TensorShapeProto { 13 | // One dimension of the tensor. 14 | message Dim { 15 | // Size of the tensor in that dimension. 16 | // This value must be >= -1, but values of -1 are reserved for "unknown" 17 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 18 | // that work with TensorShapeProto may fail at runtime when deserializing 19 | // a TensorShapeProto containing a dim value of -1. 20 | int64 size = 1; 21 | 22 | // Optional name of the tensor dimension. 23 | string name = 2; 24 | }; 25 | 26 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 27 | // for a 30 x 40 2D tensor. If an entry has size -1, this 28 | // corresponds to a dimension of unknown size. The names are 29 | // optional. 30 | // 31 | // The order of entries in "dim" matters: It indicates the layout of the 32 | // values in the tensor in-memory representation. 33 | // 34 | // The first entry in "dim" is the outermost dimension used to layout the 35 | // values, the last entry is the innermost dimension. This matches the 36 | // in-memory layout of RowMajor Eigen tensors. 37 | // 38 | // If "dim.size()" > 0, "unknown_rank" must be false. 39 | repeated Dim dim = 2; 40 | 41 | // If true, the number of dimensions in the shape is unknown. 42 | // 43 | // If true, "dim.size()" must be 0. 44 | bool unknown_rank = 3; 45 | }; 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /reid/loss/pairloss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from reid.evaluator import accuracy 7 | 8 | 9 | class PairLoss(nn.Module): 10 | def __init__(self, sampling_rate=3): 11 | super(PairLoss, self).__init__() 12 | self.sampling_rate = sampling_rate 13 | self.BCE = nn.BCELoss() 14 | self.BCE.size_average = False 15 | 16 | def forward(self, score, tar_probe, tar_gallery): 17 | cls_Size = score.size() 18 | N_probe = cls_Size[0] 19 | N_gallery = cls_Size[1] 20 | 21 | tar_gallery = tar_gallery.unsqueeze(0) 22 | tar_probe = tar_probe.unsqueeze(1) 23 | mask = tar_probe.expand(N_probe, N_gallery).eq(tar_gallery.expand(N_probe, N_gallery)) 24 | mask = mask.view(-1).cpu().numpy().tolist() 25 | 26 | score = score.contiguous() 27 | samplers = score.view(-1) 28 | labels = Variable(torch.Tensor(mask).cuda()) 29 | 30 | positivelabel = torch.Tensor(mask) 31 | negativelabel = 1 - positivelabel 32 | positiveweightsum = torch.sum(positivelabel) 33 | negativeweightsum = torch.sum(negativelabel) 34 | neg_relativeweight = positiveweightsum / negativeweightsum * self.sampling_rate 35 | weights = (positivelabel + negativelabel * neg_relativeweight) 36 | weights = weights / torch.sum(weights) / 10 37 | self.BCE.weight = weights.cuda() 38 | loss = self.BCE(samplers, labels) 39 | 40 | samplers_data = samplers.data 41 | samplers_neg = 1 - samplers_data 42 | samplerdata = torch.cat((samplers_neg.unsqueeze(1), samplers_data.unsqueeze(1)), 1) 43 | 44 | labeldata = torch.LongTensor(mask).cuda() 45 | prec, = accuracy(samplerdata, labeldata) 46 | 47 | return loss, prec[0] 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(self.lut.t()) 17 | return outputs 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features 39 | self.num_classes = num_classes 40 | self.momentum = momentum 41 | self.scalar = scalar 42 | self.weight = weight 43 | self.size_average = size_average 44 | 45 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 46 | 47 | def forward(self, inputs, targets): 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | inputs *= self.scalar 50 | loss = F.cross_entropy(inputs, targets, weight=self.weight, 51 | size_average=self.size_average) 52 | return loss, inputs -------------------------------------------------------------------------------- /tensorboardX/src/types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TypesProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // LINT.IfChange 10 | enum DataType { 11 | // Not a legal value for DataType. Used to indicate a DataType field 12 | // has not been set. 13 | DT_INVALID = 0; 14 | 15 | // Data types that all computation devices are expected to be 16 | // capable to support. 17 | DT_FLOAT = 1; 18 | DT_DOUBLE = 2; 19 | DT_INT32 = 3; 20 | DT_UINT8 = 4; 21 | DT_INT16 = 5; 22 | DT_INT8 = 6; 23 | DT_STRING = 7; 24 | DT_COMPLEX64 = 8; // Single-precision complex 25 | DT_INT64 = 9; 26 | DT_BOOL = 10; 27 | DT_QINT8 = 11; // Quantized int8 28 | DT_QUINT8 = 12; // Quantized uint8 29 | DT_QINT32 = 13; // Quantized int32 30 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 31 | DT_QINT16 = 15; // Quantized int16 32 | DT_QUINT16 = 16; // Quantized uint16 33 | DT_UINT16 = 17; 34 | DT_COMPLEX128 = 18; // Double-precision complex 35 | DT_HALF = 19; 36 | DT_RESOURCE = 20; 37 | 38 | // TODO(josh11b): DT_GENERIC_PROTO = ??; 39 | // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? 40 | 41 | // Do not use! These are only for parameters. Every enum above 42 | // should have a corresponding value below (verified by types_test). 43 | DT_FLOAT_REF = 101; 44 | DT_DOUBLE_REF = 102; 45 | DT_INT32_REF = 103; 46 | DT_UINT8_REF = 104; 47 | DT_INT16_REF = 105; 48 | DT_INT8_REF = 106; 49 | DT_STRING_REF = 107; 50 | DT_COMPLEX64_REF = 108; 51 | DT_INT64_REF = 109; 52 | DT_BOOL_REF = 110; 53 | DT_QINT8_REF = 111; 54 | DT_QUINT8_REF = 112; 55 | DT_QINT32_REF = 113; 56 | DT_BFLOAT16_REF = 114; 57 | DT_QINT16_REF = 115; 58 | DT_QUINT16_REF = 116; 59 | DT_UINT16_REF = 117; 60 | DT_COMPLEX128_REF = 118; 61 | DT_HALF_REF = 119; 62 | DT_RESOURCE_REF = 120; 63 | } 64 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) 65 | -------------------------------------------------------------------------------- /tensorboardX/graph.py: -------------------------------------------------------------------------------- 1 | from .src.graph_pb2 import GraphDef 2 | from .src.node_def_pb2 import NodeDef 3 | from .src.versions_pb2 import VersionDef 4 | from .src.attr_value_pb2 import AttrValue 5 | from .src.tensor_shape_pb2 import TensorShapeProto 6 | 7 | 8 | def replace(name, scope): 9 | return '/'.join([scope[name], name]) 10 | 11 | 12 | def parse(graph): 13 | scope = {} 14 | for n in graph.nodes(): 15 | inputs = [i.uniqueName() for i in n.inputs()] 16 | for i in range(1, len(inputs)): 17 | scope[inputs[i]] = n.scopeName() 18 | 19 | uname = next(n.outputs()).uniqueName() 20 | assert n.scopeName() != '', '{} has empty scope name'.format(n) 21 | scope[uname] = n.scopeName() 22 | scope['0'] = 'input' 23 | 24 | nodes = [] 25 | for n in graph.nodes(): 26 | attrs = {k: n[k] for k in n.attributeNames()} 27 | attrs = str(attrs).replace("'", ' ') # singlequote will be escaped by tensorboard 28 | inputs = [replace(i.uniqueName(), scope) for i in n.inputs()] 29 | uname = next(n.outputs()).uniqueName() 30 | nodes.append({'name': replace(uname, scope), 'op': n.kind(), 'inputs': inputs, 'attr': attrs}) 31 | 32 | for n in graph.inputs(): 33 | uname = n.uniqueName() 34 | if uname not in scope.keys(): 35 | scope[uname] = 'unused' 36 | nodes.append({'name': replace(uname, scope), 'op': 'Parameter', 'inputs': [], 'attr': str(n.type())}) 37 | 38 | return nodes 39 | 40 | 41 | def graph(model, args, verbose=False): 42 | import torch 43 | with torch.onnx.set_training(model, False): 44 | trace, _ = torch.jit.trace(model, args) 45 | torch.onnx._optimize_trace(trace, False) 46 | graph = trace.graph() 47 | if verbose: 48 | print(graph) 49 | list_of_nodes = parse(graph) 50 | nodes = [] 51 | for node in list_of_nodes: 52 | nodes.append( 53 | NodeDef(name=node['name'], op=node['op'], input=node['inputs'], 54 | attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8'))})) 55 | return GraphDef(node=nodes, versions=VersionDef(producer=22)) 56 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .resnet import * 3 | from .alexnet import * 4 | from .selfpoolingdir import SelfPoolingDir 5 | from .crosspoolingdir import CrossPoolingDir 6 | from .classifier import Classifier 7 | from .attmodel import AttModuleDir 8 | from .classifier import Classifier 9 | 10 | 11 | __factory = { 12 | 'alexnet': alexnet, 13 | 'resnet18': resnet18, 14 | 'resnet34': resnet34, 15 | 'resnet50': resnet50, 16 | 'resnet101': resnet101, 17 | 'resnet152': resnet152, 18 | 'attmodel': AttModuleDir, 19 | 'classifier': Classifier, 20 | } 21 | 22 | 23 | def names(): 24 | return sorted(__factory.keys()) 25 | 26 | 27 | def create(name, *args, **kwargs): 28 | """ 29 | Create a model instance. 30 | Parameters 31 | ---------- 32 | name : str 33 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 34 | 'resnet50', 'resnet101', and 'resnet152'. 35 | pretrained : bool, optional 36 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 37 | model. Default: True 38 | cut_at_pooling : bool, optional 39 | If True, will cut the model before the last global pooling layer and 40 | ignore the remaining kwargs. Default: False 41 | num_features : int, optional 42 | If positive, will append a Linear layer after the global pooling layer, 43 | with this number of output units, followed by a BatchNorm layer. 44 | Otherwise these layers will not be appended. Default: 256 for 45 | 'inception', 0 for 'resnet*' 46 | norm : bool, optional 47 | If True, will normalize the feature to be unit L2-norm for each sample. 48 | Otherwise will append a ReLU layer after the above Linear layer if 49 | num_features > 0. Default: False 50 | dropout : float, optional 51 | If positive, will append a Dropout layer with this dropout rate. 52 | Default: 0 53 | """ 54 | if name not in __factory: 55 | raise KeyError("Unknown model:", name) 56 | return __factory[name](*args, **kwargs) 57 | -------------------------------------------------------------------------------- /tensorboardX/embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_tsv(metadata, save_path): 5 | metadata = [str(x) for x in metadata] 6 | with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f: 7 | for x in metadata: 8 | f.write(x + '\n') 9 | 10 | 11 | # https://github.com/tensorflow/tensorboard/issues/44 image label will be squared 12 | def make_sprite(label_img, save_path): 13 | import math 14 | import torch 15 | import torchvision 16 | from .x2num import makenp 17 | # this ensures the sprite image has correct dimension as described in 18 | # https://www.tensorflow.org/get_started/embedding_viz 19 | nrow = int(math.ceil((label_img.size(0)) ** 0.5)) 20 | 21 | label_img = torch.from_numpy(makenp(label_img)) # for other framework 22 | # augment images so that #images equals nrow*nrow 23 | label_img = torch.cat((label_img, torch.randn(nrow ** 2 - label_img.size(0), *label_img.size()[1:]) * 255), 0) 24 | 25 | torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0) 26 | 27 | 28 | def append_pbtxt(metadata, label_img, save_path, global_step, tag): 29 | with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f: 30 | # step = os.path.split(save_path)[-1] 31 | f.write('embeddings {\n') 32 | f.write('tensor_name: "{}:{}"\n'.format(tag, global_step)) 33 | f.write('tensor_path: "{}"\n'.format(os.path.join(global_step, 'tensors.tsv'))) 34 | if metadata is not None: 35 | f.write('metadata_path: "{}"\n'.format(os.path.join(global_step, 'metadata.tsv'))) 36 | if label_img is not None: 37 | f.write('sprite {\n') 38 | f.write('image_path: "{}"\n'.format(os.path.join(global_step, 'sprite.png'))) 39 | f.write('single_image_dim: {}\n'.format(label_img.size(3))) 40 | f.write('single_image_dim: {}\n'.format(label_img.size(2))) 41 | f.write('}\n') 42 | f.write('}\n') 43 | 44 | 45 | def make_mat(matlist, save_path): 46 | with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f: 47 | for x in matlist: 48 | x = [str(i) for i in x] 49 | f.write('\t'.join(x) + '\n') 50 | -------------------------------------------------------------------------------- /tensorboardX/src/graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "GraphProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/node_def.proto"; 10 | //import "tensorflow/core/framework/function.proto"; 11 | import "tensorboardX/src/versions.proto"; 12 | 13 | // Represents the graph of operations 14 | message GraphDef { 15 | repeated NodeDef node = 1; 16 | 17 | // Compatibility versions of the graph. See core/public/version.h for version 18 | // history. The GraphDef version is distinct from the TensorFlow version, and 19 | // each release of TensorFlow will support a range of GraphDef versions. 20 | VersionDef versions = 4; 21 | 22 | // Deprecated single version field; use versions above instead. Since all 23 | // GraphDef changes before "versions" was introduced were forward 24 | // compatible, this field is entirely ignored. 25 | int32 version = 3 [deprecated = true]; 26 | 27 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. 28 | // 29 | // "library" provides user-defined functions. 30 | // 31 | // Naming: 32 | // * library.function.name are in a flat namespace. 33 | // NOTE: We may need to change it to be hierarchical to support 34 | // different orgs. E.g., 35 | // { "/google/nn", { ... }}, 36 | // { "/google/vision", { ... }} 37 | // { "/org_foo/module_bar", { ... }} 38 | // map named_lib; 39 | // * If node[i].op is the name of one function in "library", 40 | // node[i] is deemed as a function call. Otherwise, node[i].op 41 | // must be a primitive operation supported by the runtime. 42 | // 43 | // 44 | // Function call semantics: 45 | // 46 | // * The callee may start execution as soon as some of its inputs 47 | // are ready. The caller may want to use Tuple() mechanism to 48 | // ensure all inputs are ready in the same time. 49 | // 50 | // * The consumer of return values may start executing as soon as 51 | // the return values the consumer depends on are ready. The 52 | // consumer may want to use Tuple() mechanism to ensure the 53 | // consumer does not start until all return values of the callee 54 | // function are ready. 55 | //FunctionDefLibrary library = 2; 56 | }; 57 | -------------------------------------------------------------------------------- /tensorboardX/src/event.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "EventProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.util"; 8 | 9 | import "tensorboardX/src/summary.proto"; 10 | 11 | // Protocol buffer representing an event that happened during 12 | // the execution of a Brain model. 13 | message Event { 14 | // Timestamp of the event. 15 | double wall_time = 1; 16 | 17 | // Global step of the event. 18 | int64 step = 2; 19 | 20 | oneof what { 21 | // An event file was started, with the specified version. 22 | // This is use to identify the contents of the record IO files 23 | // easily. Current version is "brain.Event:2". All versions 24 | // start with "brain.Event:". 25 | string file_version = 3; 26 | // An encoded version of a GraphDef. 27 | bytes graph_def = 4; 28 | // A summary was generated. 29 | Summary summary = 5; 30 | // The user output a log message. Not all messages are logged, only ones 31 | // generated via the Python tensorboard_logging module. 32 | LogMessage log_message = 6; 33 | // The state of the session which can be used for restarting after crashes. 34 | SessionLog session_log = 7; 35 | // The metadata returned by running a session.run() call. 36 | TaggedRunMetadata tagged_run_metadata = 8; 37 | // An encoded version of a MetaGraphDef. 38 | bytes meta_graph_def = 9; 39 | } 40 | } 41 | 42 | // Protocol buffer used for logging messages to the events file. 43 | message LogMessage { 44 | enum Level { 45 | UNKNOWN = 0; 46 | DEBUG = 10; 47 | INFO = 20; 48 | WARN = 30; 49 | ERROR = 40; 50 | FATAL = 50; 51 | } 52 | Level level = 1; 53 | string message = 2; 54 | } 55 | 56 | // Protocol buffer used for logging session state. 57 | message SessionLog { 58 | enum SessionStatus { 59 | STATUS_UNSPECIFIED = 0; 60 | START = 1; 61 | STOP = 2; 62 | CHECKPOINT = 3; 63 | } 64 | 65 | SessionStatus status = 1; 66 | // This checkpoint_path contains both the path and filename. 67 | string checkpoint_path = 2; 68 | string msg = 3; 69 | } 70 | 71 | // For logging the metadata output for a single session.run() call. 72 | message TaggedRunMetadata { 73 | // Tag name associated with this metadata. 74 | string tag = 1; 75 | // Byte-encoded version of the `RunMetadata` proto in order to allow lazy 76 | // deserialization. 77 | bytes run_metadata = 2; 78 | } 79 | -------------------------------------------------------------------------------- /utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_cnn_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'cnnmodel_best.pth.tar')) 29 | 30 | 31 | def save_att_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 32 | mkdir_if_missing(osp.dirname(fpath)) 33 | torch.save(state, fpath) 34 | if is_best: 35 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'attmodel_best.pth.tar')) 36 | 37 | 38 | def save_cls_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 39 | mkdir_if_missing(osp.dirname(fpath)) 40 | torch.save(state, fpath) 41 | if is_best: 42 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'clsmodel_best.pth.tar')) 43 | 44 | def load_checkpoint(fpath): 45 | if osp.isfile(fpath): 46 | checkpoint = torch.load(fpath) 47 | print("=> Loaded checkpoint '{}'".format(fpath)) 48 | return checkpoint 49 | else: 50 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 51 | 52 | 53 | def copy_state_dict(state_dict, model, strip=None): 54 | tgt_state = model.state_dict() 55 | copied_names = set() 56 | for name, param in state_dict.items(): 57 | if strip is not None and name.startswith(strip): 58 | name = name[len(strip):] 59 | if name not in tgt_state: 60 | continue 61 | if isinstance(param, Parameter): 62 | param = param.data 63 | if param.size() != tgt_state[name].size(): 64 | print('mismatch:', name, param.size(), tgt_state[name].size()) 65 | continue 66 | tgt_state[name].copy_(param) 67 | copied_names.add(name) 68 | 69 | missing = set(tgt_state.keys()) - copied_names 70 | if len(missing) > 0: 71 | print("missing keys in state_dict:", missing) 72 | 73 | return model 74 | -------------------------------------------------------------------------------- /reid/data/dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | from torch.utils.data import DataLoader 4 | from reid.dataset import get_sequence 5 | from reid.data import seqtransforms as T 6 | from reid.data import SeqTrainPreprocessor 7 | from reid.data import SeqTestPreprocessor 8 | from reid.data import RandomPairSampler 9 | 10 | 11 | def get_data(dataset_name, split_id, data_dir, batch_size, seq_len, seq_srd, workers, train_mode): 12 | 13 | root = osp.join(data_dir, dataset_name) 14 | dataset = get_sequence(dataset_name, root, split_id=split_id, 15 | seq_len=seq_len, seq_srd=seq_srd, num_val=1, download=True) 16 | train_set = dataset.trainval 17 | num_classes = dataset.num_trainval_ids 18 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 19 | 20 | train_processor = SeqTrainPreprocessor(train_set, dataset, seq_len, 21 | transform=T.Compose([T.RectScale(256, 128), 22 | T.RandomHorizontalFlip(), 23 | T.RandomSizedEarser(), 24 | T.ToTensor(), normalizer])) 25 | 26 | query_processor = SeqTestPreprocessor(dataset.query, dataset, seq_len, 27 | transform=T.Compose([T.RectScale(256, 128), 28 | T.ToTensor(), normalizer])) 29 | 30 | gallery_processor = SeqTestPreprocessor(dataset.gallery, dataset, seq_len, 31 | transform=T.Compose([T.RectScale(256, 128), 32 | T.ToTensor(), normalizer])) 33 | 34 | if train_mode == 'cnn_rnn': 35 | train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, sampler=RandomPairSampler(train_set), pin_memory=True) 36 | elif train_mode == 'cnn': 37 | train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, shuffle=True, pin_memory=True) 38 | else: 39 | raise ValueError('no such train mode') 40 | 41 | query_loader = DataLoader( 42 | query_processor, batch_size=8, num_workers=workers, shuffle=False, 43 | pin_memory=True) 44 | 45 | gallery_loader = DataLoader( 46 | gallery_processor, batch_size=8, num_workers=workers, shuffle=False, 47 | pin_memory=True) 48 | 49 | return dataset, num_classes, train_loader, query_loader, gallery_loader 50 | -------------------------------------------------------------------------------- /reid/models/selfpoolingdir.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | from torch import nn 4 | import torch.nn.init as init 5 | 6 | class SelfPoolingDir(nn.Module): 7 | def __init__(self, input_num, output_num, feat_fc=None): # 2048,128 8 | super(SelfPoolingDir, self).__init__() 9 | self.input_num = input_num 10 | self.output_num = output_num 11 | 12 | ## Linear_Q 13 | if feat_fc is None: 14 | self.featQ = nn.Sequential(nn.Linear(self.input_num, self.output_num), 15 | nn.BatchNorm1d(self.output_num)) 16 | for m in self.featQ.modules(): 17 | if isinstance(m, nn.Linear): 18 | init.kaiming_normal(m.weight, mode='fan_out') 19 | init.constant(m.bias, 0) 20 | elif isinstance(m, nn.BatchNorm1d): 21 | init.constant(m.weight, 1) 22 | init.constant(m.bias, 0) 23 | else: 24 | print(type(m)) 25 | else: 26 | self.featQ = feat_fc 27 | 28 | ## Softmax 29 | self.softmax = nn.Softmax() 30 | 31 | def forward(self, probe_value, probe_base): #(bz/2)*sq*128; (bz/2)*sq*2048 32 | pro_size = probe_value.size() 33 | pro_batch = pro_size[0] # 32 34 | pro_len = pro_size[1] # 10 35 | 36 | # generating Querys 37 | Qs = probe_base.view(pro_batch * pro_len, -1) # 320*2048 38 | Qs = self.featQ(Qs) 39 | # Qs = self.featQ_bn(Qs) 40 | Qs = Qs.view(pro_batch, pro_len, -1) # 32*10*128 41 | tmp_K = Qs 42 | Qmean = torch.mean(Qs, 1, keepdim=True) # 32*1*128 43 | Hs = Qmean.expand(pro_batch, pro_len, self.output_num) # 32*10*128 44 | 45 | weights = Hs * tmp_K # 32*10*128 46 | weights = weights.permute(0, 2, 1) # 32*128*10 47 | weights = weights.contiguous() 48 | weights = weights.view(-1, pro_len) 49 | weights = self.softmax(weights) 50 | weights = weights.view(pro_batch, self.output_num, pro_len) 51 | weights = weights.permute(0, 2, 1) # 32*10*128 52 | pool_probe = probe_value * weights 53 | pool_probe = pool_probe.sum(1) 54 | pool_probe = pool_probe.squeeze(1) # 32*128 55 | """ 56 | pool_probe = torch.mean(probe_value,1) 57 | pool_probe = pool_probe.squeeze(1) # 32*128 58 | """ 59 | 60 | # pool_probe Batch x featnum 61 | # Hs Batch x hidden_num 62 | 63 | return pool_probe, pool_probe 64 | -------------------------------------------------------------------------------- /tensorboardX/x2num.py: -------------------------------------------------------------------------------- 1 | # DO NOT alter/distruct/free input object ! 2 | 3 | import numpy as np 4 | 5 | 6 | def makenp(x, modality=None): 7 | # if already numpy, return 8 | if isinstance(x, np.ndarray): 9 | if modality == 'IMG' and x.dtype == np.uint8: 10 | return x.astype(np.float32) / 255.0 11 | return x 12 | if np.isscalar(x): 13 | return np.array([x]) 14 | if 'torch' in str(type(x)): 15 | return pytorch_np(x, modality) 16 | if 'chainer' in str(type(x)): 17 | return chainer_np(x, modality) 18 | if 'mxnet' in str(type(x)): 19 | return mxnet_np(x, modality) 20 | 21 | 22 | def pytorch_np(x, modality): 23 | import torch 24 | if isinstance(x, torch.autograd.variable.Variable): 25 | x = x.data 26 | x = x.cpu().numpy() 27 | if modality == 'IMG': 28 | x = _prepare_image(x) 29 | return x 30 | 31 | 32 | def theano_np(x): 33 | import theano 34 | pass 35 | 36 | 37 | def caffe2_np(x): 38 | pass 39 | 40 | 41 | def mxnet_np(x, modality): 42 | x = x.asnumpy() 43 | if modality == 'IMG': 44 | x = _prepare_image(x) 45 | return x 46 | 47 | 48 | def chainer_np(x, modality): 49 | import chainer 50 | x = chainer.cuda.to_cpu(x.data) 51 | if modality == 'IMG': 52 | x = _prepare_image(x) 53 | return x 54 | 55 | 56 | def make_grid(I, ncols=8): 57 | assert isinstance(I, np.ndarray), 'plugin error, should pass numpy array here' 58 | assert I.ndim == 4 and I.shape[1] == 3 59 | nimg = I.shape[0] 60 | H = I.shape[2] 61 | W = I.shape[3] 62 | ncols = min(nimg, ncols) 63 | nrows = int(np.ceil(float(nimg) / ncols)) 64 | canvas = np.zeros((3, H * nrows, W * ncols)) 65 | i = 0 66 | for y in range(nrows): 67 | for x in range(ncols): 68 | if i >= nimg: 69 | break 70 | canvas[:, y * H:(y + 1) * H, x * W:(x + 1) * W] = I[i] 71 | i = i + 1 72 | return canvas 73 | 74 | 75 | def _prepare_image(I): 76 | assert isinstance(I, np.ndarray), 'plugin error, should pass numpy array here' 77 | assert I.ndim == 2 or I.ndim == 3 or I.ndim == 4 78 | if I.ndim == 4: # NCHW 79 | if I.shape[1] == 1: # N1HW 80 | I = np.concatenate((I, I, I), 1) # N3HW 81 | assert I.shape[1] == 3 82 | I = make_grid(I) # 3xHxW 83 | if I.ndim == 3 and I.shape[0] == 1: # 1xHxW 84 | I = np.concatenate((I, I, I), 0) # 3xHxW 85 | if I.ndim == 2: # HxW 86 | I = np.expand_dims(I, 0) # 1xHxW 87 | I = np.concatenate((I, I, I), 0) # 3xHxW 88 | I = I.transpose(1, 2, 0) 89 | 90 | return I 91 | -------------------------------------------------------------------------------- /tensorboardX/src/attr_value.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "AttrValueProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/tensor.proto"; 10 | import "tensorboardX/src/tensor_shape.proto"; 11 | import "tensorboardX/src/types.proto"; 12 | 13 | // Protocol buffer representing the value for an attr used to configure an Op. 14 | // Comment indicates the corresponding attr type. Only the field matching the 15 | // attr type may be filled. 16 | message AttrValue { 17 | // LINT.IfChange 18 | message ListValue { 19 | repeated bytes s = 2; // "list(string)" 20 | repeated int64 i = 3 [packed = true]; // "list(int)" 21 | repeated float f = 4 [packed = true]; // "list(float)" 22 | repeated bool b = 5 [packed = true]; // "list(bool)" 23 | repeated DataType type = 6 [packed = true]; // "list(type)" 24 | repeated TensorShapeProto shape = 7; // "list(shape)" 25 | repeated TensorProto tensor = 8; // "list(tensor)" 26 | repeated NameAttrList func = 9; // "list(attr)" 27 | } 28 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) 29 | 30 | oneof value { 31 | bytes s = 2; // "string" 32 | int64 i = 3; // "int" 33 | float f = 4; // "float" 34 | bool b = 5; // "bool" 35 | DataType type = 6; // "type" 36 | TensorShapeProto shape = 7; // "shape" 37 | TensorProto tensor = 8; // "tensor" 38 | ListValue list = 1; // any "list(...)" 39 | 40 | // "func" represents a function. func.name is a function's name or 41 | // a primitive op's name. func.attr.first is the name of an attr 42 | // defined for that function. func.attr.second is the value for 43 | // that attr in the instantiation. 44 | NameAttrList func = 10; 45 | 46 | // This is a placeholder only used in nodes defined inside a 47 | // function. It indicates the attr value will be supplied when 48 | // the function is instantiated. For example, let us suppose a 49 | // node "N" in function "FN". "N" has an attr "A" with value 50 | // placeholder = "foo". When FN is instantiated with attr "foo" 51 | // set to "bar", the instantiated node N's attr A will have been 52 | // given the value "bar". 53 | string placeholder = 9; 54 | } 55 | } 56 | 57 | // A list of attr names and their values. The whole list is attached 58 | // with a string name. E.g., MatMul[T=float]. 59 | message NameAttrList { 60 | string name = 1; 61 | map attr = 2; 62 | } 63 | -------------------------------------------------------------------------------- /tensorboardX/src/node_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "NodeProto"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/attr_value.proto"; 10 | 11 | message NodeDef { 12 | // The name given to this operator. Used for naming inputs, 13 | // logging, visualization, etc. Unique within a single GraphDef. 14 | // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". 15 | string name = 1; 16 | 17 | // The operation name. There may be custom parameters in attrs. 18 | // Op names starting with an underscore are reserved for internal use. 19 | string op = 2; 20 | 21 | // Each input is "node:src_output" with "node" being a string name and 22 | // "src_output" indicating which output tensor to use from "node". If 23 | // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs 24 | // may optionally be followed by control inputs that have the format 25 | // "^node". 26 | repeated string input = 3; 27 | 28 | // A (possibly partial) specification for the device on which this 29 | // node should be placed. 30 | // The expected syntax for this string is as follows: 31 | // 32 | // DEVICE_SPEC ::= PARTIAL_SPEC 33 | // 34 | // PARTIAL_SPEC ::= ("/" CONSTRAINT) * 35 | // CONSTRAINT ::= ("job:" JOB_NAME) 36 | // | ("replica:" [1-9][0-9]*) 37 | // | ("task:" [1-9][0-9]*) 38 | // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) 39 | // 40 | // Valid values for this string include: 41 | // * "/job:worker/replica:0/task:1/gpu:3" (full specification) 42 | // * "/job:worker/gpu:3" (partial specification) 43 | // * "" (no specification) 44 | // 45 | // If the constraints do not resolve to a single device (or if this 46 | // field is empty or not present), the runtime will attempt to 47 | // choose a device automatically. 48 | string device = 4; 49 | 50 | // Operation-specific graph-construction-time configuration. 51 | // Note that this should include all attrs defined in the 52 | // corresponding OpDef, including those with a value matching 53 | // the default -- this allows the default to change and makes 54 | // NodeDefs easier to interpret on their own. However, if 55 | // an attr with a default is not specified in this list, the 56 | // default will be used. 57 | // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and 58 | // one of the names from the corresponding OpDef's attr field). 59 | // The values must have a type matching the corresponding OpDef 60 | // attr's type field. 61 | // TODO(josh11b): Add some examples here showing best practices. 62 | map attr = 5; 63 | }; 64 | -------------------------------------------------------------------------------- /tensorboardX/src/plugin_pr_curve_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/plugin_pr_curve.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/plugin_pr_curve.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n&tensorboardX/src/plugin_pr_curve.proto\x12\x0btensorboard\"<\n\x11PrCurvePluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x16\n\x0enum_thresholds\x18\x02 \x01(\rb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _PRCURVEPLUGINDATA = _descriptor.Descriptor( 29 | name='PrCurvePluginData', 30 | full_name='tensorboard.PrCurvePluginData', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='version', full_name='tensorboard.PrCurvePluginData.version', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='num_thresholds', full_name='tensorboard.PrCurvePluginData.num_thresholds', index=1, 44 | number=2, type=13, cpp_type=3, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=55, 62 | serialized_end=115, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['PrCurvePluginData'] = _PRCURVEPLUGINDATA 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | PrCurvePluginData = _reflection.GeneratedProtocolMessageType('PrCurvePluginData', (_message.Message,), dict( 69 | DESCRIPTOR = _PRCURVEPLUGINDATA, 70 | __module__ = 'tensorboardX.src.plugin_pr_curve_pb2' 71 | # @@protoc_insertion_point(class_scope:tensorboard.PrCurvePluginData) 72 | )) 73 | _sym_db.RegisterMessage(PrCurvePluginData) 74 | 75 | 76 | # @@protoc_insertion_point(module_scope) 77 | -------------------------------------------------------------------------------- /reid/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | def No_index(a, b): 11 | assert isinstance(a, list) 12 | return [i for i, j in enumerate(a) if j != b] 13 | 14 | 15 | class RandomIdentitySampler(Sampler): 16 | 17 | def __init__(self, data_source, num_instances=1): 18 | self.data_source = data_source 19 | self.num_instances = num_instances 20 | self.index_dic = defaultdict(list) 21 | for index, (_, pid, _) in enumerate(data_source): 22 | self.index_dic[pid].append(index) 23 | self.pids = list(self.index_dic.keys()) 24 | self.num_samples = len(data_source) 25 | 26 | def __len__(self): 27 | return self.num_samples * self.num_instances 28 | 29 | 30 | def __iter__(self): 31 | indices = torch.randperm(self.num_samples) 32 | ret = [] 33 | for i in indices: 34 | pid = self.pids[i] 35 | t = self.index_dic[pid] 36 | if len(t) >= self.num_instances: 37 | t = np.random.choice(t, size=self.num_instances, replace=False) 38 | else: 39 | t = np.random.choice(t, size=self.num_instances, replace=True) 40 | ret.extend(t) 41 | return iter(ret) 42 | 43 | 44 | class RandomPairSampler(Sampler): 45 | def __init__(self, data_source): 46 | self.data_source = data_source 47 | self.index_pid = defaultdict(int) 48 | self.pid_cam = defaultdict(list) 49 | self.pid_index = defaultdict(list) 50 | self.num_samples = len(data_source) 51 | for index, (_, _, _, pid, cam) in enumerate(data_source): 52 | self.index_pid[index] = pid 53 | self.pid_cam[pid].append(cam) 54 | self.pid_index[pid].append(index) 55 | 56 | def __len__(self): 57 | return self.num_samples * 2 58 | 59 | def __iter__(self): 60 | indices = torch.randperm(self.num_samples) 61 | ret = [] 62 | for i in indices: 63 | _, _, i_label, i_pid, i_cam = self.data_source[i] 64 | ret.append(i) 65 | pid_i = self.index_pid[i] 66 | cams = self.pid_cam[pid_i] 67 | index = self.pid_index[pid_i] 68 | select_cams = No_index(cams, i_cam) 69 | try: 70 | select_camind = np.random.choice(select_cams) 71 | except ValueError: 72 | print(cams) 73 | print(pid_i) 74 | print(i_label) 75 | select_ind = index[select_camind] 76 | ret.append(select_ind) 77 | 78 | return iter(ret) -------------------------------------------------------------------------------- /reid/models/alexnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | import torchvision 6 | 7 | class AlexNet(nn.Module): 8 | def __init__(self, pretrained=True, cut_at_pooling=False, num_features=0, dropout=0): 9 | super(AlexNet, self).__init__() 10 | 11 | self.pretrained = pretrained 12 | self.has_embedding = num_features > 0 13 | self.num_features = num_features 14 | 15 | self.base = torchvision.models.alexnet(pretrained=pretrained) 16 | self.features = self.base.features 17 | self.classifier = self.base.classifier 18 | conv0 = nn.Conv2d(2, 64, kernel_size=11, stride=4, padding=2, bias=False) 19 | init.kaiming_normal(conv0.weight, mode='fan_out') 20 | self.conv0 = conv0 21 | 22 | out_planes = self.classifier._modules['1'].in_features 23 | 24 | self.feat1 = nn.Linear(5376, 2048) 25 | self.feat_bn1 = nn.BatchNorm1d(2048) 26 | init.kaiming_normal(self.feat1.weight, mode='fan_out') 27 | init.constant(self.feat1.bias, 0) 28 | init.constant(self.feat_bn1.weight, 1) 29 | init.constant(self.feat_bn1.bias, 0) 30 | 31 | if self.has_embedding: 32 | self.feat = nn.Linear(2048, self.num_features) 33 | self.feat_bn = nn.BatchNorm1d(self.num_features) 34 | init.kaiming_normal(self.feat.weight, mode='fan_out') 35 | init.constant(self.feat.bias, 0) 36 | init.constant(self.feat_bn.weight, 1) 37 | init.constant(self.feat_bn.bias, 0) 38 | 39 | 40 | def forward(self, imgs, motions, mode): 41 | img_size = imgs.size() 42 | motion_size = motions.size() 43 | batch_sz = img_size[0] 44 | seq_len = img_size[1] 45 | imgs = imgs.view(-1, img_size[2], img_size[3], img_size[4]) 46 | motions = motions.view(-1, motion_size[2], motion_size[3], motion_size[4]) 47 | motions = motions[:, 1:3] 48 | 49 | for name, module in self.features._modules.items(): 50 | if name == '0': 51 | x = module(imgs) + self.conv0(motions) 52 | continue 53 | x = module(x) 54 | 55 | x = x.view(x.size(0), -1) 56 | x = self.feat1(x) 57 | x = self.feat_bn1(x) 58 | if mode == 'cnn_rnn': 59 | raw = x.view(batch_sz, seq_len, -1) 60 | if self.has_embedding: 61 | x = self.feat(x) 62 | x = self.feat_bn(x) 63 | 64 | if mode == 'cnn_rnn': 65 | # x = x / x.norm(2, 1).expand_as(x) 66 | x = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 67 | x = x.view(batch_sz, seq_len, -1) 68 | return x, raw 69 | 70 | def alexnet(**kwargs): 71 | return AlexNet(**kwargs) 72 | -------------------------------------------------------------------------------- /tensorboardX/src/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/resource_handle.proto"; 10 | import "tensorboardX/src/tensor_shape.proto"; 11 | import "tensorboardX/src/types.proto"; 12 | 13 | // Protocol buffer representing a tensor. 14 | message TensorProto { 15 | DataType dtype = 1; 16 | 17 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 18 | TensorShapeProto tensor_shape = 2; 19 | 20 | // Only one of the representations below is set, one of "tensor_contents" and 21 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 22 | // contain repeated fields it would require another extra set of messages. 23 | 24 | // Version number. 25 | // 26 | // In version 0, if the "repeated xxx" representations contain only one 27 | // element, that element is repeated to fill the shape. This makes it easy 28 | // to represent a constant Tensor with a single value. 29 | int32 version_number = 3; 30 | 31 | // Serialized raw tensor content from either Tensor::AsProtoTensorContent or 32 | // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation 33 | // can be used for all tensor types. The purpose of this representation is to 34 | // reduce serialization overhead during RPC call by avoiding serialization of 35 | // many repeated small items. 36 | bytes tensor_content = 4; 37 | 38 | // Type specific representations that make it easy to create tensor protos in 39 | // all languages. Only the representation corresponding to "dtype" can 40 | // be set. The values hold the flattened representation of the tensor in 41 | // row major order. 42 | 43 | // DT_HALF. Note that since protobuf has no int16 type, we'll have some 44 | // pointless zero padding for each value here. 45 | repeated int32 half_val = 13 [packed = true]; 46 | 47 | // DT_FLOAT. 48 | repeated float float_val = 5 [packed = true]; 49 | 50 | // DT_DOUBLE. 51 | repeated double double_val = 6 [packed = true]; 52 | 53 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. 54 | repeated int32 int_val = 7 [packed = true]; 55 | 56 | // DT_STRING 57 | repeated bytes string_val = 8; 58 | 59 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 60 | // and imaginary parts of i-th single precision complex. 61 | repeated float scomplex_val = 9 [packed = true]; 62 | 63 | // DT_INT64 64 | repeated int64 int64_val = 10 [packed = true]; 65 | 66 | // DT_BOOL 67 | repeated bool bool_val = 11 [packed = true]; 68 | 69 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real 70 | // and imaginary parts of i-th double precision complex. 71 | repeated double dcomplex_val = 12 [packed = true]; 72 | 73 | // DT_RESOURCE 74 | repeated ResourceHandleProto resource_handle_val = 14; 75 | }; 76 | -------------------------------------------------------------------------------- /reid/models/crosspoolingdir.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | from torch import nn 4 | import torch.nn.init as init 5 | 6 | class CrossPoolingDir(nn.Module): 7 | 8 | def __init__(self, input_num, output_num, feat_fc=None): 9 | super(CrossPoolingDir, self).__init__() 10 | self.input_num = input_num 11 | self.output_num = output_num 12 | 13 | ## Linear_K 14 | if feat_fc is None: 15 | self.featK = nn.Sequential(nn.Linear(self.input_num, self.output_num), 16 | nn.BatchNorm1d(self.output_num)) 17 | for m in self.featK.modules(): 18 | if isinstance(m, nn.Linear): 19 | init.kaiming_normal(m.weight, mode='fan_out') 20 | init.constant(m.bias, 0) 21 | elif isinstance(m, nn.BatchNorm1d): 22 | init.constant(m.weight, 1) 23 | init.constant(m.bias, 0) 24 | else: 25 | print(type(m)) 26 | else: 27 | self.featK = feat_fc 28 | 29 | 30 | ## Softmax 31 | self.softmax = nn.Softmax() 32 | 33 | def forward(self, gallery_value, gallery_base, querys): 34 | 35 | gal_size = gallery_value.size() 36 | gal_batch = gal_size[0] 37 | gal_len = gal_size[1] 38 | 39 | ## Linear self-transorfmation 40 | Q_size = querys.size() 41 | pro_batch = Q_size[0] 42 | Q_featnum = Q_size[1] 43 | 44 | K = gallery_base.view(gal_batch * gal_len, -1) 45 | K = self.featK(K) 46 | # K = self.featK_bn(K) 47 | K = K.view(gal_batch, gal_len, -1) 48 | # K: gal_batch x gal_len x H_featnum 49 | # query: pro_batch x H_featnum 50 | 51 | Q = querys.unsqueeze(1) 52 | Q = Q.unsqueeze(1) 53 | K = K.unsqueeze(0) 54 | 55 | # Q: pro_batch x 1 x 1 x Q_featnum 56 | # K: 1 x gal_batch x gal_len x Q_featnum 57 | 58 | Q = Q.expand(pro_batch, gal_batch, gal_len, Q_featnum) 59 | K = K.expand(pro_batch, gal_batch, gal_len, Q_featnum) 60 | 61 | QK = Q * K 62 | QK = QK.permute(0, 1, 3, 2) 63 | 64 | # pro_batch x gal_batch x Q_featnum x gal_len 65 | QK = QK.contiguous() 66 | QK = QK.view(-1, gal_len) 67 | weights = self.softmax(QK) 68 | weights = weights.view(pro_batch, gal_batch, Q_featnum, gal_len) 69 | 70 | # gallery : gal_batch x gal_len x Q_featnum 71 | gallery_value = gallery_value.permute(0, 2, 1) 72 | # gallery : gal_batch x Q_featnum x gal_len 73 | gallery_value = gallery_value.contiguous() 74 | gallery_value = gallery_value.unsqueeze(0) 75 | # gallery : 1 x gal_batch x Q_featnum x gal_len 76 | gallery_value = gallery_value.expand(pro_batch, gal_batch, Q_featnum, gal_len) 77 | # gallery : pro_batch x gal_batch x Q_featnum x gal_len 78 | pool_gallery = (weights * gallery_value).sum(3) 79 | # pool_gallery = pool_gallery.squeeze(3) 80 | 81 | return pool_gallery 82 | -------------------------------------------------------------------------------- /reid/models/attmodel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from torch import nn 3 | from reid.models import SelfPoolingDir 4 | from reid.models import CrossPoolingDir 5 | import torch.nn.init as init 6 | 7 | 8 | class AttModuleDir(nn.Module): 9 | 10 | def __init__(self, input_num, output_num, same_fc=True): #2048 ,128 11 | super(AttModuleDir, self).__init__() 12 | 13 | self.input_num = input_num 14 | self.output_num = output_num 15 | 16 | ## attention modules 17 | if same_fc: 18 | self.feat_fc = nn.Sequential(nn.Linear(self.input_num, self.output_num), 19 | nn.BatchNorm1d(self.output_num)) 20 | for m in self.feat_fc.modules(): 21 | if isinstance(m, nn.Linear): 22 | init.kaiming_normal(m.weight, mode='fan_out') 23 | init.constant(m.bias, 0) 24 | elif isinstance(m, nn.BatchNorm1d): 25 | init.constant(m.weight, 1) 26 | init.constant(m.bias, 0) 27 | else: 28 | print(type(m)) 29 | 30 | self.selfpooling_model = SelfPoolingDir(self.input_num, self.output_num, feat_fc=self.feat_fc) 31 | self.crosspooling_model = CrossPoolingDir(self.input_num, self.output_num, feat_fc=self.feat_fc) 32 | else: 33 | self.selfpooling_model = SelfPoolingDir(self.input_num, self.output_num) 34 | self.crosspooling_model = CrossPoolingDir(self.input_num, self.output_num) 35 | 36 | 37 | def forward(self, x, inputs): #x(bz*sq*128) input(bz*sq*2048) 38 | xsize = x.size() 39 | sample_num = xsize[0] # 64 40 | 41 | if sample_num % 2 != 0: 42 | raise RuntimeError("the batch size should be even number!") 43 | 44 | seq_len = x.size()[1] # 10 45 | x = x.view(int(sample_num/2), 2, seq_len, -1) #32*2*10*128 46 | inputs = inputs.view(int(sample_num/2), 2, seq_len, -1) #32*2*10*2048 47 | probe_x = x[:, 0, :, :] # 32*10*128 48 | probe_x = probe_x.contiguous() 49 | gallery_x = x[:, 1, :, :] # 32*10*128 50 | gallery_x = gallery_x.contiguous() 51 | 52 | probe_input = inputs[:, 0, :, :] # 32*10*2048 53 | probe_input = probe_input.contiguous() 54 | gallery_input = inputs[:, 1, :, :] # 32*10*2048 55 | gallery_input = gallery_input.contiguous() 56 | 57 | ## self-pooling 58 | pooled_probe, hidden_probe = self.selfpooling_model(probe_x, probe_input) 59 | pooled_gallery, hidden_gallery = self.selfpooling_model(gallery_x, gallery_input) 60 | 61 | ## cross-pooling 62 | # gallery_x(32*10*128), gallery_input(32*10*2048), pooled_probe(32*128) 63 | pooled_gallery_2 = self.crosspooling_model(gallery_x, gallery_input, pooled_probe) 64 | pooled_probe_2 = self.crosspooling_model(probe_x, probe_input, pooled_gallery) 65 | 66 | pooled_probe_2 = pooled_probe_2.permute(1, 0, 2) 67 | pooled_probe, pooled_gallery = pooled_probe.unsqueeze(1), pooled_gallery.unsqueeze(0) 68 | # 32*1*128, 32*32*128, 32*32*128, 1*32*128 69 | return pooled_probe, pooled_gallery_2, pooled_probe_2, pooled_gallery # (bz/2) * 128, (bz/2)*(bz/2)*128 70 | -------------------------------------------------------------------------------- /tensorboardX/src/versions_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/versions.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/versions.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n\x1ftensorboardX/src/versions.proto\x12\x0btensorboard\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x42/\n\x18org.tensorflow.frameworkB\x0eVersionsProtosP\x01\xf8\x01\x01\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _VERSIONDEF = _descriptor.Descriptor( 29 | name='VersionDef', 30 | full_name='tensorboard.VersionDef', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='producer', full_name='tensorboard.VersionDef.producer', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='min_consumer', full_name='tensorboard.VersionDef.min_consumer', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='bad_consumers', full_name='tensorboard.VersionDef.bad_consumers', index=2, 51 | number=3, type=5, cpp_type=1, label=3, 52 | has_default_value=False, default_value=[], 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=48, 69 | serialized_end=123, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['VersionDef'] = _VERSIONDEF 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | VersionDef = _reflection.GeneratedProtocolMessageType('VersionDef', (_message.Message,), dict( 76 | DESCRIPTOR = _VERSIONDEF, 77 | __module__ = 'tensorboardX.src.versions_pb2' 78 | # @@protoc_insertion_point(class_scope:tensorboard.VersionDef) 79 | )) 80 | _sym_db.RegisterMessage(VersionDef) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\016VersionsProtosP\001\370\001\001')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /reid/data/seqpreprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | import torch 4 | from PIL import Image 5 | 6 | 7 | 8 | class SeqTrainPreprocessor(object): 9 | def __init__(self, seqset, dataset, seq_len, transform=None): 10 | super(SeqTrainPreprocessor, self).__init__() 11 | self.seqset = seqset 12 | self.identities = dataset.identities 13 | self.transform = transform 14 | self.seq_len = seq_len 15 | self.root = [dataset.images_dir] 16 | self.root.append(dataset.other_dir) 17 | 18 | def __len__(self): 19 | return len(self.seqset) 20 | 21 | 22 | def __getitem__(self, indices): 23 | if isinstance(indices, (tuple, list)): 24 | return [self._get_single_item(index) for index in indices] 25 | return self._get_single_item(indices) 26 | 27 | def _get_single_item(self, index): 28 | 29 | start_ind, end_ind, pid, label, camid = self.seqset[index] 30 | 31 | imgseq = [] 32 | flowseq = [] 33 | for ind in range(start_ind, end_ind): 34 | fname = self.identities[pid][camid][ind] 35 | fpath_img = osp.join(self.root[0], fname) 36 | imgrgb = Image.open(fpath_img).convert('RGB') 37 | fpath_flow = osp.join(self.root[1], fname) 38 | flowrgb = Image.open(fpath_flow).convert('RGB') 39 | imgseq.append(imgrgb) 40 | flowseq.append(flowrgb) 41 | 42 | while (len(imgseq) < self.seq_len): 43 | imgseq.append(imgrgb) 44 | flowseq.append(flowrgb) 45 | 46 | seq = [imgseq, flowseq] 47 | 48 | if self.transform is not None: 49 | seq = self.transform(seq) 50 | 51 | img_tensor = torch.stack(seq[0], 0) 52 | 53 | flow_tensor = torch.stack(seq[1], 0) 54 | 55 | return img_tensor, flow_tensor, label, camid 56 | 57 | 58 | 59 | class SeqTestPreprocessor(object): 60 | 61 | def __init__(self, seqset, dataset, seq_len, transform=None): 62 | super(SeqTestPreprocessor, self).__init__() 63 | self.seqset = seqset 64 | self.identities = dataset.identities 65 | self.transform = transform 66 | self.seq_len = seq_len 67 | self.root = [dataset.images_dir] 68 | self.root.append(dataset.other_dir) 69 | 70 | def __len__(self): 71 | return len(self.seqset) 72 | 73 | def __getitem__(self, indices): 74 | if isinstance(indices, (tuple, list)): 75 | return [self._get_single_item(index) for index in indices] 76 | return self._get_single_item(indices) 77 | 78 | def _get_single_item(self, index): 79 | 80 | start_ind, end_ind, pid, label, camid = self.seqset[index] 81 | 82 | imgseq = [] 83 | flowseq = [] 84 | for ind in range(start_ind, end_ind): 85 | fname = self.identities[pid][camid][ind] 86 | fpath_img = osp.join(self.root[0], fname) 87 | imgrgb = Image.open(fpath_img).convert('RGB') 88 | fpath_flow = osp.join(self.root[1], fname) 89 | flowrgb = Image.open(fpath_flow).convert('RGB') 90 | imgseq.append(imgrgb) 91 | flowseq.append(flowrgb) 92 | 93 | while (len(imgseq) < self.seq_len): 94 | imgseq.append(imgrgb) 95 | flowseq.append(flowrgb) 96 | 97 | seq = [imgseq, flowseq] 98 | 99 | if self.transform is not None: 100 | seq = self.transform(seq) 101 | 102 | img_tensor = torch.stack(seq[0], 0) 103 | 104 | if len(self.root) == 2: 105 | flow_tensor = torch.stack(seq[1], 0) 106 | else: 107 | flow_tensor = None 108 | 109 | return img_tensor, flow_tensor, pid, camid -------------------------------------------------------------------------------- /reid/data/datasequence.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | import numpy as np 4 | from utils.serialization import read_json 5 | 6 | 7 | def _pluckseq(identities, indices, seq_len, seq_str): 8 | ret = [] 9 | for index, pid in enumerate(indices): 10 | pid_images = identities[pid] 11 | for camid, cam_images in enumerate(pid_images): 12 | seqall = len(cam_images) 13 | seq_inds = [(start_ind, start_ind + seq_len)\ 14 | for start_ind in range(0, seqall-seq_len, seq_str)] 15 | 16 | if not seq_inds: 17 | seq_inds = [(0, seqall)] 18 | for seq_ind in seq_inds: 19 | ret.append((seq_ind[0], seq_ind[1], pid, index, camid)) 20 | return ret 21 | 22 | 23 | 24 | class Datasequence(object): 25 | def __init__(self, root, split_id= 0): 26 | self.root = root 27 | self.split_id = split_id 28 | self.meta = None 29 | self.split = None 30 | self.train, self.val, self.trainval = [], [], [] 31 | self.query, self.gallery = [], [] 32 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 33 | self.identities = [] 34 | 35 | @property 36 | def images_dir(self): 37 | return osp.join(self.root, 'images') 38 | 39 | def load(self, seq_len, seq_str, num_val=0.3, verbose=True): 40 | splits = read_json(osp.join(self.root, 'splits.json')) 41 | if self.split_id >= len(splits): 42 | raise ValueError("split_id exceeds total splits {}" 43 | .format(len(splits))) 44 | 45 | self.split = splits[self.split_id] 46 | 47 | # Randomly split train / val 48 | trainval_pids = np.asarray(self.split['trainval']) 49 | np.random.shuffle(trainval_pids) 50 | num = len(trainval_pids) 51 | 52 | if isinstance(num_val, float): 53 | num_val = int(round(num * num_val)) 54 | if num_val >= num or num_val < 0: 55 | raise ValueError("num_val exceeds total identities {}" 56 | .format(num)) 57 | 58 | train_pids = sorted(trainval_pids[:-num_val]) 59 | val_pids = sorted(trainval_pids[-num_val:]) 60 | 61 | # comments validation set changes every time it loads 62 | 63 | self.meta = read_json(osp.join(self.root, 'meta.json')) 64 | identities = self.meta['identities'] 65 | self.identities = identities 66 | self.train = _pluckseq(identities, train_pids, seq_len, seq_str) 67 | self.val = _pluckseq(identities, val_pids, seq_len, seq_str) 68 | self.trainval = _pluckseq(identities, trainval_pids, seq_len, seq_str) 69 | self.num_train_ids = len(train_pids) 70 | self.num_val_ids = len(val_pids) 71 | self.num_trainval_ids = len(trainval_pids) 72 | 73 | 74 | 75 | 76 | if verbose: 77 | print(self.__class__.__name__, "dataset loaded") 78 | print(" subset | # ids | # sequences") 79 | print(" ---------------------------") 80 | print(" train | {:5d} | {:8d}" 81 | .format(self.num_train_ids, len(self.train))) 82 | print(" val | {:5d} | {:8d}" 83 | .format(self.num_val_ids, len(self.val))) 84 | print(" trainval | {:5d} | {:8d}" 85 | .format(self.num_trainval_ids, len(self.trainval))) 86 | print(" query | {:5d} | {:8d}" 87 | .format(len(self.split['query']), len(self.split['query']))) 88 | print(" gallery | {:5d} | {:8d}" 89 | .format(len(self.split['gallery']), len(self.split['gallery']))) 90 | 91 | def _check_integrity(self): 92 | return osp.isdir(osp.join(self.root, 'images')) and \ 93 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 94 | osp.isfile(osp.join(self.root, 'splits.json')) -------------------------------------------------------------------------------- /tensorboardX/src/graph_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/graph.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorboardX.src import node_def_pb2 as tensorboardX_dot_src_dot_node__def__pb2 17 | from tensorboardX.src import versions_pb2 as tensorboardX_dot_src_dot_versions__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='tensorboardX/src/graph.proto', 22 | package='tensorboard', 23 | syntax='proto3', 24 | serialized_pb=_b('\n\x1ctensorboardX/src/graph.proto\x12\x0btensorboard\x1a\x1ftensorboardX/src/node_def.proto\x1a\x1ftensorboardX/src/versions.proto\"n\n\x08GraphDef\x12\"\n\x04node\x18\x01 \x03(\x0b\x32\x14.tensorboard.NodeDef\x12)\n\x08versions\x18\x04 \x01(\x0b\x32\x17.tensorboard.VersionDef\x12\x13\n\x07version\x18\x03 \x01(\x05\x42\x02\x18\x01\x42,\n\x18org.tensorflow.frameworkB\x0bGraphProtosP\x01\xf8\x01\x01\x62\x06proto3') 25 | , 26 | dependencies=[tensorboardX_dot_src_dot_node__def__pb2.DESCRIPTOR,tensorboardX_dot_src_dot_versions__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _GRAPHDEF = _descriptor.Descriptor( 32 | name='GraphDef', 33 | full_name='tensorboard.GraphDef', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='node', full_name='tensorboard.GraphDef.node', index=0, 40 | number=1, type=11, cpp_type=10, label=3, 41 | has_default_value=False, default_value=[], 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='versions', full_name='tensorboard.GraphDef.versions', index=1, 47 | number=4, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | _descriptor.FieldDescriptor( 53 | name='version', full_name='tensorboard.GraphDef.version', index=2, 54 | number=3, type=5, cpp_type=1, label=1, 55 | has_default_value=False, default_value=0, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001'))), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=111, 72 | serialized_end=221, 73 | ) 74 | 75 | _GRAPHDEF.fields_by_name['node'].message_type = tensorboardX_dot_src_dot_node__def__pb2._NODEDEF 76 | _GRAPHDEF.fields_by_name['versions'].message_type = tensorboardX_dot_src_dot_versions__pb2._VERSIONDEF 77 | DESCRIPTOR.message_types_by_name['GraphDef'] = _GRAPHDEF 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), dict( 81 | DESCRIPTOR = _GRAPHDEF, 82 | __module__ = 'tensorboardX.src.graph_pb2' 83 | # @@protoc_insertion_point(class_scope:tensorboard.GraphDef) 84 | )) 85 | _sym_db.RegisterMessage(GraphDef) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\013GraphProtosP\001\370\001\001')) 90 | _GRAPHDEF.fields_by_name['version'].has_options = True 91 | _GRAPHDEF.fields_by_name['version']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /tensorboardX/src/resource_handle_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/resource_handle.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/resource_handle.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n&tensorboardX/src/resource_handle.proto\x12\x0btensorboard\"r\n\x13ResourceHandleProto\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04 \x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05 \x01(\tB/\n\x18org.tensorflow.frameworkB\x0eResourceHandleP\x01\xf8\x01\x01\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _RESOURCEHANDLEPROTO = _descriptor.Descriptor( 29 | name='ResourceHandleProto', 30 | full_name='tensorboard.ResourceHandleProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='device', full_name='tensorboard.ResourceHandleProto.device', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='container', full_name='tensorboard.ResourceHandleProto.container', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='name', full_name='tensorboard.ResourceHandleProto.name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='hash_code', full_name='tensorboard.ResourceHandleProto.hash_code', index=3, 58 | number=4, type=4, cpp_type=4, label=1, 59 | has_default_value=False, default_value=0, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='maybe_type_name', full_name='tensorboard.ResourceHandleProto.maybe_type_name', index=4, 65 | number=5, type=9, cpp_type=9, label=1, 66 | has_default_value=False, default_value=_b("").decode('utf-8'), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | syntax='proto3', 79 | extension_ranges=[], 80 | oneofs=[ 81 | ], 82 | serialized_start=55, 83 | serialized_end=169, 84 | ) 85 | 86 | DESCRIPTOR.message_types_by_name['ResourceHandleProto'] = _RESOURCEHANDLEPROTO 87 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 88 | 89 | ResourceHandleProto = _reflection.GeneratedProtocolMessageType('ResourceHandleProto', (_message.Message,), dict( 90 | DESCRIPTOR = _RESOURCEHANDLEPROTO, 91 | __module__ = 'tensorboardX.src.resource_handle_pb2' 92 | # @@protoc_insertion_point(class_scope:tensorboard.ResourceHandleProto) 93 | )) 94 | _sym_db.RegisterMessage(ResourceHandleProto) 95 | 96 | 97 | DESCRIPTOR.has_options = True 98 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\016ResourceHandleP\001\370\001\001')) 99 | # @@protoc_insertion_point(module_scope) 100 | -------------------------------------------------------------------------------- /tensorboardX/src/summary.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "SummaryProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/tensor.proto"; 10 | 11 | // Metadata associated with a series of Summary data 12 | message SummaryDescription { 13 | // Hint on how plugins should process the data in this series. 14 | // Supported values include "scalar", "histogram", "image", "audio" 15 | string type_hint = 1; 16 | } 17 | 18 | // Serialization format for histogram module in 19 | // core/lib/histogram/histogram.h 20 | message HistogramProto { 21 | double min = 1; 22 | double max = 2; 23 | double num = 3; 24 | double sum = 4; 25 | double sum_squares = 5; 26 | 27 | // Parallel arrays encoding the bucket boundaries and the bucket values. 28 | // bucket(i) is the count for the bucket i. The range for 29 | // a bucket is: 30 | // i == 0: -DBL_MAX .. bucket_limit(0) 31 | // i != 0: bucket_limit(i-1) .. bucket_limit(i) 32 | repeated double bucket_limit = 6 [packed = true]; 33 | repeated double bucket = 7 [packed = true]; 34 | }; 35 | 36 | // A SummaryMetadata encapsulates information on which plugins are able to make 37 | // use of a certain summary value. 38 | message SummaryMetadata { 39 | message PluginData { 40 | // The name of the plugin this data pertains to. 41 | string plugin_name = 1; 42 | 43 | // The content to store for the plugin. The best practice is for this JSON 44 | // string to be the canonical JSON serialization of a protocol buffer 45 | // defined by the plugin. Converting that protobuf to and from JSON is the 46 | // responsibility of the plugin code, and is not enforced by 47 | // TensorFlow/TensorBoard. 48 | string content = 2; 49 | } 50 | 51 | // A list of plugin data. A single summary value instance may be used by more 52 | // than 1 plugin. 53 | repeated PluginData plugin_data = 1; 54 | }; 55 | 56 | // A Summary is a set of named values to be displayed by the 57 | // visualizer. 58 | // 59 | // Summaries are produced regularly during training, as controlled by 60 | // the "summary_interval_secs" attribute of the training operation. 61 | // Summaries are also produced at the end of an evaluation. 62 | message Summary { 63 | message Image { 64 | // Dimensions of the image. 65 | int32 height = 1; 66 | int32 width = 2; 67 | // Valid colorspace values are 68 | // 1 - grayscale 69 | // 2 - grayscale + alpha 70 | // 3 - RGB 71 | // 4 - RGBA 72 | // 5 - DIGITAL_YUV 73 | // 6 - BGRA 74 | int32 colorspace = 3; 75 | // Image data in encoded format. All image formats supported by 76 | // image_codec::CoderUtil can be stored here. 77 | bytes encoded_image_string = 4; 78 | } 79 | 80 | message Audio { 81 | // Sample rate of the audio in Hz. 82 | float sample_rate = 1; 83 | // Number of channels of audio. 84 | int64 num_channels = 2; 85 | // Length of the audio in frames (samples per channel). 86 | int64 length_frames = 3; 87 | // Encoded audio data and its associated RFC 2045 content type (e.g. 88 | // "audio/wav"). 89 | bytes encoded_audio_string = 4; 90 | string content_type = 5; 91 | } 92 | 93 | message Value { 94 | // Name of the node that output this summary; in general, the name of a 95 | // TensorSummary node. If the node in question has multiple outputs, then 96 | // a ":\d+" suffix will be appended, like "some_op:13". 97 | // Might not be set for legacy summaries (i.e. those not using the tensor 98 | // value field) 99 | string node_name = 7; 100 | 101 | // Tag name for the data. Will only be used by legacy summaries 102 | // (ie. those not using the tensor value field) 103 | // For legacy summaries, will be used as the title of the graph 104 | // in the visualizer. 105 | // 106 | // Tag is usually "op_name:value_name", where "op_name" itself can have 107 | // structure to indicate grouping. 108 | string tag = 1; 109 | SummaryMetadata metadata = 9; 110 | // Value associated with the tag. 111 | oneof value { 112 | float simple_value = 2; 113 | bytes obsolete_old_style_histogram = 3; 114 | Image image = 4; 115 | HistogramProto histo = 5; 116 | Audio audio = 6; 117 | TensorProto tensor = 8; 118 | } 119 | } 120 | 121 | // Set of values for the summary. 122 | repeated Value value = 1; 123 | } 124 | -------------------------------------------------------------------------------- /reid/models/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | from torch import nn 4 | import torch.nn.init as init 5 | import numpy as np 6 | 7 | 8 | class Classifier(nn.Module): 9 | def __init__(self, feat_num, class_num, drop=0): 10 | super(Classifier, self).__init__() 11 | self.feat_num = feat_num 12 | self.class_num = class_num 13 | self.drop = drop 14 | 15 | # BN layer 16 | self.classifierBN = nn.BatchNorm1d(self.feat_num) 17 | # feat classifeir 18 | self.classifierlinear = nn.Linear(self.feat_num, self.class_num) 19 | # dropout_layer 20 | self.drop = drop 21 | if self.drop > 0: 22 | self.droplayer = nn.Dropout(drop) 23 | 24 | init.constant(self.classifierBN.weight, 1) 25 | init.constant(self.classifierBN.bias, 0) 26 | 27 | init.normal(self.classifierlinear.weight, std=0.001) 28 | init.constant(self.classifierlinear.bias, 0) 29 | 30 | def forward(self, probe, gallery2, probe2, gallery): 31 | S_gallery2 = gallery2.size() 32 | N_probe = S_gallery2[0] 33 | N_gallery = S_gallery2[1] 34 | feat_num = S_gallery2[2] 35 | 36 | probe = probe.expand(N_probe, N_gallery, feat_num) 37 | gallery = gallery.expand(N_probe, N_gallery, feat_num) 38 | 39 | 40 | slice0 = 30 41 | if N_probe < slice0: 42 | diff1, diff2 = probe - gallery, probe2 - gallery2 43 | diff = diff1 * diff2 44 | pg_size = diff.size() 45 | p_size, g_size = pg_size[0], pg_size[1] 46 | diff = diff.view(p_size * g_size, -1) 47 | diff = diff.contiguous() 48 | diff = self.classifierBN(diff) 49 | if self.drop > 0: 50 | diff = self.droplayer(diff) 51 | cls_encode = self.classifierlinear(diff) 52 | cls_encode = cls_encode.view(p_size, g_size, -1) 53 | 54 | else: 55 | iter_time_0 = int(np.floor(N_probe / slice0)) 56 | for i in range(0, iter_time_0): 57 | before_index_0 = i * slice0 58 | after_index_0 = (i + 1) * slice0 59 | probe_tmp = probe[before_index_0:after_index_0, :, :] 60 | gallery_tmp = gallery[before_index_0:after_index_0, :, :] 61 | probe2_tmp = probe2[before_index_0:after_index_0, :, :] 62 | gallery2_tmp = gallery2[before_index_0:after_index_0, :, :] 63 | diff1_tmp, diff2_tmp = probe_tmp - gallery_tmp, probe2_tmp - gallery2_tmp 64 | # diff1_tmp = diff1[before_index_0:after_index_0, :, :] 65 | # diff2_tmp = diff2[before_index_0:after_index_0, :, :] 66 | diff_tmp = diff1_tmp * diff2_tmp 67 | pg_size = diff_tmp.size() 68 | p_size, g_size = pg_size[0], pg_size[1] 69 | diff_tmp = diff_tmp.view(p_size * g_size, -1) 70 | diff_tmp = diff_tmp.contiguous() 71 | diff_tmp = self.classifierBN(diff_tmp) 72 | if self.drop > 0: 73 | diff_tmp = self.droplayer(diff_tmp) 74 | cls_encode_tmp = self.classifierlinear(diff_tmp) 75 | cls_encode_tmp = cls_encode_tmp.view(p_size, g_size, -1) 76 | if i == 0: 77 | cls_encode = cls_encode_tmp 78 | else: 79 | cls_encode = torch.cat((cls_encode, cls_encode_tmp), 0) 80 | before_index_0 = iter_time_0 * slice0 81 | after_index_0 = N_probe 82 | if after_index_0 > before_index_0: 83 | probe_tmp = probe[before_index_0:after_index_0, :, :] 84 | gallery_tmp = gallery[before_index_0:after_index_0, :, :] 85 | probe2_tmp = probe2[before_index_0:after_index_0, :, :] 86 | gallery2_tmp = gallery2[before_index_0:after_index_0, :, :] 87 | diff1_tmp, diff2_tmp = probe_tmp - gallery_tmp, probe2_tmp - gallery2_tmp 88 | # diff1_tmp = diff1[before_index_0:after_index_0, :, :] 89 | # diff2_tmp = diff2[before_index_0:after_index_0, :, :] 90 | diff_tmp = diff1_tmp * diff2_tmp 91 | pg_size = diff_tmp.size() 92 | p_size, g_size = pg_size[0], pg_size[1] 93 | diff_tmp = diff_tmp.view(p_size * g_size, -1) 94 | diff_tmp = diff_tmp.contiguous() 95 | diff_tmp = self.classifierBN(diff_tmp) 96 | if self.drop > 0: 97 | diff_tmp = self.droplayer(diff_tmp) 98 | cls_encode_tmp = self.classifierlinear(diff_tmp) 99 | cls_encode_tmp = cls_encode_tmp.view(p_size, g_size, -1) 100 | cls_encode = torch.cat((cls_encode, cls_encode_tmp), 0) 101 | 102 | return cls_encode 103 | -------------------------------------------------------------------------------- /tensorboardX/src/tensor_shape_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/tensor_shape.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/tensor_shape.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n#tensorboardX/src/tensor_shape.proto\x12\x0btensorboard\"{\n\x10TensorShapeProto\x12.\n\x03\x64im\x18\x02 \x03(\x0b\x32!.tensorboard.TensorShapeProto.Dim\x12\x14\n\x0cunknown_rank\x18\x03 \x01(\x08\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tB2\n\x18org.tensorflow.frameworkB\x11TensorShapeProtosP\x01\xf8\x01\x01\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor( 29 | name='Dim', 30 | full_name='tensorboard.TensorShapeProto.Dim', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='size', full_name='tensorboard.TensorShapeProto.Dim.size', index=0, 37 | number=1, type=3, cpp_type=2, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='name', full_name='tensorboard.TensorShapeProto.Dim.name', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=142, 62 | serialized_end=175, 63 | ) 64 | 65 | _TENSORSHAPEPROTO = _descriptor.Descriptor( 66 | name='TensorShapeProto', 67 | full_name='tensorboard.TensorShapeProto', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='dim', full_name='tensorboard.TensorShapeProto.dim', index=0, 74 | number=2, type=11, cpp_type=10, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None), 79 | _descriptor.FieldDescriptor( 80 | name='unknown_rank', full_name='tensorboard.TensorShapeProto.unknown_rank', index=1, 81 | number=3, type=8, cpp_type=7, label=1, 82 | has_default_value=False, default_value=False, 83 | message_type=None, enum_type=None, containing_type=None, 84 | is_extension=False, extension_scope=None, 85 | options=None), 86 | ], 87 | extensions=[ 88 | ], 89 | nested_types=[_TENSORSHAPEPROTO_DIM, ], 90 | enum_types=[ 91 | ], 92 | options=None, 93 | is_extendable=False, 94 | syntax='proto3', 95 | extension_ranges=[], 96 | oneofs=[ 97 | ], 98 | serialized_start=52, 99 | serialized_end=175, 100 | ) 101 | 102 | _TENSORSHAPEPROTO_DIM.containing_type = _TENSORSHAPEPROTO 103 | _TENSORSHAPEPROTO.fields_by_name['dim'].message_type = _TENSORSHAPEPROTO_DIM 104 | DESCRIPTOR.message_types_by_name['TensorShapeProto'] = _TENSORSHAPEPROTO 105 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 106 | 107 | TensorShapeProto = _reflection.GeneratedProtocolMessageType('TensorShapeProto', (_message.Message,), dict( 108 | 109 | Dim = _reflection.GeneratedProtocolMessageType('Dim', (_message.Message,), dict( 110 | DESCRIPTOR = _TENSORSHAPEPROTO_DIM, 111 | __module__ = 'tensorboardX.src.tensor_shape_pb2' 112 | # @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto.Dim) 113 | )) 114 | , 115 | DESCRIPTOR = _TENSORSHAPEPROTO, 116 | __module__ = 'tensorboardX.src.tensor_shape_pb2' 117 | # @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto) 118 | )) 119 | _sym_db.RegisterMessage(TensorShapeProto) 120 | _sym_db.RegisterMessage(TensorShapeProto.Dim) 121 | 122 | 123 | DESCRIPTOR.has_options = True 124 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\021TensorShapeProtosP\001\370\001\001')) 125 | # @@protoc_insertion_point(module_scope) 126 | -------------------------------------------------------------------------------- /tensorboardX/crc32c.py: -------------------------------------------------------------------------------- 1 | import array 2 | 3 | 4 | CRC_TABLE = ( 5 | 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 6 | 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, 7 | 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, 8 | 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 9 | 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, 10 | 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, 11 | 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 12 | 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, 13 | 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, 14 | 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 15 | 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, 16 | 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, 17 | 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 18 | 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, 19 | 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, 20 | 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 21 | 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, 22 | 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, 23 | 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 24 | 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, 25 | 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, 26 | 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 27 | 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, 28 | 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, 29 | 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 30 | 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, 31 | 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, 32 | 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 33 | 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, 34 | 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, 35 | 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 36 | 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, 37 | 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, 38 | 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 39 | 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, 40 | 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, 41 | 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 42 | 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, 43 | 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, 44 | 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 45 | 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, 46 | 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, 47 | 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 48 | 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, 49 | 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, 50 | 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 51 | 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, 52 | 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, 53 | 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 54 | 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, 55 | 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, 56 | 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 57 | 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, 58 | 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, 59 | 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 60 | 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, 61 | 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, 62 | 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 63 | 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, 64 | 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, 65 | 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 66 | 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, 67 | 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, 68 | 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, 69 | ) 70 | 71 | 72 | CRC_INIT = 0 73 | 74 | _MASK = 0xFFFFFFFF 75 | 76 | 77 | def crc_update(crc, data): 78 | """Update CRC-32C checksum with data. 79 | 80 | Args: 81 | crc: 32-bit checksum to update as long. 82 | data: byte array, string or iterable over bytes. 83 | 84 | Returns: 85 | 32-bit updated CRC-32C as long. 86 | """ 87 | 88 | if type(data) != array.array or data.itemsize != 1: 89 | buf = array.array("B", data) 90 | else: 91 | buf = data 92 | 93 | crc ^= _MASK 94 | for b in buf: 95 | table_index = (crc ^ b) & 0xff 96 | crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK 97 | return crc ^ _MASK 98 | 99 | 100 | def crc_finalize(crc): 101 | """Finalize CRC-32C checksum. 102 | 103 | This function should be called as last step of crc calculation. 104 | 105 | Args: 106 | crc: 32-bit checksum as long. 107 | 108 | Returns: 109 | finalized 32-bit checksum as long 110 | """ 111 | return crc & _MASK 112 | 113 | 114 | def crc32c(data): 115 | """Compute CRC-32C checksum of the data. 116 | 117 | Args: 118 | data: byte array, string or iterable over bytes. 119 | 120 | Returns: 121 | 32-bit CRC-32C checksum of data as long. 122 | """ 123 | return crc_finalize(crc_update(CRC_INIT, data)) 124 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | class ResNet(nn.Module): 13 | __factory = { 14 | 18: torchvision.models.resnet18, 15 | 34: torchvision.models.resnet34, 16 | 50: torchvision.models.resnet50, 17 | 101: torchvision.models.resnet101, 18 | 152: torchvision.models.resnet152, 19 | } 20 | 21 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 22 | num_features=0, dropout=0): 23 | super(ResNet, self).__init__() 24 | 25 | self.depth = depth 26 | self.pretrained = pretrained 27 | self.cut_at_pooling = cut_at_pooling 28 | 29 | # Construct base (pretrained) resnet 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | 33 | self.base = ResNet.__factory[depth](pretrained=pretrained) 34 | 35 | conv0 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 36 | init.kaiming_normal(conv0.weight, mode='fan_out') 37 | #init.kaiming_uniform(conv0.weight, mode='fan_out') 38 | self.conv0 = conv0 39 | 40 | if not self.cut_at_pooling: 41 | self.num_features = num_features 42 | self.dropout = dropout 43 | self.has_embedding = num_features > 0 44 | 45 | out_planes = self.base.fc.in_features 46 | 47 | # Append new layers 48 | if self.has_embedding: 49 | self.feat = nn.Linear(out_planes, self.num_features) 50 | self.feat_bn = nn.BatchNorm1d(self.num_features) 51 | init.kaiming_normal(self.feat.weight, mode='fan_out') 52 | init.constant(self.feat.bias, 0) 53 | init.constant(self.feat_bn.weight, 1) 54 | init.constant(self.feat_bn.bias, 0) 55 | else: 56 | # Change the num_features to CNN output channels 57 | self.num_features = out_planes 58 | if self.dropout > 0: 59 | self.drop = nn.Dropout(self.dropout) 60 | 61 | 62 | if not self.pretrained: 63 | self.reset_params() 64 | 65 | def forward(self, imgs, motions, mode): 66 | 67 | img_size = imgs.size() 68 | motion_size = motions.size() 69 | batch_sz = img_size[0] 70 | seq_len = img_size[1] 71 | imgs = imgs.view(-1, img_size[2], img_size[3], img_size[4]) 72 | motions = motions.view(-1, motion_size[2], motion_size[3], motion_size[4]) 73 | motions = motions[:, 0:3] 74 | x = imgs 75 | for name, module in self.base._modules.items(): 76 | if name == 'conv1': 77 | x = module(imgs) + self.conv0(motions) 78 | continue 79 | if name == 'avgpool': 80 | break 81 | x = module(x) 82 | 83 | if self.cut_at_pooling: 84 | return x 85 | 86 | x = F.avg_pool2d(x, x.size()[2:]) 87 | x = x.view(x.size(0), -1) 88 | 89 | if mode == 'cnn_rnn': 90 | raw = x.view(batch_sz, seq_len, -1) 91 | 92 | if self.has_embedding: 93 | x = self.feat(x) 94 | x = self.feat_bn(x) 95 | 96 | if self.dropout > 0: 97 | x = self.drop(x) 98 | 99 | 100 | if mode == 'cnn_rnn': 101 | # x = x / x.norm(2, 1).expand_as(x) 102 | x = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 103 | x = x.view(batch_sz, seq_len, -1) 104 | return x, raw 105 | elif mode == 'cnn': 106 | # x = x / x.norm(2, 1).expand_as(x) 107 | x = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 108 | x = x.view(batch_sz, seq_len, -1) 109 | x = torch.squeeze(torch.mean(x, 1), 1) 110 | return x 111 | 112 | def reset_params(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | init.kaiming_normal(m.weight, mode='fan_out') 116 | if m.bias is not None: 117 | init.constant(m.bias, 0) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | init.constant(m.weight, 1) 120 | init.constant(m.bias, 0) 121 | elif isinstance(m, nn.Linear): 122 | init.normal(m.weight, std=0.001) 123 | if m.bias is not None: 124 | init.constant(m.bias, 0) 125 | 126 | 127 | def resnet18(**kwargs): 128 | return ResNet(18, **kwargs) 129 | 130 | 131 | def resnet34(**kwargs): 132 | return ResNet(34, **kwargs) 133 | 134 | 135 | def resnet50(**kwargs): 136 | return ResNet(50, **kwargs) 137 | 138 | 139 | def resnet101(**kwargs): 140 | return ResNet(101, **kwargs) 141 | 142 | 143 | def resnet152(**kwargs): 144 | return ResNet(152, **kwargs) 145 | -------------------------------------------------------------------------------- /reid/evaluator/eva_functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from sklearn.metrics import average_precision_score 7 | from utils import to_torch, to_numpy 8 | 9 | 10 | 11 | def _unique_sample(ids_dict, num): 12 | mask = np.zeros(num, dtype=np.bool) 13 | for _, indices in ids_dict.items(): 14 | i = np.random.choice(indices) 15 | mask[i] = True 16 | return mask 17 | 18 | 19 | def cmc(distmat, query_ids=None, gallery_ids=None, 20 | query_cams=None, gallery_cams=None, topk=100, 21 | separate_camera_set=False, 22 | single_gallery_shot=False, 23 | first_match_break=False): 24 | distmat = to_numpy(distmat) 25 | m, n = distmat.shape 26 | # Fill up default values 27 | if query_ids is None: 28 | query_ids = np.arange(m) 29 | if gallery_ids is None: 30 | gallery_ids = np.arange(n) 31 | if query_cams is None: 32 | query_cams = np.zeros(m).astype(np.int32) 33 | if gallery_cams is None: 34 | gallery_cams = np.ones(n).astype(np.int32) 35 | # Ensure numpy array 36 | query_ids = np.asarray(query_ids) 37 | gallery_ids = np.asarray(gallery_ids) 38 | query_cams = np.asarray(query_cams) 39 | gallery_cams = np.asarray(gallery_cams) 40 | # Sort and find correct matches 41 | indices = np.argsort(distmat, axis=1) 42 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 43 | # Compute CMC for each query 44 | ret = np.zeros(topk) 45 | num_valid_queries = 0 46 | for i in range(m): 47 | # Filter out the same id and same camera 48 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 49 | (gallery_cams[indices[i]] != query_cams[i])) 50 | if separate_camera_set: 51 | # Filter out samples from same camera 52 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 53 | if not np.any(matches[i, valid]): continue 54 | if single_gallery_shot: 55 | repeat = 10 56 | gids = gallery_ids[indices[i][valid]] 57 | inds = np.where(valid)[0] 58 | ids_dict = defaultdict(list) 59 | for j, x in zip(inds, gids): 60 | ids_dict[x].append(j) 61 | else: 62 | repeat = 1 63 | for _ in range(repeat): 64 | if single_gallery_shot: 65 | # Randomly choose one instance for each id 66 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 67 | index = np.nonzero(matches[i, sampled])[0] 68 | else: 69 | index = np.nonzero(matches[i, valid])[0] 70 | delta = 1. / (len(index) * repeat) 71 | for j, k in enumerate(index): 72 | if k - j >= topk: break 73 | if first_match_break: 74 | ret[k - j] += 1 75 | break 76 | ret[k - j] += delta 77 | num_valid_queries += 1 78 | if num_valid_queries == 0: 79 | raise RuntimeError("No valid query") 80 | return ret.cumsum() / num_valid_queries 81 | 82 | 83 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 84 | query_cams=None, gallery_cams=None): 85 | distmat = to_numpy(distmat) 86 | m, n = distmat.shape 87 | # Fill up default values 88 | if query_ids is None: 89 | query_ids = np.arange(m) 90 | if gallery_ids is None: 91 | gallery_ids = np.arange(n) 92 | if query_cams is None: 93 | query_cams = np.zeros(m).astype(np.int32) 94 | if gallery_cams is None: 95 | gallery_cams = np.ones(n).astype(np.int32) 96 | # Ensure numpy array 97 | query_ids = np.asarray(query_ids) 98 | gallery_ids = np.asarray(gallery_ids) 99 | query_cams = np.asarray(query_cams) 100 | gallery_cams = np.asarray(gallery_cams) 101 | # Sort and find correct matches 102 | indices = np.argsort(distmat, axis=1) 103 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 104 | # Compute AP for each query 105 | aps = [] 106 | for i in range(m): 107 | # Filter out the same id and same camera 108 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 109 | (gallery_cams[indices[i]] != query_cams[i])) 110 | y_true = matches[i, valid] 111 | y_score = -distmat[i][indices[i]][valid] 112 | if not np.any(y_true): continue 113 | aps.append(average_precision_score(y_true, y_score)) 114 | if len(aps) == 0: 115 | raise RuntimeError("No valid query") 116 | return np.mean(aps) 117 | 118 | def accuracy(output, target, topk=(1,)): 119 | output, target = to_torch(output), to_torch(target) 120 | maxk = max(topk) 121 | batch_size = target.size(0) 122 | 123 | _, pred = output.topk(maxk, 1, True, True) 124 | pred = pred.t() 125 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 126 | 127 | ret = [] 128 | for k in topk: 129 | correct_k = correct[:k].view(-1).float().sum(0) 130 | ret.append(correct_k.mul_(1. / batch_size)) 131 | return ret 132 | -------------------------------------------------------------------------------- /tensorboardX/graph_onnx.py: -------------------------------------------------------------------------------- 1 | from .src.graph_pb2 import GraphDef 2 | from .src.node_def_pb2 import NodeDef 3 | from .src.versions_pb2 import VersionDef 4 | from .src.attr_value_pb2 import AttrValue 5 | from .src.tensor_shape_pb2 import TensorShapeProto 6 | # from .src.onnx_pb2 import ModelProto 7 | 8 | 9 | def gg(fname): 10 | import onnx # 0.2.1 11 | m = onnx.load(fname) 12 | nodes_proto = [] 13 | nodes = [] 14 | g = m.graph 15 | import itertools 16 | for node in itertools.chain(g.input, g.output): 17 | nodes_proto.append(node) 18 | 19 | for node in nodes_proto: 20 | shapeproto = TensorShapeProto( 21 | dim=[TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim]) 22 | nodes.append(NodeDef( 23 | name=node.name, 24 | op='Variable', 25 | input=[], 26 | attr={ 27 | 'dtype': AttrValue(type=node.type.tensor_type.elem_type), 28 | 'shape': AttrValue(shape=shapeproto), 29 | }) 30 | ) 31 | 32 | for node in g.node: 33 | attr = [] 34 | for s in node.attribute: 35 | attr.append(' = '.join([str(f[1]) for f in s.ListFields()])) 36 | attr = ', '.join(attr).encode(encoding='utf_8') 37 | 38 | nodes.append(NodeDef( 39 | name=node.output[0], 40 | op=node.op_type, 41 | input=node.input, 42 | attr={'parameters': AttrValue(s=attr)}, 43 | )) 44 | # two pass token replacement, appends opname to object id 45 | mapping = {} 46 | for node in nodes: 47 | mapping[node.name] = node.op + '_' + node.name 48 | 49 | nodes, mapping = updatenodes(nodes, mapping) 50 | mapping = smartGrouping(nodes, mapping) 51 | nodes, mapping = updatenodes(nodes, mapping) 52 | 53 | return GraphDef(node=nodes, versions=VersionDef(producer=22)) 54 | 55 | 56 | def updatenodes(nodes, mapping): 57 | for node in nodes: 58 | newname = mapping[node.name] 59 | node.name = newname 60 | newinput = [] 61 | for inputnode in list(node.input): 62 | newinput.append(mapping[inputnode]) 63 | node.input.remove(inputnode) 64 | node.input.extend(newinput) 65 | newmap = {} 66 | for k, v in mapping.items(): 67 | newmap[v] = v 68 | return nodes, newmap 69 | 70 | 71 | def findnode(nodes, name): 72 | """ input: node name 73 | returns: node object 74 | """ 75 | for n in nodes: 76 | if n.name == name: 77 | return n 78 | 79 | 80 | def parser(s, nodes, node): 81 | print(s) 82 | if len(s) == 0: 83 | return 84 | if len(s) > 0: 85 | if s[0] == node.op: 86 | print(s[0], node.name, s[1], node.input) 87 | for n in node.input: 88 | print(n, s[1]) 89 | parser(s[1], nodes, findnode(nodes, n)) 90 | else: 91 | return False 92 | 93 | 94 | # TODO: use recursive parse 95 | 96 | def smartGrouping(nodes, mapping): 97 | # a Fully Conv is: (TODO: check var1.size(0)==var2.size(0)) 98 | # GEMM <-- Variable (c1) 99 | # ^-- Transpose (c2) <-- Variable (c3) 100 | 101 | # a Conv with bias is: (TODO: check var1.size(0)==var2.size(0)) 102 | # Add <-- Conv (c2) <-- Variable (c3) 103 | # ^-- Variable (c1) 104 | # 105 | # gemm = ('Gemm', ('Variable', ('Transpose', ('Variable')))) 106 | 107 | FCcounter = 1 108 | Convcounter = 1 109 | for node in nodes: 110 | if node.op == 'Gemm': 111 | c1 = c2 = c3 = False 112 | for name_in in node.input: 113 | n = findnode(nodes, name_in) 114 | if n.op == 'Variable': 115 | c1 = True 116 | c1name = n.name 117 | if n.op == 'Transpose': 118 | c2 = True 119 | c2name = n.name 120 | if len(n.input) == 1: 121 | nn = findnode(nodes, n.input[0]) 122 | if nn.op == 'Variable': 123 | c3 = True 124 | c3name = nn.name 125 | # print(n.op, n.name, c1, c2, c3) 126 | if c1 and c2 and c3: 127 | # print(c1name, c2name, c3name) 128 | mapping[c1name] = 'FC{}/{}'.format(FCcounter, c1name) 129 | mapping[c2name] = 'FC{}/{}'.format(FCcounter, c2name) 130 | mapping[c3name] = 'FC{}/{}'.format(FCcounter, c3name) 131 | mapping[node.name] = 'FC{}/{}'.format(FCcounter, node.name) 132 | FCcounter += 1 133 | continue 134 | if node.op == 'Add': 135 | c1 = c2 = c3 = False 136 | for name_in in node.input: 137 | n = findnode(nodes, name_in) 138 | if n.op == 'Variable': 139 | c1 = True 140 | c1name = n.name 141 | if n.op == 'Conv': 142 | c2 = True 143 | c2name = n.name 144 | if len(n.input) >= 1: 145 | for nn_name in n.input: 146 | nn = findnode(nodes, nn_name) 147 | if nn.op == 'Variable': 148 | c3 = True 149 | c3name = nn.name 150 | 151 | if c1 and c2 and c3: 152 | # print(c1name, c2name, c3name) 153 | mapping[c1name] = 'Conv{}/{}'.format(Convcounter, c1name) 154 | mapping[c2name] = 'Conv{}/{}'.format(Convcounter, c2name) 155 | mapping[c3name] = 'Conv{}/{}'.format(Convcounter, c3name) 156 | mapping[node.name] = 'Conv{}/{}'.format(Convcounter, node.name) 157 | Convcounter += 1 158 | return mapping 159 | -------------------------------------------------------------------------------- /reid/train/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import torch 4 | from torch.autograd import Variable 5 | from reid.evaluator import accuracy 6 | from utils.meters import AverageMeter 7 | import torch.nn.functional as F 8 | import sys 9 | from tensorboardX import SummaryWriter 10 | 11 | class BaseTrainer(object): 12 | 13 | def __init__(self, model, criterion): 14 | super(BaseTrainer, self).__init__() 15 | self.model = model 16 | self.criterion = criterion 17 | 18 | def train(self, epoch, data_loader, optimizer1, optimizer2): 19 | self.model.train() 20 | 21 | batch_time = AverageMeter() 22 | data_time = AverageMeter() 23 | losses = AverageMeter() 24 | precisions = AverageMeter() 25 | precisions1 = AverageMeter() 26 | precisions2 = AverageMeter() 27 | 28 | end = time.time() 29 | for i, inputs in enumerate(data_loader): 30 | data_time.update(time.time() - end) 31 | 32 | inputs, targets = self._parse_data(inputs) 33 | 34 | 35 | loss, prec_oim, prec_score, prec_finalscore = self._forward(inputs, targets) 36 | losses.update(loss.data[0], targets.size(0)) 37 | 38 | precisions.update(prec_oim, targets.size(0)) 39 | precisions1.update(prec_score, targets.size(0)) 40 | precisions2.update(prec_finalscore, targets.size(0)) 41 | 42 | optimizer1.zero_grad() 43 | optimizer2.zero_grad() 44 | loss.backward() 45 | optimizer1.step() 46 | optimizer2.step() 47 | 48 | batch_time.update(time.time() - end) 49 | end = time.time() 50 | print_freq = 5 51 | num_step = len(data_loader) 52 | num_iter = num_step * epoch + i 53 | self.writer.add_scalar('train/loss_step', losses.val, num_iter) 54 | self.writer.add_scalar('train/loss_avg', losses.avg, num_iter) 55 | self.writer.add_scalar('train/prec_pairloss', precisions1.avg, num_iter) 56 | self.writer.add_scalar('train/prec_oimloss', precisions.avg, num_iter) 57 | if (i + 1) % print_freq == 0: 58 | print('Epoch: [{}][{}/{}]\t' 59 | 'Loss {:.3f} ({:.3f})\t' 60 | 'prec_oim {:.2%} ({:.2%})\t' 61 | 'prec_score {:.2%} ({:.2%})\t' 62 | .format(epoch, i + 1, len(data_loader), 63 | losses.val, losses.avg, 64 | precisions.val, precisions.avg, 65 | precisions1.val, precisions1.avg)) 66 | 67 | 68 | def _parse_data(self, inputs): 69 | raise NotImplementedError 70 | 71 | def _forward(self, inputs, targets): 72 | raise NotImplementedError 73 | 74 | 75 | class SEQTrainer(BaseTrainer): 76 | 77 | def __init__(self, cnn_model, att_model, classifier_model, criterion_veri, criterion_oim, mode, rate, logdir): 78 | super(SEQTrainer, self).__init__(cnn_model, criterion_veri) 79 | self.att_model = att_model 80 | self.classifier_model = classifier_model 81 | self.regular_criterion = criterion_oim 82 | self.mode = mode 83 | self.rate = rate 84 | self.writer = SummaryWriter(log_dir=logdir) 85 | 86 | def _parse_data(self, inputs): 87 | imgs, flows, pids, _ = inputs 88 | inputs = [Variable(imgs), Variable(flows)] 89 | targets = Variable(pids).cuda() 90 | return inputs, targets 91 | 92 | def _forward(self, inputs, targets): 93 | 94 | if self.mode == 'cnn': 95 | out_feat = self.model(inputs[0], inputs[1], self.mode) 96 | 97 | loss, outputs = self.regular_criterion(out_feat, targets) 98 | prec, = accuracy(outputs.data, targets.data) 99 | prec = prec[0] 100 | 101 | return loss, prec, 0, 0 102 | 103 | elif self.mode == 'cnn_rnn': 104 | 105 | 106 | feat, feat_raw = self.model(inputs[0], inputs[1], self.mode) 107 | featsize = feat.size() 108 | featbatch = featsize[0] 109 | seqlen = featsize[1] 110 | 111 | ## expand the target label ID loss 112 | featX = feat.view(featbatch * seqlen, -1) 113 | 114 | targetX = targets.unsqueeze(1) 115 | targetX = targetX.expand(featbatch, seqlen) 116 | targetX = targetX.contiguous() 117 | targetX = targetX.view(featbatch * seqlen, -1) 118 | targetX = targetX.squeeze(1) 119 | loss_id, outputs_id = self.regular_criterion(featX, targetX) 120 | 121 | prec_id, = accuracy(outputs_id.data, targetX.data) 122 | prec_id = prec_id[0] 123 | 124 | ## verification label 125 | 126 | featsize = feat.size() 127 | sample_num = featsize[0] 128 | targets = targets.data 129 | targets = targets.view(int(sample_num / 2), -1) 130 | tar_probe = targets[:, 0] 131 | tar_gallery = targets[:, 1] 132 | 133 | pooled_probe, pooled_gallery_2, pooled_probe_2, pooled_gallery = self.att_model(feat, feat_raw) 134 | 135 | encode_scores = self.classifier_model(pooled_probe, pooled_gallery_2, pooled_probe_2, pooled_gallery) 136 | 137 | encode_size = encode_scores.size() 138 | encodemat = encode_scores.view(-1, 2) 139 | encodemat = F.softmax(encodemat) 140 | encodemat = encodemat.view(encode_size[0], encode_size[1], 2) 141 | encodemat = encodemat[:, :, 1] 142 | 143 | loss_ver, prec_ver = self.criterion(encodemat, tar_probe, tar_gallery) 144 | 145 | 146 | loss = loss_id*self.rate + 100*loss_ver 147 | 148 | return loss, prec_id, prec_ver, 0 149 | else: 150 | raise ValueError("Unsupported loss:", self.criterion) 151 | 152 | def train(self, epoch, data_loader, optimizer1, optimizer2, rate): 153 | self.att_model.train() 154 | self.classifier_model.train() 155 | self.rate = rate 156 | super(SEQTrainer, self).train(epoch, data_loader, optimizer1, optimizer2) 157 | -------------------------------------------------------------------------------- /tensorboardX/src/node_def_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/node_def.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorboardX.src import attr_value_pb2 as tensorboardX_dot_src_dot_attr__value__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='tensorboardX/src/node_def.proto', 21 | package='tensorboard', 22 | syntax='proto3', 23 | serialized_pb=_b('\n\x1ftensorboardX/src/node_def.proto\x12\x0btensorboard\x1a!tensorboardX/src/attr_value.proto\"\xb5\x01\n\x07NodeDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02op\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12,\n\x04\x61ttr\x18\x05 \x03(\x0b\x32\x1e.tensorboard.NodeDef.AttrEntry\x1a\x43\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorboard.AttrValue:\x02\x38\x01\x42*\n\x18org.tensorflow.frameworkB\tNodeProtoP\x01\xf8\x01\x01\x62\x06proto3') 24 | , 25 | dependencies=[tensorboardX_dot_src_dot_attr__value__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | 30 | _NODEDEF_ATTRENTRY = _descriptor.Descriptor( 31 | name='AttrEntry', 32 | full_name='tensorboard.NodeDef.AttrEntry', 33 | filename=None, 34 | file=DESCRIPTOR, 35 | containing_type=None, 36 | fields=[ 37 | _descriptor.FieldDescriptor( 38 | name='key', full_name='tensorboard.NodeDef.AttrEntry.key', index=0, 39 | number=1, type=9, cpp_type=9, label=1, 40 | has_default_value=False, default_value=_b("").decode('utf-8'), 41 | message_type=None, enum_type=None, containing_type=None, 42 | is_extension=False, extension_scope=None, 43 | options=None), 44 | _descriptor.FieldDescriptor( 45 | name='value', full_name='tensorboard.NodeDef.AttrEntry.value', index=1, 46 | number=2, type=11, cpp_type=10, label=1, 47 | has_default_value=False, default_value=None, 48 | message_type=None, enum_type=None, containing_type=None, 49 | is_extension=False, extension_scope=None, 50 | options=None), 51 | ], 52 | extensions=[ 53 | ], 54 | nested_types=[], 55 | enum_types=[ 56 | ], 57 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 58 | is_extendable=False, 59 | syntax='proto3', 60 | extension_ranges=[], 61 | oneofs=[ 62 | ], 63 | serialized_start=198, 64 | serialized_end=265, 65 | ) 66 | 67 | _NODEDEF = _descriptor.Descriptor( 68 | name='NodeDef', 69 | full_name='tensorboard.NodeDef', 70 | filename=None, 71 | file=DESCRIPTOR, 72 | containing_type=None, 73 | fields=[ 74 | _descriptor.FieldDescriptor( 75 | name='name', full_name='tensorboard.NodeDef.name', index=0, 76 | number=1, type=9, cpp_type=9, label=1, 77 | has_default_value=False, default_value=_b("").decode('utf-8'), 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None), 81 | _descriptor.FieldDescriptor( 82 | name='op', full_name='tensorboard.NodeDef.op', index=1, 83 | number=2, type=9, cpp_type=9, label=1, 84 | has_default_value=False, default_value=_b("").decode('utf-8'), 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=None), 88 | _descriptor.FieldDescriptor( 89 | name='input', full_name='tensorboard.NodeDef.input', index=2, 90 | number=3, type=9, cpp_type=9, label=3, 91 | has_default_value=False, default_value=[], 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=None), 95 | _descriptor.FieldDescriptor( 96 | name='device', full_name='tensorboard.NodeDef.device', index=3, 97 | number=4, type=9, cpp_type=9, label=1, 98 | has_default_value=False, default_value=_b("").decode('utf-8'), 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=None), 102 | _descriptor.FieldDescriptor( 103 | name='attr', full_name='tensorboard.NodeDef.attr', index=4, 104 | number=5, type=11, cpp_type=10, label=3, 105 | has_default_value=False, default_value=[], 106 | message_type=None, enum_type=None, containing_type=None, 107 | is_extension=False, extension_scope=None, 108 | options=None), 109 | ], 110 | extensions=[ 111 | ], 112 | nested_types=[_NODEDEF_ATTRENTRY, ], 113 | enum_types=[ 114 | ], 115 | options=None, 116 | is_extendable=False, 117 | syntax='proto3', 118 | extension_ranges=[], 119 | oneofs=[ 120 | ], 121 | serialized_start=84, 122 | serialized_end=265, 123 | ) 124 | 125 | _NODEDEF_ATTRENTRY.fields_by_name['value'].message_type = tensorboardX_dot_src_dot_attr__value__pb2._ATTRVALUE 126 | _NODEDEF_ATTRENTRY.containing_type = _NODEDEF 127 | _NODEDEF.fields_by_name['attr'].message_type = _NODEDEF_ATTRENTRY 128 | DESCRIPTOR.message_types_by_name['NodeDef'] = _NODEDEF 129 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 130 | 131 | NodeDef = _reflection.GeneratedProtocolMessageType('NodeDef', (_message.Message,), dict( 132 | 133 | AttrEntry = _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), dict( 134 | DESCRIPTOR = _NODEDEF_ATTRENTRY, 135 | __module__ = 'tensorboardX.src.node_def_pb2' 136 | # @@protoc_insertion_point(class_scope:tensorboard.NodeDef.AttrEntry) 137 | )) 138 | , 139 | DESCRIPTOR = _NODEDEF, 140 | __module__ = 'tensorboardX.src.node_def_pb2' 141 | # @@protoc_insertion_point(class_scope:tensorboard.NodeDef) 142 | )) 143 | _sym_db.RegisterMessage(NodeDef) 144 | _sym_db.RegisterMessage(NodeDef.AttrEntry) 145 | 146 | 147 | DESCRIPTOR.has_options = True 148 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\tNodeProtoP\001\370\001\001')) 149 | _NODEDEF_ATTRENTRY.has_options = True 150 | _NODEDEF_ATTRENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 151 | # @@protoc_insertion_point(module_scope) 152 | -------------------------------------------------------------------------------- /reid/dataset/prid2011sequence.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from reid.data.datasequence import Datasequence 5 | from utils.osutils import mkdir_if_missing 6 | from utils.serialization import write_json 7 | import tarfile 8 | import zipfile 9 | from glob import glob 10 | import shutil 11 | import numpy as np 12 | 13 | datasetname = 'prid_2011' 14 | # flowname = 'prid2011flow_opencv' 15 | flowname = 'prid_xiong' 16 | 17 | class infostruct(object): 18 | pass 19 | 20 | class PRID2011SEQUENCE(Datasequence): 21 | 22 | def __init__(self, root, split_id=0, seq_len=12, seq_srd=6, num_val=1, download=False): 23 | super(PRID2011SEQUENCE, self).__init__(root, split_id=split_id) 24 | 25 | if download: 26 | self.download() 27 | 28 | if not self._check_integrity(): 29 | self.imgextract() 30 | 31 | self.load(seq_len, seq_srd, num_val) 32 | 33 | self.query, query_pid, query_camid, query_num = self._pluckseq_cam(self.identities, self.split['query'], seq_len, seq_srd, 0) 34 | self.queryinfo = infostruct() 35 | self.queryinfo.pid = query_pid 36 | self.queryinfo.camid = query_camid 37 | self.queryinfo.tranum = query_num 38 | 39 | self.gallery, gallery_pid, gallery_camid, gallery_num = self._pluckseq_cam(self.identities, self.split['gallery'], seq_len, seq_srd, 1) 40 | self.galleryinfo = infostruct() 41 | self.galleryinfo.pid = gallery_pid 42 | self.galleryinfo.camid = gallery_camid 43 | self.galleryinfo.tranum = gallery_num 44 | 45 | @property 46 | def other_dir(self): 47 | return osp.join(self.root, 'others') 48 | 49 | def download(self): 50 | 51 | if self._check_integrity(): 52 | print("Files already downloaded and verified") 53 | return 54 | 55 | raw_dir = osp.join(self.root, 'raw') 56 | mkdir_if_missing(raw_dir) 57 | 58 | fpath1 = osp.join(raw_dir, datasetname + '.zip') 59 | fpath2 = osp.join(raw_dir, flowname + '.tar') 60 | 61 | if osp.isfile(fpath1) and osp.isfile(fpath2): 62 | print("Using the download file:" + fpath1 + " " + fpath2) 63 | else: 64 | print("Please firstly download the files") 65 | raise RuntimeError("Downloaded file missing!") 66 | 67 | def imgextract(self): 68 | 69 | raw_dir = osp.join(self.root, 'raw') 70 | exdir1 = osp.join(raw_dir, datasetname) 71 | exdir2 = osp.join(raw_dir, flowname) 72 | fpath1 = osp.join(raw_dir, datasetname + '.zip') 73 | fpath2 = osp.join(raw_dir, flowname + '.tar') 74 | 75 | if not osp.isdir(exdir1): 76 | print("Extracting tar file") 77 | cwd = os.getcwd() 78 | zip_ref = zipfile.ZipFile(fpath1, 'r') 79 | mkdir_if_missing(exdir1) 80 | zip_ref.extractall(exdir1) 81 | zip_ref.close() 82 | os.chdir(cwd) 83 | 84 | if not osp.isdir(exdir2): 85 | print("Extracting tar file") 86 | cwd = os.getcwd() 87 | tar_ref = tarfile.open(fpath2) 88 | mkdir_if_missing(exdir2) 89 | os.chdir(exdir2) 90 | tar_ref.extractall() 91 | tar_ref.close() 92 | os.chdir(cwd) 93 | 94 | ## recognizing the dataset 95 | # Format 96 | 97 | images_dir = osp.join(self.root, 'images') 98 | mkdir_if_missing(images_dir) 99 | 100 | others_dir = osp.join(self.root, 'others') 101 | mkdir_if_missing(others_dir) 102 | 103 | fpaths1 = sorted(glob(osp.join(exdir1, 'multi_shot', '*/*/*.png'))) 104 | fpaths2 = sorted(glob(osp.join(exdir2, '*/*/*.png'))) 105 | 106 | identities_images = [[[] for _ in range(2)] for _ in range(200)] 107 | identities_others = [[[] for _ in range(2)] for _ in range(200)] 108 | 109 | for fpath in fpaths1: 110 | fname = fpath 111 | fname_list = fname.split('/') 112 | cam_name = fname_list[-3] 113 | pid_name = fname_list[-2] 114 | frame_name = fname_list[-1] 115 | cam_id = 1 if cam_name =='cam_a' else 2 116 | pid_id = int(pid_name.split('_')[-1]) 117 | if pid_id > 200: 118 | continue 119 | frame_id = int(frame_name.split('.')[-2]) 120 | imagefname = ('{:08d}_{:02d}_{:04d}.png' 121 | .format(pid_id-1, cam_id-1, frame_id-1)) 122 | identities_images[pid_id - 1][cam_id - 1].append(imagefname) 123 | shutil.copy(fpath, osp.join(images_dir, imagefname)) 124 | 125 | for fpath in fpaths2: 126 | fname = fpath 127 | fname_list = fname.split('/') 128 | cam_name = fname_list[-3] 129 | pid_name = fname_list[-2] 130 | frame_name = fname_list[-1] 131 | cam_id = 1 if cam_name =='cam_a' else 2 132 | pid_id = int(pid_name.split('_')[-1]) 133 | if pid_id > 200: 134 | continue 135 | frame_id = int(frame_name.split('.')[-2]) 136 | flowfname = ('{:08d}_{:02d}_{:04d}.png' 137 | .format(pid_id-1, cam_id-1, frame_id-1)) 138 | identities_others[pid_id - 1][cam_id - 1].append(flowfname) 139 | shutil.copy(fname, osp.join(others_dir, flowfname)) 140 | 141 | 142 | 143 | 144 | meta = {'name': 'prid2011-sequence', 'shot': 'sequence', 'num_cameras': 2, 145 | 'identities': identities_images} 146 | 147 | write_json(meta, osp.join(self.root, 'meta.json')) 148 | # Consider fixed training and testing split 149 | num = 200 150 | splits = [] 151 | for i in range(10): 152 | pids = np.random.permutation(num) 153 | pids = (pids -1).tolist() 154 | trainval_pids = pids[:num // 2] 155 | test_pids = pids[num // 2:] 156 | split = {'trainval': trainval_pids, 157 | 'query': test_pids, 158 | 'gallery': test_pids} 159 | 160 | splits.append(split) 161 | write_json(splits, osp.join(self.root, 'splits.json')) 162 | def _pluckseq_cam(self, identities, indices, seq_len, seq_str, camid): 163 | ret = [] 164 | per_id = [] 165 | cam_id = [] 166 | tra_num = [] 167 | 168 | for index, pid in enumerate(indices): 169 | pid_images = identities[pid] 170 | cam_images = pid_images[camid] 171 | seqall = len(cam_images) 172 | seq_inds = [(start_ind, start_ind + seq_len) \ 173 | for start_ind in range(0, seqall - seq_len, seq_str)] 174 | if not seq_inds: 175 | seq_inds = [(0, seqall)] 176 | for seq_ind in seq_inds: 177 | ret.append((seq_ind[0], seq_ind[1], pid, index, camid)) 178 | per_id.append(pid) 179 | cam_id.append(camid) 180 | tra_num.append(len(seq_inds)) 181 | return ret, per_id, cam_id, tra_num 182 | -------------------------------------------------------------------------------- /tensorboardX/event_file_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Writes events to disk in a logdir.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import logging 22 | import os.path 23 | import socket 24 | import threading 25 | import time 26 | 27 | import six 28 | 29 | from .src import event_pb2 30 | from .record_writer import RecordWriter 31 | 32 | 33 | def directory_check(path): 34 | '''Initialize the directory for log files.''' 35 | # If the direcotry does not exist, create it! 36 | if not os.path.exists(path): 37 | os.makedirs(path) 38 | 39 | 40 | class EventsWriter(object): 41 | '''Writes `Event` protocol buffers to an event file.''' 42 | 43 | def __init__(self, file_prefix): 44 | ''' 45 | Events files have a name of the form 46 | '/some/file/path/events.out.tfevents.[timestamp].[hostname]' 47 | ''' 48 | self._file_prefix = file_prefix + ".out.tfevents." + str(time.time())[:10] + "." + socket.gethostname() 49 | 50 | # Open(Create) the log file with the particular form of name. 51 | logging.basicConfig(filename=self._file_prefix) 52 | 53 | self._num_outstanding_events = 0 54 | 55 | self._py_recordio_writer = RecordWriter(self._file_prefix) 56 | 57 | # Initialize an event instance. 58 | self._event = event_pb2.Event() 59 | 60 | self._event.wall_time = time.time() 61 | 62 | self.write_event(self._event) 63 | 64 | def write_event(self, event): 65 | '''Append "event" to the file.''' 66 | 67 | # Check if event is of type event_pb2.Event proto. 68 | if not isinstance(event, event_pb2.Event): 69 | raise TypeError("Expected an event_pb2.Event proto, " 70 | " but got %s" % type(event)) 71 | return self._write_serialized_event(event.SerializeToString()) 72 | 73 | def _write_serialized_event(self, event_str): 74 | self._num_outstanding_events += 1 75 | self._py_recordio_writer.write(event_str) 76 | 77 | def flush(self): 78 | '''Flushes the event file to disk.''' 79 | self._num_outstanding_events = 0 80 | return True 81 | 82 | def close(self): 83 | '''Call self.flush().''' 84 | return_value = self.flush() 85 | return return_value 86 | 87 | 88 | class EventFileWriter(object): 89 | """Writes `Event` protocol buffers to an event file. 90 | The `EventFileWriter` class creates an event file in the specified directory, 91 | and asynchronously writes Event protocol buffers to the file. The Event file 92 | is encoded using the tfrecord format, which is similar to RecordIO. 93 | @@__init__ 94 | @@add_event 95 | @@flush 96 | @@close 97 | """ 98 | 99 | def __init__(self, logdir, max_queue=10, flush_secs=120): 100 | """Creates a `EventFileWriter` and an event file to write to. 101 | On construction the summary writer creates a new event file in `logdir`. 102 | This event file will contain `Event` protocol buffers, which are written to 103 | disk via the add_event method. 104 | The other arguments to the constructor control the asynchronous writes to 105 | the event file: 106 | * `flush_secs`: How often, in seconds, to flush the added summaries 107 | and events to disk. 108 | * `max_queue`: Maximum number of summaries or events pending to be 109 | written to disk before one of the 'add' calls block. 110 | Args: 111 | logdir: A string. Directory where event file will be written. 112 | max_queue: Integer. Size of the queue for pending events and summaries. 113 | flush_secs: Number. How often, in seconds, to flush the 114 | pending events and summaries to disk. 115 | """ 116 | self._logdir = logdir 117 | directory_check(self._logdir) 118 | self._event_queue = six.moves.queue.Queue(max_queue) 119 | self._ev_writer = EventsWriter(os.path.join(self._logdir, "events")) 120 | self._closed = False 121 | self._worker = _EventLoggerThread(self._event_queue, self._ev_writer, 122 | flush_secs) 123 | 124 | self._worker.start() 125 | 126 | def get_logdir(self): 127 | """Returns the directory where event file will be written.""" 128 | return self._logdir 129 | 130 | def reopen(self): 131 | """Reopens the EventFileWriter. 132 | Can be called after `close()` to add more events in the same directory. 133 | The events will go into a new events file. 134 | Does nothing if the EventFileWriter was not closed. 135 | """ 136 | if self._closed: 137 | self._closed = False 138 | 139 | def add_event(self, event): 140 | """Adds an event to the event file. 141 | Args: 142 | event: An `Event` protocol buffer. 143 | """ 144 | if not self._closed: 145 | self._event_queue.put(event) 146 | 147 | def flush(self): 148 | """Flushes the event file to disk. 149 | Call this method to make sure that all pending events have been written to 150 | disk. 151 | """ 152 | self._event_queue.join() 153 | self._ev_writer.flush() 154 | 155 | def close(self): 156 | """Flushes the event file to disk and close the file. 157 | Call this method when you do not need the summary writer anymore. 158 | """ 159 | self.flush() 160 | self._ev_writer.close() 161 | self._closed = True 162 | 163 | 164 | class _EventLoggerThread(threading.Thread): 165 | """Thread that logs events.""" 166 | 167 | def __init__(self, queue, ev_writer, flush_secs): 168 | """Creates an _EventLoggerThread. 169 | Args: 170 | queue: A Queue from which to dequeue events. 171 | ev_writer: An event writer. Used to log brain events for 172 | the visualizer. 173 | flush_secs: How often, in seconds, to flush the 174 | pending file to disk. 175 | """ 176 | threading.Thread.__init__(self) 177 | self.daemon = True 178 | self._queue = queue 179 | self._ev_writer = ev_writer 180 | self._flush_secs = flush_secs 181 | # The first event will be flushed immediately. 182 | self._next_event_flush_time = 0 183 | 184 | def run(self): 185 | while True: 186 | event = self._queue.get() 187 | try: 188 | self._ev_writer.write_event(event) 189 | # Flush the event writer every so often. 190 | now = time.time() 191 | if now > self._next_event_flush_time: 192 | self._ev_writer.flush() 193 | # Do it again in two minutes. 194 | self._next_event_flush_time = now + self._flush_secs 195 | finally: 196 | self._queue.task_done() 197 | -------------------------------------------------------------------------------- /reid/data/seqtransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import random 4 | from PIL import Image, ImageOps 5 | import numpy as np 6 | 7 | 8 | class Compose(object): 9 | """Composes several transforms together. 10 | 11 | Args: 12 | transforms (List[Transform]): list of transforms to compose. 13 | 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, seqs): 25 | for t in self.transforms: 26 | seqs = t(seqs) 27 | return seqs 28 | 29 | 30 | class RectScale(object): 31 | def __init__(self, height, width, interpolation=Image.BILINEAR): 32 | self.height = height 33 | self.width = width 34 | self.interpolation = interpolation 35 | 36 | def __call__(self, seqs): 37 | modallen = len(seqs) 38 | framelen = len(seqs[0]) 39 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 40 | 41 | 42 | for modal_ind, modal in enumerate(seqs): 43 | for frame_ind, frame in enumerate(modal): 44 | w, h = frame.size 45 | if h == self.height and w == self.width: 46 | new_seqs[modal_ind][frame_ind] = frame 47 | else: 48 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation) 49 | 50 | return new_seqs 51 | 52 | 53 | 54 | class RandomSizedRectCrop(object): 55 | def __init__(self, height, width, interpolation=Image.BILINEAR): 56 | self.height = height 57 | self.width = width 58 | self.interpolation = interpolation 59 | 60 | def __call__(self, seqs): 61 | sample_img = seqs[0][0] 62 | for attempt in range(10): 63 | area = sample_img.size[0] * sample_img.size[1] 64 | target_area = random.uniform(0.64, 1.0) * area 65 | aspect_ratio = random.uniform(2, 3) 66 | 67 | h = int(round(math.sqrt(target_area * aspect_ratio))) 68 | w = int(round(math.sqrt(target_area / aspect_ratio))) 69 | 70 | if w <= sample_img.size[0] and h <= sample_img.size[1]: 71 | x1 = random.randint(0, sample_img.size[0] - w) 72 | y1 = random.randint(0, sample_img.size[1] - h) 73 | 74 | sample_img = sample_img.crop((x1, y1, x1 + w, y1 + h)) 75 | assert (sample_img.size == (w, h)) 76 | modallen = len(seqs) 77 | framelen = len(seqs[0]) 78 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 79 | 80 | for modal_ind, modal in enumerate(seqs): 81 | for frame_ind, frame in enumerate(modal): 82 | 83 | frame = frame.crop((x1, y1, x1 + w, y1 + h)) 84 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation) 85 | 86 | return new_seqs 87 | 88 | # Fallback 89 | scale = RectScale(self.height, self.width, 90 | interpolation=self.interpolation) 91 | return scale(seqs) 92 | 93 | class RandomSizedEarser(object): 94 | 95 | def __init__(self, sl=0.02, sh=0.2, asratio=0.3, p=0.5): 96 | self.sl = sl 97 | self.sh = sh 98 | self.asratio = asratio 99 | self.p = p 100 | 101 | def __call__(self, seqs): 102 | modallen = len(seqs) 103 | framelen = len(seqs[0]) 104 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 105 | for modal_ind, modal in enumerate(seqs): 106 | for frame_ind, frame in enumerate(modal): 107 | p1 = random.uniform(0.0, 1.0) 108 | W = frame.size[0] 109 | H = frame.size[1] 110 | area = H * W 111 | 112 | if p1 > self.p: 113 | new_seqs[modal_ind][frame_ind] = frame 114 | else: 115 | gen = True 116 | while gen: 117 | Se = random.uniform(self.sl, self.sh) * area 118 | re = random.uniform(self.asratio, 1 / self.asratio) 119 | He = np.sqrt(Se * re) 120 | We = np.sqrt(Se / re) 121 | xe = random.uniform(0, W - We) 122 | ye = random.uniform(0, H - He) 123 | if xe + We <= W and ye + He <= H and xe > 0 and ye > 0: 124 | x1 = int(np.ceil(xe)) 125 | y1 = int(np.ceil(ye)) 126 | x2 = int(np.floor(x1 + We)) 127 | y2 = int(np.floor(y1 + He)) 128 | part1 = frame.crop((x1, y1, x2, y2)) 129 | Rc = random.randint(0, 255) 130 | Gc = random.randint(0, 255) 131 | Bc = random.randint(0, 255) 132 | I = Image.new('RGB', part1.size, (Rc, Gc, Bc)) 133 | frame.paste(I, part1.size) 134 | break 135 | 136 | new_seqs[modal_ind][frame_ind] = frame 137 | 138 | return new_seqs 139 | 140 | 141 | class RandomHorizontalFlip(object): 142 | """Randomly horizontally flips the given PIL.Image Sequence with a probability of 0.5 143 | """ 144 | def __call__(self, seqs): 145 | if random.random() < 0.5: 146 | modallen = len(seqs) 147 | framelen = len(seqs[0]) 148 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 149 | for modal_ind, modal in enumerate(seqs): 150 | for frame_ind, frame in enumerate(modal): 151 | new_seqs[modal_ind][frame_ind] = frame.transpose(Image.FLIP_LEFT_RIGHT) 152 | return new_seqs 153 | return seqs 154 | 155 | 156 | class ToTensor(object): 157 | 158 | def __call__(self, seqs): 159 | modallen = len(seqs) 160 | framelen = len(seqs[0]) 161 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 162 | pic = seqs[0][0] 163 | 164 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 165 | if pic.mode == 'YCbCr': 166 | nchannel = 3 167 | elif pic.mode == 'I;16': 168 | nchannel = 1 169 | else: 170 | nchannel = len(pic.mode) 171 | 172 | if pic.mode =='I': 173 | for modal_ind, modal in enumerate(seqs): 174 | for frame_ind, frame in enumerate(modal): 175 | img = torch.from_numpy(np.array(frame, np.int32, copy=False)) 176 | img = img.view(pic.size[1], pic.size[0], nchannel) 177 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous() 178 | 179 | elif pic.mode == 'I;16': 180 | for modal_ind, modal in enumerate(seqs): 181 | for frame_ind, frame in enumerate(modal): 182 | img = torch.from_numpy(np.array(frame, np.int16, copy=False)) 183 | img = img.view(pic.size[1], pic.size[0], nchannel) 184 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous() 185 | else: 186 | for modal_ind, modal in enumerate(seqs): 187 | for frame_ind, frame in enumerate(modal): 188 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(frame.tobytes())) 189 | img = img.view(pic.size[1], pic.size[0], nchannel) 190 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 191 | new_seqs[modal_ind][frame_ind] = img.float().div(255) 192 | 193 | 194 | return new_seqs 195 | 196 | 197 | 198 | class Normalize(object): 199 | """Given mean: (R, G, B) and std: (R, G, B), 200 | will normalize each channel of the torch.*Tensor, i.e. 201 | channel = (channel - mean) / std 202 | """ 203 | def __init__(self, mean, std): 204 | self.mean = mean 205 | self.std = std 206 | 207 | def __call__(self, seqs): 208 | # TODO: make efficient 209 | modallen = len(seqs) 210 | framelen = len(seqs[0]) 211 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 212 | 213 | for modal_ind, modal in enumerate(seqs): 214 | for frame_ind, frame in enumerate(modal): 215 | for t, m, s in zip(frame, self.mean, self.std): 216 | t.sub_(m).div_(s) 217 | new_seqs[modal_ind][frame_ind] = frame 218 | 219 | return new_seqs 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | -------------------------------------------------------------------------------- /reid/dataset/ilidsvidsequence.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from reid.data.datasequence import Datasequence 5 | from utils.osutils import mkdir_if_missing 6 | from utils.serialization import write_json 7 | import tarfile 8 | from glob import glob 9 | import shutil 10 | import scipy.io as sio 11 | import numpy as np 12 | datasetname = 'iLIDS-VID' 13 | # flowname = 'Farneback' 14 | flowname = 'optical_flow' 15 | 16 | 17 | class infostruct(object): 18 | pass 19 | 20 | class iLIDSVIDSEQUENCE(Datasequence): 21 | 22 | def __init__(self, root, split_id=0, seq_len=12, seq_srd=6, num_val=1, download=False): 23 | super(iLIDSVIDSEQUENCE, self).__init__(root, split_id=split_id) 24 | 25 | if download: 26 | self.download() 27 | 28 | if not self._check_integrity(): 29 | self.imgextract() 30 | 31 | self.load(seq_len, seq_srd, num_val) 32 | 33 | self.query, query_pid, query_camid, query_num = self._pluckseq_cam(self.identities, self.split['query'], seq_len, seq_srd, 0) 34 | self.queryinfo = infostruct() 35 | self.queryinfo.pid = query_pid 36 | self.queryinfo.camid = query_camid 37 | self.queryinfo.tranum = query_num 38 | 39 | self.gallery, gallery_pid, gallery_camid, gallery_num = self._pluckseq_cam(self.identities, self.split['gallery'], seq_len, seq_srd, 1) 40 | self.galleryinfo = infostruct() 41 | self.galleryinfo.pid = gallery_pid 42 | self.galleryinfo.camid = gallery_camid 43 | self.galleryinfo.tranum = gallery_num 44 | 45 | @property 46 | def other_dir(self): 47 | return osp.join(self.root, 'others') 48 | 49 | 50 | def download(self): 51 | 52 | if self._check_integrity(): 53 | print("Files already downloaded and verified") 54 | return 55 | 56 | raw_dir = osp.join(self.root, 'raw') 57 | mkdir_if_missing(raw_dir) 58 | 59 | fpath1 = osp.join(raw_dir, datasetname + '.tar') 60 | fpath2 = osp.join(raw_dir, flowname + '.tar') 61 | if osp.isfile(fpath1) and osp.isfile(fpath2): 62 | print("Using the download file:" + fpath1 + " " + fpath2) 63 | else: 64 | print("Please firstly download the files") 65 | raise RuntimeError("Downloaded file missing!") 66 | 67 | def imgextract(self): 68 | 69 | raw_dir = osp.join(self.root, 'raw') 70 | exdir1 = osp.join(raw_dir, datasetname) 71 | exdir2 = osp.join(raw_dir, flowname) 72 | fpath1 = osp.join(raw_dir, datasetname + '.tar') 73 | fpath2 = osp.join(raw_dir, flowname + '.tar') 74 | 75 | 76 | if not osp.isdir(exdir1): 77 | print("Extracting tar file") 78 | cwd = os.getcwd() 79 | tar = tarfile.open(fpath1) 80 | mkdir_if_missing(exdir1) 81 | os.chdir(exdir1) 82 | tar.extractall() 83 | tar.close() 84 | os.chdir(cwd) 85 | 86 | if not osp.isdir(exdir2): 87 | print("Extracting tar file") 88 | cwd = os.getcwd() 89 | tar = tarfile.open(fpath2) 90 | mkdir_if_missing(exdir2) 91 | os.chdir(exdir2) 92 | tar.extractall() 93 | tar.close() 94 | os.chdir(cwd) 95 | 96 | # reorganzing the dataset 97 | # Format 98 | 99 | temp_images_dir = osp.join(self.root, 'temp_images') 100 | mkdir_if_missing(temp_images_dir) 101 | 102 | temp_others_dir = osp.join(self.root, 'temp_others') 103 | mkdir_if_missing(temp_others_dir) 104 | 105 | images_dir = osp.join(self.root, 'images') 106 | mkdir_if_missing(images_dir) 107 | 108 | others_dir = osp.join(self.root, 'others') 109 | mkdir_if_missing(others_dir) 110 | 111 | fpaths1 = sorted(glob(osp.join(exdir1, 'i-LIDS-VID/sequences', '*/*/*.png'))) 112 | fpaths2 = sorted(glob(osp.join(exdir2, flowname, '*/*/*.png'))) 113 | 114 | identities_imgraw = [[[] for _ in range(2)] for _ in range(319)] 115 | identities_otherraw = [[[] for _ in range(2)] for _ in range(319)] 116 | 117 | # image information 118 | for fpath in fpaths1: 119 | fname = osp.basename(fpath) 120 | fname_list = fname.split('_') 121 | cam_name = fname_list[0] 122 | pid_name = fname_list[1] 123 | cam = int(cam_name[-1]) 124 | pid = int(pid_name[-3:]) 125 | temp_fname = ('{:08d}_{:02d}_{:04d}.png' 126 | .format(pid, cam, len(identities_imgraw[pid - 1][cam - 1]))) 127 | identities_imgraw[pid - 1][cam - 1].append(temp_fname) 128 | shutil.copy(fpath, osp.join(temp_images_dir, temp_fname)) 129 | 130 | identities_temp = [x for x in identities_imgraw if x != [[], []]] 131 | identities_images = identities_temp 132 | 133 | for pid in range(len(identities_temp)): 134 | for cam in range(2): 135 | for img in range(len(identities_images[pid][cam])): 136 | temp_fname = identities_temp[pid][cam][img] 137 | fname = ('{:08d}_{:02d}_{:04d}.png' 138 | .format(pid, cam, img)) 139 | identities_images[pid][cam][img] = fname 140 | shutil.copy(osp.join(temp_images_dir, temp_fname), osp.join(images_dir, fname)) 141 | 142 | shutil.rmtree(temp_images_dir) 143 | 144 | # flow information 145 | 146 | for fpath in fpaths2: 147 | fname = osp.basename(fpath) 148 | fname_list = fname.split('_') 149 | cam_name = fname_list[0] 150 | pid_name = fname_list[1] 151 | cam = int(cam_name[-1]) 152 | pid = int(pid_name[-3:]) 153 | temp_fname = ('{:08d}_{:02d}_{:04d}.png' 154 | .format(pid, cam, len(identities_otherraw[pid - 1][cam - 1]))) 155 | identities_otherraw[pid - 1][cam - 1].append(temp_fname) 156 | shutil.copy(fpath, osp.join(temp_others_dir, temp_fname)) 157 | 158 | identities_temp = [x for x in identities_otherraw if x != [[], []]] 159 | identities_others = identities_temp 160 | 161 | for pid in range(len(identities_temp)): 162 | for cam in range(2): 163 | for img in range(len(identities_others[pid][cam])): 164 | temp_fname = identities_temp[pid][cam][img] 165 | fname = ('{:08d}_{:02d}_{:04d}.png' 166 | .format(pid, cam, img)) 167 | identities_others[pid][cam][img] = fname 168 | shutil.copy(osp.join(temp_others_dir, temp_fname), osp.join(others_dir, fname)) 169 | 170 | shutil.rmtree(temp_others_dir) 171 | 172 | meta = {'name': 'iLIDS-sequence', 'shot': 'sequence', 'num_cameras': 2, 173 | 'identities': identities_images} 174 | 175 | write_json(meta, osp.join(self.root, 'meta.json')) 176 | 177 | 178 | # Consider fixed training and testing split 179 | splitmat_name = osp.join(exdir1, 'train-test people splits', 'train_test_splits_ilidsvid.mat') 180 | data = sio.loadmat(splitmat_name) 181 | person_list = data['ls_set'] 182 | num = len(identities_images) 183 | splits = [] 184 | 185 | for i in range(10): 186 | pids = (person_list[i] - 1).tolist() 187 | trainval_pids = sorted(pids[:num // 2]) 188 | test_pids = sorted(pids[num // 2:]) 189 | split = {'trainval': trainval_pids, 190 | 'query': test_pids, 191 | 'gallery': test_pids} 192 | splits.append(split) 193 | write_json(splits, osp.join(self.root, 'splits.json')) 194 | 195 | def _pluckseq_cam(self, identities, indices, seq_len, seq_str, camid): 196 | ret = [] 197 | per_id = [] 198 | cam_id = [] 199 | tra_num = [] 200 | 201 | for index, pid in enumerate(indices): 202 | pid_images = identities[pid] 203 | cam_images = pid_images[camid] 204 | seqall = len(cam_images) 205 | seq_inds = [(start_ind, start_ind + seq_len) \ 206 | for start_ind in range(0, seqall - seq_len, seq_str)] 207 | if not seq_inds: 208 | seq_inds = [(0, seqall)] 209 | for seq_ind in seq_inds: 210 | ret.append((seq_ind[0], seq_ind[1], pid, index, camid)) 211 | per_id.append(pid) 212 | cam_id.append(camid) 213 | tra_num.append(len(seq_inds)) 214 | return ret, per_id, cam_id, tra_num 215 | -------------------------------------------------------------------------------- /reid/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import torch 4 | from torch.autograd import Variable 5 | from utils.meters import AverageMeter 6 | from utils import to_numpy 7 | from .eva_functions import cmc, mean_ap 8 | import numpy as np 9 | 10 | 11 | def evaluate_seq(distmat, query_pids, query_camids, gallery_pids, gallery_camids, cmc_topk=(1, 5, 10)): 12 | query_ids = np.array(query_pids) 13 | gallery_ids = np.array(gallery_pids) 14 | query_cams = np.array(query_camids) 15 | gallery_cams = np.array(gallery_camids) 16 | 17 | ## 18 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 19 | print('Mean AP: {:4.1%}'.format(mAP)) 20 | 21 | cmc_configs = { 22 | 'allshots': dict(separate_camera_set=False, 23 | single_gallery_shot=False, 24 | first_match_break=False), 25 | 'cuhk03': dict(separate_camera_set=True, 26 | single_gallery_shot=True, 27 | first_match_break=False), 28 | 'market1501': dict(separate_camera_set=False, 29 | single_gallery_shot=False, 30 | first_match_break=True)} 31 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 32 | query_cams, gallery_cams, **params) 33 | for name, params in cmc_configs.items()} 34 | 35 | print('CMC Scores{:>12}{:>12}{:>12}' 36 | .format('allshots', 'cuhk03', 'market1501')) 37 | for k in cmc_topk: 38 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 39 | .format(k, cmc_scores['allshots'][k - 1], 40 | cmc_scores['cuhk03'][k - 1], 41 | cmc_scores['market1501'][k - 1])) 42 | 43 | # Use the allshots cmc top-1 score for validation criterion 44 | return mAP 45 | 46 | 47 | def pairwise_distance_tensor(query_x, gallery_x): 48 | 49 | query_n = query_x.size(0) 50 | gallery_n = gallery_x.size(0) 51 | # query_squ = torch.pow(query_x, 2).sum(1).squeeze(1) 52 | # gallery_squ = torch.pow(gallery_x, 2).sum(1).squeeze(1) 53 | query_squ = torch.pow(query_x, 2).sum(1, keepdim=True).squeeze(1) 54 | gallery_squ = torch.pow(gallery_x, 2).sum(1, keepdim=True).squeeze(1) 55 | query_squ = query_squ.unsqueeze(1) 56 | gallery_squ = gallery_squ.unsqueeze(0) 57 | query_squ = query_squ.expand(query_n, gallery_n) 58 | gallery_squ = gallery_squ.expand(query_n, gallery_n) 59 | 60 | query_gallery_squ = query_squ + gallery_squ 61 | query_gallery_pro = torch.mm(query_x, gallery_x.t()) 62 | dist = query_gallery_squ - 2*query_gallery_pro 63 | 64 | return dist 65 | 66 | 67 | class CNNEvaluator(object): 68 | 69 | def __init__(self, cnn_model, mode): 70 | super(CNNEvaluator, self).__init__() 71 | self.cnn_model = cnn_model 72 | self.mode = mode 73 | 74 | def extract_feature(self, cnn_model, data_loader): 75 | print_freq = 5 76 | cnn_model.eval() 77 | batch_time = AverageMeter() 78 | data_time = AverageMeter() 79 | end = time.time() 80 | 81 | allfeatures = 0 82 | 83 | for i, (imgs, flows, _, _) in enumerate(data_loader): 84 | data_time.update(time.time() - end) 85 | imgs = Variable(imgs, volatile=True) 86 | flows = Variable(flows, volatile=True) 87 | 88 | if i == 0: 89 | out_feat = self.cnn_model(imgs, flows, self.mode) 90 | allfeatures = out_feat.data 91 | preimgs = imgs 92 | preflows = flows 93 | elif imgs.size(0)12}{:>12}{:>12}' 37 | .format('allshots', 'cuhk03', 'market1501')) 38 | for k in cmc_topk: 39 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 40 | .format(k, cmc_scores['allshots'][k - 1], 41 | cmc_scores['cuhk03'][k - 1], 42 | cmc_scores['market1501'][k - 1])) 43 | 44 | # Use the allshots cmc top-1 score for validation criterion 45 | top1 = cmc_scores['allshots'][0] 46 | top5 = cmc_scores['allshots'][4] 47 | top10 = cmc_scores['allshots'][9] 48 | top20 = cmc_scores['allshots'][19] 49 | 50 | return mAP, top1, top5, top10, top20 51 | 52 | 53 | 54 | class ATTEvaluator(object): 55 | 56 | def __init__(self, cnn_model, att_model, classifier_model,mode,criterion): 57 | super(ATTEvaluator, self).__init__() 58 | self.cnn_model = cnn_model 59 | self.att_model = att_model 60 | self.classifier_model = classifier_model 61 | self.mode = mode 62 | self.criterion = criterion 63 | 64 | def extract_feature(self, data_loader): 65 | print_freq = 5 66 | self.cnn_model.eval() 67 | self.att_model.eval() 68 | 69 | 70 | batch_time = AverageMeter() 71 | data_time = AverageMeter() 72 | end = time.time() 73 | 74 | # allfeatures = 0 75 | # allfeatures_raw = 0 76 | 77 | for i, (imgs, flows, _, _) in enumerate(data_loader): 78 | data_time.update(time.time() - end) 79 | imgs = Variable(imgs, volatile=True) 80 | flows = Variable(flows, volatile=True) 81 | 82 | if i == 0: 83 | out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) 84 | allfeatures = [out_feat] 85 | allfeatures_raw = [out_raw] 86 | preimgs = imgs 87 | preflows = flows 88 | elif imgs.size(0) < data_loader.batch_size: 89 | flaw_batchsize = imgs.size(0) 90 | cat_batchsize = data_loader.batch_size - flaw_batchsize 91 | imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) 92 | flows = torch.cat((flows, preflows[0:cat_batchsize]), 0) 93 | 94 | out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) 95 | 96 | out_feat = out_feat[0:flaw_batchsize] 97 | out_raw = out_raw[0:flaw_batchsize] 98 | 99 | allfeatures.append(out_feat) 100 | allfeatures_raw.append(out_raw) 101 | else: 102 | out_feat, out_raw = self.cnn_model(imgs, flows, self.mode) 103 | 104 | allfeatures.append(out_feat) 105 | allfeatures_raw.append(out_raw) 106 | 107 | batch_time.update(time.time() - end) 108 | end = time.time() 109 | 110 | if (i + 1) % print_freq == 0: 111 | print('Extract Features: [{}/{}]\t' 112 | 'Time {:.3f} ({:.3f})\t' 113 | 'Data {:.3f} ({:.3f})\t' 114 | .format(i + 1, len(data_loader), 115 | batch_time.val, batch_time.avg, 116 | data_time.val, data_time.avg)) 117 | 118 | allfeatures = torch.cat(allfeatures, 0) 119 | allfeatures_raw = torch.cat(allfeatures_raw, 0) 120 | return allfeatures, allfeatures_raw 121 | 122 | def evaluate(self, query_loader, gallery_loader, queryinfo, galleryinfo): 123 | 124 | 125 | self.cnn_model.eval() 126 | self.att_model.eval() 127 | self.classifier_model.eval() 128 | 129 | 130 | querypid = queryinfo.pid 131 | querycamid = queryinfo.camid 132 | querytranum = queryinfo.tranum 133 | gallerypid = galleryinfo.pid 134 | gallerycamid = galleryinfo.camid 135 | gallerytranum = galleryinfo.tranum 136 | 137 | query_resfeatures, query_resraw = self.extract_feature(query_loader) 138 | gallery_resfeatures, gallery_resraw = self.extract_feature(gallery_loader) 139 | 140 | querylen = len(querypid) 141 | gallerylen = len(gallerypid) 142 | 143 | # online gallery extraction 144 | single_distmat = np.zeros((querylen, gallerylen)) 145 | 146 | q_start = 0 147 | pooled_query = [] 148 | for qind, qnum in enumerate(querytranum): 149 | query_feat_tmp = query_resfeatures[q_start:q_start+qnum, :, :] 150 | query_featraw_tmp = query_resraw[q_start:q_start+qnum, :, :] 151 | pooled_query_tmp, hidden_query_tmp = self.att_model.selfpooling_model(query_feat_tmp, query_featraw_tmp) 152 | pooled_query.append(pooled_query_tmp) 153 | q_start += qnum 154 | pooled_query = torch.cat(pooled_query, 0) 155 | 156 | g_start = 0 157 | pooled_gallery = [] 158 | for gind, gnum in enumerate(gallerytranum): 159 | gallery_feat_tmp = gallery_resfeatures[g_start:g_start+gnum, :, :] 160 | gallery_featraw_tmp = gallery_resraw[g_start:g_start+gnum, :, :] 161 | pooled_gallery_tmp, hidden_gallery_tmp = self.att_model.selfpooling_model(gallery_feat_tmp, gallery_featraw_tmp) 162 | pooled_gallery.append(pooled_gallery_tmp) 163 | g_start += gnum 164 | pooled_gallery = torch.cat(pooled_gallery, 0) 165 | # pooled_query, hidden_query = self.att_model.selfpooling_model_1(query_resfeatures, query_resraw) 166 | # pooled_gallery, hidden_gallery = self.att_model.selfpooling_model_2(gallery_resfeatures, gallery_resraw) 167 | 168 | g_start = 0 169 | pooled_gallery_2 = [] 170 | for gind, gnum in enumerate(gallerytranum): 171 | gallery_feat_tmp = gallery_resfeatures[g_start:g_start+gnum, :, :] 172 | gallery_featraw_tmp = gallery_resraw[g_start:g_start+gnum, :, :] 173 | pooled_gallery_2_tmp = self.att_model.crosspooling_model(gallery_feat_tmp, gallery_featraw_tmp, pooled_query) 174 | pooled_gallery_2.append(pooled_gallery_2_tmp) 175 | g_start += gnum 176 | pooled_gallery_2 = torch.cat(pooled_gallery_2, 1) 177 | 178 | q_start = 0 179 | pooled_query_2 = [] 180 | for qind, qnum in enumerate(querytranum): 181 | query_feat_tmp = query_resfeatures[q_start:q_start+qnum, :, :] 182 | query_featraw_tmp = query_resraw[q_start:q_start+qnum, :, :] 183 | pooled_query_2_tmp = self.att_model.crosspooling_model(query_feat_tmp, query_featraw_tmp, pooled_gallery) 184 | pooled_query_2.append(pooled_query_2_tmp) 185 | q_start += qnum 186 | pooled_query_2 = torch.cat(pooled_query_2, 1) 187 | 188 | pooled_query_2 = pooled_query_2.permute(1, 0, 2) 189 | pooled_query, pooled_gallery = pooled_query.unsqueeze(1), pooled_gallery.unsqueeze(0) 190 | 191 | encode_scores = self.classifier_model(pooled_query, pooled_gallery_2, pooled_query_2, pooled_gallery) 192 | 193 | encode_size = encode_scores.size() 194 | encodemat = encode_scores.view(-1, 2) 195 | encodemat = F.softmax(encodemat) 196 | encodemat = encodemat.view(encode_size[0], encode_size[1], 2) 197 | 198 | single_distmat_all = encodemat[:, :, 0] 199 | single_distmat_all = single_distmat_all.data.cpu().numpy() 200 | q_start, g_start = 0, 0 201 | for qind, qnum in enumerate(querytranum): 202 | for gind, gnum in enumerate(gallerytranum): 203 | distmat_qg = single_distmat_all[q_start:q_start+qnum, g_start:g_start+gnum] 204 | #percile = np.percentile(distmat_qg, 20) 205 | percile = np.percentile(distmat_qg, 10) 206 | if distmat_qg[distmat_qg <= percile] is not None: 207 | distmean = np.mean(distmat_qg[distmat_qg <= percile]) 208 | else: 209 | distmean = np.mean(distmat_qg) 210 | 211 | single_distmat[qind, gind] = distmean 212 | g_start = g_start + gnum 213 | g_start = 0 214 | q_start = q_start + qnum 215 | 216 | return evaluate_seq(single_distmat, querypid, querycamid, gallerypid, gallerycamid) 217 | -------------------------------------------------------------------------------- /tensorboardX/src/types_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/types.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='tensorboardX/src/types.proto', 21 | package='tensorboard', 22 | syntax='proto3', 23 | serialized_pb=_b('\n\x1ctensorboardX/src/types.proto\x12\x0btensorboard*\xc2\x05\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0c\n\x08\x44T_FLOAT\x10\x01\x12\r\n\tDT_DOUBLE\x10\x02\x12\x0c\n\x08\x44T_INT32\x10\x03\x12\x0c\n\x08\x44T_UINT8\x10\x04\x12\x0c\n\x08\x44T_INT16\x10\x05\x12\x0b\n\x07\x44T_INT8\x10\x06\x12\r\n\tDT_STRING\x10\x07\x12\x10\n\x0c\x44T_COMPLEX64\x10\x08\x12\x0c\n\x08\x44T_INT64\x10\t\x12\x0b\n\x07\x44T_BOOL\x10\n\x12\x0c\n\x08\x44T_QINT8\x10\x0b\x12\r\n\tDT_QUINT8\x10\x0c\x12\r\n\tDT_QINT32\x10\r\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x0e\x12\r\n\tDT_QINT16\x10\x0f\x12\x0e\n\nDT_QUINT16\x10\x10\x12\r\n\tDT_UINT16\x10\x11\x12\x11\n\rDT_COMPLEX128\x10\x12\x12\x0b\n\x07\x44T_HALF\x10\x13\x12\x0f\n\x0b\x44T_RESOURCE\x10\x14\x12\x10\n\x0c\x44T_FLOAT_REF\x10\x65\x12\x11\n\rDT_DOUBLE_REF\x10\x66\x12\x10\n\x0c\x44T_INT32_REF\x10g\x12\x10\n\x0c\x44T_UINT8_REF\x10h\x12\x10\n\x0c\x44T_INT16_REF\x10i\x12\x0f\n\x0b\x44T_INT8_REF\x10j\x12\x11\n\rDT_STRING_REF\x10k\x12\x14\n\x10\x44T_COMPLEX64_REF\x10l\x12\x10\n\x0c\x44T_INT64_REF\x10m\x12\x0f\n\x0b\x44T_BOOL_REF\x10n\x12\x10\n\x0c\x44T_QINT8_REF\x10o\x12\x11\n\rDT_QUINT8_REF\x10p\x12\x11\n\rDT_QINT32_REF\x10q\x12\x13\n\x0f\x44T_BFLOAT16_REF\x10r\x12\x11\n\rDT_QINT16_REF\x10s\x12\x12\n\x0e\x44T_QUINT16_REF\x10t\x12\x11\n\rDT_UINT16_REF\x10u\x12\x15\n\x11\x44T_COMPLEX128_REF\x10v\x12\x0f\n\x0b\x44T_HALF_REF\x10w\x12\x13\n\x0f\x44T_RESOURCE_REF\x10xB,\n\x18org.tensorflow.frameworkB\x0bTypesProtosP\x01\xf8\x01\x01\x62\x06proto3') 24 | ) 25 | 26 | _DATATYPE = _descriptor.EnumDescriptor( 27 | name='DataType', 28 | full_name='tensorboard.DataType', 29 | filename=None, 30 | file=DESCRIPTOR, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='DT_INVALID', index=0, number=0, 34 | options=None, 35 | type=None), 36 | _descriptor.EnumValueDescriptor( 37 | name='DT_FLOAT', index=1, number=1, 38 | options=None, 39 | type=None), 40 | _descriptor.EnumValueDescriptor( 41 | name='DT_DOUBLE', index=2, number=2, 42 | options=None, 43 | type=None), 44 | _descriptor.EnumValueDescriptor( 45 | name='DT_INT32', index=3, number=3, 46 | options=None, 47 | type=None), 48 | _descriptor.EnumValueDescriptor( 49 | name='DT_UINT8', index=4, number=4, 50 | options=None, 51 | type=None), 52 | _descriptor.EnumValueDescriptor( 53 | name='DT_INT16', index=5, number=5, 54 | options=None, 55 | type=None), 56 | _descriptor.EnumValueDescriptor( 57 | name='DT_INT8', index=6, number=6, 58 | options=None, 59 | type=None), 60 | _descriptor.EnumValueDescriptor( 61 | name='DT_STRING', index=7, number=7, 62 | options=None, 63 | type=None), 64 | _descriptor.EnumValueDescriptor( 65 | name='DT_COMPLEX64', index=8, number=8, 66 | options=None, 67 | type=None), 68 | _descriptor.EnumValueDescriptor( 69 | name='DT_INT64', index=9, number=9, 70 | options=None, 71 | type=None), 72 | _descriptor.EnumValueDescriptor( 73 | name='DT_BOOL', index=10, number=10, 74 | options=None, 75 | type=None), 76 | _descriptor.EnumValueDescriptor( 77 | name='DT_QINT8', index=11, number=11, 78 | options=None, 79 | type=None), 80 | _descriptor.EnumValueDescriptor( 81 | name='DT_QUINT8', index=12, number=12, 82 | options=None, 83 | type=None), 84 | _descriptor.EnumValueDescriptor( 85 | name='DT_QINT32', index=13, number=13, 86 | options=None, 87 | type=None), 88 | _descriptor.EnumValueDescriptor( 89 | name='DT_BFLOAT16', index=14, number=14, 90 | options=None, 91 | type=None), 92 | _descriptor.EnumValueDescriptor( 93 | name='DT_QINT16', index=15, number=15, 94 | options=None, 95 | type=None), 96 | _descriptor.EnumValueDescriptor( 97 | name='DT_QUINT16', index=16, number=16, 98 | options=None, 99 | type=None), 100 | _descriptor.EnumValueDescriptor( 101 | name='DT_UINT16', index=17, number=17, 102 | options=None, 103 | type=None), 104 | _descriptor.EnumValueDescriptor( 105 | name='DT_COMPLEX128', index=18, number=18, 106 | options=None, 107 | type=None), 108 | _descriptor.EnumValueDescriptor( 109 | name='DT_HALF', index=19, number=19, 110 | options=None, 111 | type=None), 112 | _descriptor.EnumValueDescriptor( 113 | name='DT_RESOURCE', index=20, number=20, 114 | options=None, 115 | type=None), 116 | _descriptor.EnumValueDescriptor( 117 | name='DT_FLOAT_REF', index=21, number=101, 118 | options=None, 119 | type=None), 120 | _descriptor.EnumValueDescriptor( 121 | name='DT_DOUBLE_REF', index=22, number=102, 122 | options=None, 123 | type=None), 124 | _descriptor.EnumValueDescriptor( 125 | name='DT_INT32_REF', index=23, number=103, 126 | options=None, 127 | type=None), 128 | _descriptor.EnumValueDescriptor( 129 | name='DT_UINT8_REF', index=24, number=104, 130 | options=None, 131 | type=None), 132 | _descriptor.EnumValueDescriptor( 133 | name='DT_INT16_REF', index=25, number=105, 134 | options=None, 135 | type=None), 136 | _descriptor.EnumValueDescriptor( 137 | name='DT_INT8_REF', index=26, number=106, 138 | options=None, 139 | type=None), 140 | _descriptor.EnumValueDescriptor( 141 | name='DT_STRING_REF', index=27, number=107, 142 | options=None, 143 | type=None), 144 | _descriptor.EnumValueDescriptor( 145 | name='DT_COMPLEX64_REF', index=28, number=108, 146 | options=None, 147 | type=None), 148 | _descriptor.EnumValueDescriptor( 149 | name='DT_INT64_REF', index=29, number=109, 150 | options=None, 151 | type=None), 152 | _descriptor.EnumValueDescriptor( 153 | name='DT_BOOL_REF', index=30, number=110, 154 | options=None, 155 | type=None), 156 | _descriptor.EnumValueDescriptor( 157 | name='DT_QINT8_REF', index=31, number=111, 158 | options=None, 159 | type=None), 160 | _descriptor.EnumValueDescriptor( 161 | name='DT_QUINT8_REF', index=32, number=112, 162 | options=None, 163 | type=None), 164 | _descriptor.EnumValueDescriptor( 165 | name='DT_QINT32_REF', index=33, number=113, 166 | options=None, 167 | type=None), 168 | _descriptor.EnumValueDescriptor( 169 | name='DT_BFLOAT16_REF', index=34, number=114, 170 | options=None, 171 | type=None), 172 | _descriptor.EnumValueDescriptor( 173 | name='DT_QINT16_REF', index=35, number=115, 174 | options=None, 175 | type=None), 176 | _descriptor.EnumValueDescriptor( 177 | name='DT_QUINT16_REF', index=36, number=116, 178 | options=None, 179 | type=None), 180 | _descriptor.EnumValueDescriptor( 181 | name='DT_UINT16_REF', index=37, number=117, 182 | options=None, 183 | type=None), 184 | _descriptor.EnumValueDescriptor( 185 | name='DT_COMPLEX128_REF', index=38, number=118, 186 | options=None, 187 | type=None), 188 | _descriptor.EnumValueDescriptor( 189 | name='DT_HALF_REF', index=39, number=119, 190 | options=None, 191 | type=None), 192 | _descriptor.EnumValueDescriptor( 193 | name='DT_RESOURCE_REF', index=40, number=120, 194 | options=None, 195 | type=None), 196 | ], 197 | containing_type=None, 198 | options=None, 199 | serialized_start=46, 200 | serialized_end=752, 201 | ) 202 | _sym_db.RegisterEnumDescriptor(_DATATYPE) 203 | 204 | DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) 205 | DT_INVALID = 0 206 | DT_FLOAT = 1 207 | DT_DOUBLE = 2 208 | DT_INT32 = 3 209 | DT_UINT8 = 4 210 | DT_INT16 = 5 211 | DT_INT8 = 6 212 | DT_STRING = 7 213 | DT_COMPLEX64 = 8 214 | DT_INT64 = 9 215 | DT_BOOL = 10 216 | DT_QINT8 = 11 217 | DT_QUINT8 = 12 218 | DT_QINT32 = 13 219 | DT_BFLOAT16 = 14 220 | DT_QINT16 = 15 221 | DT_QUINT16 = 16 222 | DT_UINT16 = 17 223 | DT_COMPLEX128 = 18 224 | DT_HALF = 19 225 | DT_RESOURCE = 20 226 | DT_FLOAT_REF = 101 227 | DT_DOUBLE_REF = 102 228 | DT_INT32_REF = 103 229 | DT_UINT8_REF = 104 230 | DT_INT16_REF = 105 231 | DT_INT8_REF = 106 232 | DT_STRING_REF = 107 233 | DT_COMPLEX64_REF = 108 234 | DT_INT64_REF = 109 235 | DT_BOOL_REF = 110 236 | DT_QINT8_REF = 111 237 | DT_QUINT8_REF = 112 238 | DT_QINT32_REF = 113 239 | DT_BFLOAT16_REF = 114 240 | DT_QINT16_REF = 115 241 | DT_QUINT16_REF = 116 242 | DT_UINT16_REF = 117 243 | DT_COMPLEX128_REF = 118 244 | DT_HALF_REF = 119 245 | DT_RESOURCE_REF = 120 246 | 247 | 248 | DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE 249 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 250 | 251 | 252 | DESCRIPTOR.has_options = True 253 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\013TypesProtosP\001\370\001\001')) 254 | # @@protoc_insertion_point(module_scope) 255 | -------------------------------------------------------------------------------- /tensorboardX/src/tensor_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/tensor.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorboardX.src import resource_handle_pb2 as tensorboardX_dot_src_dot_resource__handle__pb2 17 | from tensorboardX.src import tensor_shape_pb2 as tensorboardX_dot_src_dot_tensor__shape__pb2 18 | from tensorboardX.src import types_pb2 as tensorboardX_dot_src_dot_types__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='tensorboardX/src/tensor.proto', 23 | package='tensorboard', 24 | syntax='proto3', 25 | serialized_pb=_b('\n\x1dtensorboardX/src/tensor.proto\x12\x0btensorboard\x1a&tensorboardX/src/resource_handle.proto\x1a#tensorboardX/src/tensor_shape.proto\x1a\x1ctensorboardX/src/types.proto\"\xa6\x03\n\x0bTensorProto\x12$\n\x05\x64type\x18\x01 \x01(\x0e\x32\x15.tensorboard.DataType\x12\x33\n\x0ctensor_shape\x18\x02 \x01(\x0b\x32\x1d.tensorboard.TensorShapeProto\x12\x16\n\x0eversion_number\x18\x03 \x01(\x05\x12\x16\n\x0etensor_content\x18\x04 \x01(\x0c\x12\x14\n\x08half_val\x18\r \x03(\x05\x42\x02\x10\x01\x12\x15\n\tfloat_val\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x16\n\ndouble_val\x18\x06 \x03(\x01\x42\x02\x10\x01\x12\x13\n\x07int_val\x18\x07 \x03(\x05\x42\x02\x10\x01\x12\x12\n\nstring_val\x18\x08 \x03(\x0c\x12\x18\n\x0cscomplex_val\x18\t \x03(\x02\x42\x02\x10\x01\x12\x15\n\tint64_val\x18\n \x03(\x03\x42\x02\x10\x01\x12\x14\n\x08\x62ool_val\x18\x0b \x03(\x08\x42\x02\x10\x01\x12\x18\n\x0c\x64\x63omplex_val\x18\x0c \x03(\x01\x42\x02\x10\x01\x12=\n\x13resource_handle_val\x18\x0e \x03(\x0b\x32 .tensorboard.ResourceHandleProtoB-\n\x18org.tensorflow.frameworkB\x0cTensorProtosP\x01\xf8\x01\x01\x62\x06proto3') 26 | , 27 | dependencies=[tensorboardX_dot_src_dot_resource__handle__pb2.DESCRIPTOR,tensorboardX_dot_src_dot_tensor__shape__pb2.DESCRIPTOR,tensorboardX_dot_src_dot_types__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _TENSORPROTO = _descriptor.Descriptor( 33 | name='TensorProto', 34 | full_name='tensorboard.TensorProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='dtype', full_name='tensorboard.TensorProto.dtype', index=0, 41 | number=1, type=14, cpp_type=8, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | _descriptor.FieldDescriptor( 47 | name='tensor_shape', full_name='tensorboard.TensorProto.tensor_shape', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | _descriptor.FieldDescriptor( 54 | name='version_number', full_name='tensorboard.TensorProto.version_number', index=2, 55 | number=3, type=5, cpp_type=1, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None), 60 | _descriptor.FieldDescriptor( 61 | name='tensor_content', full_name='tensorboard.TensorProto.tensor_content', index=3, 62 | number=4, type=12, cpp_type=9, label=1, 63 | has_default_value=False, default_value=_b(""), 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None), 67 | _descriptor.FieldDescriptor( 68 | name='half_val', full_name='tensorboard.TensorProto.half_val', index=4, 69 | number=13, type=5, cpp_type=1, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 74 | _descriptor.FieldDescriptor( 75 | name='float_val', full_name='tensorboard.TensorProto.float_val', index=5, 76 | number=5, type=2, cpp_type=6, label=3, 77 | has_default_value=False, default_value=[], 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 81 | _descriptor.FieldDescriptor( 82 | name='double_val', full_name='tensorboard.TensorProto.double_val', index=6, 83 | number=6, type=1, cpp_type=5, label=3, 84 | has_default_value=False, default_value=[], 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 88 | _descriptor.FieldDescriptor( 89 | name='int_val', full_name='tensorboard.TensorProto.int_val', index=7, 90 | number=7, type=5, cpp_type=1, label=3, 91 | has_default_value=False, default_value=[], 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 95 | _descriptor.FieldDescriptor( 96 | name='string_val', full_name='tensorboard.TensorProto.string_val', index=8, 97 | number=8, type=12, cpp_type=9, label=3, 98 | has_default_value=False, default_value=[], 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=None), 102 | _descriptor.FieldDescriptor( 103 | name='scomplex_val', full_name='tensorboard.TensorProto.scomplex_val', index=9, 104 | number=9, type=2, cpp_type=6, label=3, 105 | has_default_value=False, default_value=[], 106 | message_type=None, enum_type=None, containing_type=None, 107 | is_extension=False, extension_scope=None, 108 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 109 | _descriptor.FieldDescriptor( 110 | name='int64_val', full_name='tensorboard.TensorProto.int64_val', index=10, 111 | number=10, type=3, cpp_type=2, label=3, 112 | has_default_value=False, default_value=[], 113 | message_type=None, enum_type=None, containing_type=None, 114 | is_extension=False, extension_scope=None, 115 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 116 | _descriptor.FieldDescriptor( 117 | name='bool_val', full_name='tensorboard.TensorProto.bool_val', index=11, 118 | number=11, type=8, cpp_type=7, label=3, 119 | has_default_value=False, default_value=[], 120 | message_type=None, enum_type=None, containing_type=None, 121 | is_extension=False, extension_scope=None, 122 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 123 | _descriptor.FieldDescriptor( 124 | name='dcomplex_val', full_name='tensorboard.TensorProto.dcomplex_val', index=12, 125 | number=12, type=1, cpp_type=5, label=3, 126 | has_default_value=False, default_value=[], 127 | message_type=None, enum_type=None, containing_type=None, 128 | is_extension=False, extension_scope=None, 129 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), 130 | _descriptor.FieldDescriptor( 131 | name='resource_handle_val', full_name='tensorboard.TensorProto.resource_handle_val', index=13, 132 | number=14, type=11, cpp_type=10, label=3, 133 | has_default_value=False, default_value=[], 134 | message_type=None, enum_type=None, containing_type=None, 135 | is_extension=False, extension_scope=None, 136 | options=None), 137 | ], 138 | extensions=[ 139 | ], 140 | nested_types=[], 141 | enum_types=[ 142 | ], 143 | options=None, 144 | is_extendable=False, 145 | syntax='proto3', 146 | extension_ranges=[], 147 | oneofs=[ 148 | ], 149 | serialized_start=154, 150 | serialized_end=576, 151 | ) 152 | 153 | _TENSORPROTO.fields_by_name['dtype'].enum_type = tensorboardX_dot_src_dot_types__pb2._DATATYPE 154 | _TENSORPROTO.fields_by_name['tensor_shape'].message_type = tensorboardX_dot_src_dot_tensor__shape__pb2._TENSORSHAPEPROTO 155 | _TENSORPROTO.fields_by_name['resource_handle_val'].message_type = tensorboardX_dot_src_dot_resource__handle__pb2._RESOURCEHANDLEPROTO 156 | DESCRIPTOR.message_types_by_name['TensorProto'] = _TENSORPROTO 157 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 158 | 159 | TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict( 160 | DESCRIPTOR = _TENSORPROTO, 161 | __module__ = 'tensorboardX.src.tensor_pb2' 162 | # @@protoc_insertion_point(class_scope:tensorboard.TensorProto) 163 | )) 164 | _sym_db.RegisterMessage(TensorProto) 165 | 166 | 167 | DESCRIPTOR.has_options = True 168 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\014TensorProtosP\001\370\001\001')) 169 | _TENSORPROTO.fields_by_name['half_val'].has_options = True 170 | _TENSORPROTO.fields_by_name['half_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 171 | _TENSORPROTO.fields_by_name['float_val'].has_options = True 172 | _TENSORPROTO.fields_by_name['float_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 173 | _TENSORPROTO.fields_by_name['double_val'].has_options = True 174 | _TENSORPROTO.fields_by_name['double_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 175 | _TENSORPROTO.fields_by_name['int_val'].has_options = True 176 | _TENSORPROTO.fields_by_name['int_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 177 | _TENSORPROTO.fields_by_name['scomplex_val'].has_options = True 178 | _TENSORPROTO.fields_by_name['scomplex_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 179 | _TENSORPROTO.fields_by_name['int64_val'].has_options = True 180 | _TENSORPROTO.fields_by_name['int64_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 181 | _TENSORPROTO.fields_by_name['bool_val'].has_options = True 182 | _TENSORPROTO.fields_by_name['bool_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 183 | _TENSORPROTO.fields_by_name['dcomplex_val'].has_options = True 184 | _TENSORPROTO.fields_by_name['dcomplex_val']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) 185 | # @@protoc_insertion_point(module_scope) 186 | --------------------------------------------------------------------------------