├── .github └── stale.yml ├── .gitignore ├── LICENSE ├── README.md ├── applications ├── __init__.py └── resnet_common.py ├── callbacks └── __init__.py ├── datasets ├── __init__.py ├── dukemtmc_reid.py ├── market1501.py └── msmt17.py ├── evaluation ├── metrics │ ├── __init__.py │ └── rank_cy.pyx └── post_processing │ └── re_ranking_ranklist.py ├── image_augmentation ├── __init__.py └── random_erasing.py ├── metric_learning └── triplet_hermans.py ├── regularizers └── adaptation.py ├── solution.py └── utils ├── model_utils.py └── vis_utils.py /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Configuration for probot-stale - https://github.com/probot/stale 2 | 3 | # Number of days of inactivity before an issue becomes stale 4 | daysUntilStale: 7 5 | # Number of days of inactivity before a stale issue is closed 6 | daysUntilClose: 7 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - bug 10 | # Label to use when marking an issue as stale 11 | staleLabel: stale 12 | # Comment to post when marking an issue as stale. Set to `false` to disable 13 | markComment: > 14 | This issue has been automatically marked as stale because it has not had 15 | recent activity. It will be closed if no further activity occurs. Thank you 16 | for your contributions. 17 | # Comment to post when closing a stale issue. Set to `false` to disable 18 | closeComment: > 19 | Closing as stale. Please reopen if you'd like to work on this further. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 nixingyang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python](https://img.shields.io/badge/python-3.8-blue?style=flat-square&logo=python) 2 | ![TensorFlow](https://img.shields.io/badge/tensorflow-2.2.3-green?style=flat-square&logo=tensorflow) 3 | 4 | # Adaptive L2 Regularization in Person Re-Identification 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/adaptivereid-adaptive-l2-regularization-in/person-re-identification-on-msmt17)](https://paperswithcode.com/sota/person-re-identification-on-msmt17?p=adaptivereid-adaptive-l2-regularization-in) 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/adaptivereid-adaptive-l2-regularization-in/person-re-identification-on-market-1501)](https://paperswithcode.com/sota/person-re-identification-on-market-1501?p=adaptivereid-adaptive-l2-regularization-in) 9 | 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/adaptivereid-adaptive-l2-regularization-in/person-re-identification-on-dukemtmc-reid)](https://paperswithcode.com/sota/person-re-identification-on-dukemtmc-reid?p=adaptivereid-adaptive-l2-regularization-in) 11 | 12 | ## Overview 13 | 14 | We introduce an adaptive L2 regularization mechanism in the setting of person re-identification. 15 | In the literature, it is common practice to utilize hand-picked regularization factors which remain constant throughout the training procedure. 16 | Unlike existing approaches, the regularization factors in our proposed method are updated adaptively through backpropagation. 17 | This is achieved by incorporating trainable scalar variables as the regularization factors, which are further fed into a scaled hard sigmoid function. 18 | Extensive experiments on the Market-1501, DukeMTMC-reID and MSMT17 datasets validate the effectiveness of our framework. 19 | Most notably, we obtain state-of-the-art performance on MSMT17, which is the largest dataset for person re-identification. 20 | Source code is publicly available at https://github.com/nixingyang/AdaptiveL2Regularization. 21 | 22 | ## Environment 23 | 24 | ```bash 25 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 26 | bash Miniconda3-latest-Linux-x86_64.sh 27 | conda config --set auto_activate_base false 28 | conda create --yes --name TensorFlow2.2 python=3.8 29 | conda activate TensorFlow2.2 30 | conda install --yes cudatoolkit=10.1 cudnn=7.6 -c nvidia 31 | conda install --yes cython matplotlib numpy=1.18 pandas pydot scikit-learn 32 | pip install tensorflow==2.2.3 33 | pip install opencv-python 34 | pip install albumentations --no-binary imgaug,albumentations 35 | ``` 36 | 37 | ## Training 38 | 39 | ```bash 40 | python3 -u solution.py --dataset_name "Market1501" --backbone_model_name "ResNet50" 41 | ``` 42 | 43 | - To train on other datasets, replace `"Market1501"` with `"DukeMTMC_reID"` or `"MSMT17"`. 44 | - To train with deeper backbones, replace `"ResNet50"` with `"ResNet101"` or `"ResNet152"`. 45 | - To evaluate on a subset of the complete test set, append `--testing_size 0.5` to the command. Alternatively, you may turn this feature off by using `--testing_size 0.0`. 46 | 47 | ## Evaluation 48 | 49 | ```bash 50 | python3 -u solution.py --dataset_name "Market1501" --backbone_model_name "ResNet50" --pretrained_model_file_path "?.h5" --output_folder_path "evaluation_only" --evaluation_only --freeze_backbone_for_N_epochs 0 --testing_size 1.0 --evaluate_testing_every_N_epochs 1 51 | ``` 52 | 53 | - Fill in the `pretrained_model_file_path` argument using the h5 file obtained during training. 54 | - To use the re-ranking method, append `--use_re_ranking` to the command. 55 | - You need to run this separate evaluation procedure only if `testing_size` is not set to `1.0` during training. 56 | 57 | ## Model Zoo 58 | 59 | | Dataset | Backbone | mAP | Weights | 60 | | - | - | - |- | 61 | | Market1501 | ResNet50 | 88.3 | [Link](https://1drv.ms/u/s!Av-teFsyVR6WjmR-Jys9yzGLnVqm) | 62 | | DukeMTMC_reID | ResNet50 | 79.9 | [Link](https://1drv.ms/u/s!Av-teFsyVR6WjmOBpAY4nCdnTrH3) | 63 | | MSMT17 | ResNet50 | 59.4 | [Link](https://1drv.ms/u/s!Av-teFsyVR6WjmJzh8DHdFc5edDK) | 64 | | MSMT17 | ResNet152 | 62.2 | [Link](https://1drv.ms/u/s!Av-teFsyVR6WjmWPtZpcZkNSYtoi) | 65 | 66 | ## Acknowledgements 67 | 68 | - Evaluation Metrics are adapted from [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid/blob/v1.0.6/torchreid/metrics/rank_cylib/rank_cy.pyx). 69 | - Re-Ranking is adapted from [person-re-ranking](https://github.com/zhunzhong07/person-re-ranking/blob/master/python-version/re_ranking_ranklist.py). 70 | - Random Erasing is adapted from [Random-Erasing](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py). 71 | - Triplet Loss is adapted from [triplet-reid](https://github.com/VisualComputingInstitute/triplet-reid/blob/master/loss.py). 72 | 73 | ## Third-Party Implementation 74 | 75 | - The [adaptive-l2-regularization-pytorch](https://github.com/duyuanchao/adaptive-l2-regularization-pytorch) repository from [duyuanchao](https://github.com/duyuanchao) in PyTorch. 76 | 77 | ## Citation 78 | 79 | Please consider citing [this work](https://ieeexplore.ieee.org/document/9412481) if it helps your research. 80 | 81 | ``` 82 | @inproceedings{ni2021adaptive, 83 | author={Ni, Xingyang and Fang, Liang and Huttunen, Heikki}, 84 | booktitle={2020 25th International Conference on Pattern Recognition (ICPR)}, 85 | title={Adaptive L2 Regularization in Person Re-Identification}, 86 | year={2021}, 87 | volume={}, 88 | number={}, 89 | pages={9601-9607}, 90 | doi={10.1109/ICPR48806.2021.9412481} 91 | } 92 | ``` -------------------------------------------------------------------------------- /applications/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_common import (ResNet50, ResNet50V2, ResNet101, ResNet101V2, 2 | ResNet152, ResNet152V2, ResNeXt50, ResNeXt101) 3 | -------------------------------------------------------------------------------- /applications/resnet_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: 3 | https://github.com/keras-team/keras-applications/blob/1.0.8/keras_applications/resnet_common.py 4 | """ 5 | 6 | from collections import OrderedDict 7 | 8 | from tensorflow.python.keras import backend as K 9 | from tensorflow.python.keras.applications.imagenet_utils import preprocess_input 10 | from tensorflow.python.keras.applications.resnet import (BASE_WEIGHTS_PATH, 11 | WEIGHTS_HASHES, ResNet, 12 | stack1, stack2, stack3) 13 | from tensorflow.python.keras.layers import Input 14 | from tensorflow.python.keras.models import Model 15 | from tensorflow.python.keras.utils import data_utils 16 | 17 | 18 | def init_resnet(stack_fn, preact, use_bias, model_name, input_shape, 19 | block_name_to_hyperparameters_dict, preprocess_input_mode): 20 | # Downloads the weights 21 | file_name = model_name + "_weights_tf_dim_ordering_tf_kernels_notop.h5" 22 | file_hash = WEIGHTS_HASHES[model_name][1] 23 | weights_path = data_utils.get_file(file_name, 24 | BASE_WEIGHTS_PATH + file_name, 25 | cache_subdir="models", 26 | file_hash=file_hash) 27 | 28 | # Define and initialize the first block 29 | first_block = ResNet(stack_fn=lambda x: x, 30 | preact=preact, 31 | use_bias=use_bias, 32 | include_top=False, 33 | weights=None, 34 | input_shape=input_shape) 35 | first_block.load_weights(weights_path, by_name=True) 36 | submodel_list = [first_block] 37 | 38 | # Define and initialize each block 39 | for block_name, (filters, blocks, 40 | stride1) in block_name_to_hyperparameters_dict.items(): 41 | input_tensor = Input(shape=K.int_shape(submodel_list[-1].output)[1:]) 42 | output_tensor = stack_fn(input_tensor, 43 | filters=filters, 44 | blocks=blocks, 45 | stride1=stride1, 46 | name=block_name) 47 | submodel = Model(inputs=input_tensor, 48 | outputs=output_tensor, 49 | name="{}_block".format(block_name)) 50 | submodel.load_weights(weights_path, by_name=True) 51 | submodel_list.append(submodel) 52 | 53 | return submodel_list, lambda x: preprocess_input(x, 54 | mode=preprocess_input_mode) 55 | 56 | 57 | ResNet50 = lambda input_shape, last_stride1=1: init_resnet( 58 | stack_fn=stack1, 59 | preact=False, 60 | use_bias=True, 61 | model_name="resnet50", 62 | input_shape=input_shape, 63 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (64, 3, 1)), 64 | ("conv3", (128, 4, 2)), 65 | ("conv4", (256, 6, 2)), 66 | ("conv5", 67 | (512, 3, last_stride1))]), 68 | preprocess_input_mode="caffe") 69 | 70 | ResNet101 = lambda input_shape, last_stride1=1: init_resnet( 71 | stack_fn=stack1, 72 | preact=False, 73 | use_bias=True, 74 | model_name="resnet101", 75 | input_shape=input_shape, 76 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (64, 3, 1)), 77 | ("conv3", (128, 4, 2)), 78 | ("conv4", (256, 23, 2)), 79 | ("conv5", 80 | (512, 3, last_stride1))]), 81 | preprocess_input_mode="caffe") 82 | 83 | ResNet152 = lambda input_shape, last_stride1=1: init_resnet( 84 | stack_fn=stack1, 85 | preact=False, 86 | use_bias=True, 87 | model_name="resnet152", 88 | input_shape=input_shape, 89 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (64, 3, 1)), 90 | ("conv3", (128, 8, 2)), 91 | ("conv4", (256, 36, 2)), 92 | ("conv5", 93 | (512, 3, last_stride1))]), 94 | preprocess_input_mode="caffe") 95 | 96 | ResNet50V2 = lambda input_shape: init_resnet( 97 | stack_fn=stack2, 98 | preact=False, 99 | use_bias=True, 100 | model_name="resnet50v2", 101 | input_shape=input_shape, 102 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (64, 3, 2)), 103 | ("conv3", (128, 4, 2)), 104 | ("conv4", (256, 6, 2)), 105 | ("conv5", (512, 3, 1))]), 106 | preprocess_input_mode="tf") 107 | 108 | ResNet101V2 = lambda input_shape: init_resnet( 109 | stack_fn=stack2, 110 | preact=False, 111 | use_bias=True, 112 | model_name="resnet101v2", 113 | input_shape=input_shape, 114 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (64, 3, 2)), 115 | ("conv3", (128, 4, 2)), 116 | ("conv4", (256, 23, 2)), 117 | ("conv5", (512, 3, 1))]), 118 | preprocess_input_mode="tf") 119 | 120 | ResNet152V2 = lambda input_shape: init_resnet( 121 | stack_fn=stack2, 122 | preact=False, 123 | use_bias=True, 124 | model_name="resnet152v2", 125 | input_shape=input_shape, 126 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (64, 3, 2)), 127 | ("conv3", (128, 8, 2)), 128 | ("conv4", (256, 36, 2)), 129 | ("conv5", (512, 3, 1))]), 130 | preprocess_input_mode="tf") 131 | 132 | ResNeXt50 = lambda input_shape: init_resnet( 133 | stack_fn=stack3, 134 | preact=False, 135 | use_bias=False, 136 | model_name="resnext50", 137 | input_shape=input_shape, 138 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (128, 3, 1)), 139 | ("conv3", (256, 4, 2)), 140 | ("conv4", (512, 6, 2)), 141 | ("conv5", (1024, 3, 2))]), 142 | preprocess_input_mode="torch") 143 | 144 | ResNeXt101 = lambda input_shape: init_resnet( 145 | stack_fn=stack3, 146 | preact=False, 147 | use_bias=False, 148 | model_name="resnext101", 149 | input_shape=input_shape, 150 | block_name_to_hyperparameters_dict=OrderedDict([("conv2", (128, 3, 1)), 151 | ("conv3", (256, 4, 2)), 152 | ("conv4", (512, 23, 2)), 153 | ("conv5", (1024, 3, 2))]), 154 | preprocess_input_mode="torch") 155 | -------------------------------------------------------------------------------- /callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from itertools import product 4 | 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from tensorflow.keras.callbacks import Callback 9 | 10 | # Specify the backend of matplotlib 11 | matplotlib.use("Agg") 12 | 13 | 14 | class HistoryLogger(Callback): 15 | 16 | def __init__(self, output_folder_path): 17 | super(HistoryLogger, self).__init__() 18 | 19 | self.accumulated_logs_dict = {} 20 | self.output_folder_path = output_folder_path 21 | 22 | if not os.path.isdir(output_folder_path): 23 | os.makedirs(output_folder_path) 24 | 25 | def visualize(self, loss_name): 26 | # Unpack the values 27 | epoch_to_loss_value_dict = self.accumulated_logs_dict[loss_name] 28 | epoch_list = sorted(epoch_to_loss_value_dict.keys()) 29 | loss_value_list = [ 30 | epoch_to_loss_value_dict[epoch] for epoch in epoch_list 31 | ] 32 | epoch_list = (np.array(epoch_list) + 1).tolist() 33 | 34 | # Save the figure to disk 35 | figure = plt.figure() 36 | if isinstance(loss_value_list[0], dict): 37 | for metric_name in loss_value_list[0].keys(): 38 | metric_value_list = [ 39 | loss_value[metric_name] for loss_value in loss_value_list 40 | ] 41 | print("{} {} {:.6f}".format(loss_name, metric_name, 42 | metric_value_list[-1])) 43 | plt.plot(epoch_list, 44 | metric_value_list, 45 | label="{} {:.6f}".format(metric_name, 46 | metric_value_list[-1])) 47 | else: 48 | print("{} {:.6f}".format(loss_name, loss_value_list[-1])) 49 | plt.plot(epoch_list, 50 | loss_value_list, 51 | label="{} {:.6f}".format(loss_name, loss_value_list[-1])) 52 | plt.ylabel(loss_name) 53 | plt.xlabel("Epoch") 54 | plt.grid(True) 55 | plt.legend(loc="best") 56 | plt.savefig( 57 | os.path.join(self.output_folder_path, "{}.png".format(loss_name))) 58 | plt.close(figure) 59 | 60 | def on_epoch_end(self, epoch, logs=None): 61 | # Visualize each figure 62 | for loss_name, loss_value in logs.items(): 63 | if loss_name not in self.accumulated_logs_dict: 64 | self.accumulated_logs_dict[loss_name] = {} 65 | self.accumulated_logs_dict[loss_name][epoch] = loss_value 66 | self.visualize(loss_name) 67 | 68 | # Save the accumulated_logs_dict to disk 69 | with open( 70 | os.path.join(self.output_folder_path, 71 | "accumulated_logs_dict.pkl"), "wb") as file_object: 72 | pickle.dump(self.accumulated_logs_dict, file_object, 73 | pickle.HIGHEST_PROTOCOL) 74 | 75 | # Delete extra keys due to changes in ProgbarLogger 76 | loss_name_list = list(logs.keys()) 77 | split_name_list = ["valid", "test"] 78 | for loss_name, split_name in product(loss_name_list, split_name_list): 79 | if loss_name.startswith(split_name): 80 | _ = logs.pop(loss_name) 81 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | from sklearn.preprocessing import LabelEncoder 5 | 6 | from .dukemtmc_reid import load_DukeMTMC_reID 7 | from .market1501 import load_Market1501 8 | from .msmt17 import load_MSMT17 9 | 10 | 11 | def _get_root_folder_path(): 12 | root_folder_path_list = [ 13 | os.path.expanduser("~/Documents/Local Storage/Dataset"), 14 | "/sgn-data/MLG/nixingyang/Dataset" 15 | ] 16 | root_folder_path_mask = [ 17 | os.path.isdir(folder_path) for folder_path in root_folder_path_list 18 | ] 19 | root_folder_path = root_folder_path_list[root_folder_path_mask.index(True)] 20 | return root_folder_path 21 | 22 | 23 | def _get_attribute_name_to_label_encoder_dict(accumulated_info_dataframe): 24 | attribute_name_to_label_encoder_dict = OrderedDict({}) 25 | accumulated_info_dataframe = accumulated_info_dataframe.drop( 26 | columns=["image_file_path", "camera_ID"]) 27 | for attribute_name in accumulated_info_dataframe.columns: 28 | label_encoder = LabelEncoder() 29 | label_encoder.fit(accumulated_info_dataframe[attribute_name].values) 30 | attribute_name_to_label_encoder_dict[attribute_name] = label_encoder 31 | return attribute_name_to_label_encoder_dict 32 | 33 | 34 | def load_accumulated_info_of_dataset(root_folder_path, dataset_name): 35 | if not os.path.isdir(root_folder_path): 36 | root_folder_path = _get_root_folder_path() 37 | print("Use {} as root_folder_path ...".format(root_folder_path)) 38 | 39 | dataset_name_to_load_function_dict = { 40 | "Market1501": load_Market1501, 41 | "DukeMTMC_reID": load_DukeMTMC_reID, 42 | "MSMT17": load_MSMT17 43 | } 44 | assert dataset_name in dataset_name_to_load_function_dict 45 | load_function = dataset_name_to_load_function_dict[dataset_name] 46 | train_and_valid_accumulated_info_dataframe, test_query_accumulated_info_dataframe, test_gallery_accumulated_info_dataframe = load_function( 47 | root_folder_path=root_folder_path) 48 | 49 | assert not train_and_valid_accumulated_info_dataframe.isnull().values.any( 50 | ) # All fields contain value 51 | train_and_valid_attribute_name_to_label_encoder_dict = _get_attribute_name_to_label_encoder_dict( 52 | train_and_valid_accumulated_info_dataframe) 53 | 54 | return train_and_valid_accumulated_info_dataframe, test_query_accumulated_info_dataframe, test_gallery_accumulated_info_dataframe, \ 55 | train_and_valid_attribute_name_to_label_encoder_dict 56 | -------------------------------------------------------------------------------- /datasets/dukemtmc_reid.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import pandas as pd 5 | 6 | 7 | def _load_accumulated_info(root_folder_path, 8 | dataset_folder_name="DukeMTMC-reID", 9 | image_folder_name="bounding_box_train"): 10 | """ 11 | References: 12 | https://drive.google.com/file/d/1jjE85dRCMOgRtvJ5RQV9-Afs-2_5dY3O/view 13 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 14 | gdrive download 1jjE85dRCMOgRtvJ5RQV9-Afs-2_5dY3O 15 | 7za x DukeMTMC-reID.zip 16 | sha256sum DukeMTMC-reID.zip 17 | 932ae18937b6a77bc59846d4fb00da4ee02cdda93329ca0537ad899a569e3505 DukeMTMC-reID.zip 18 | """ 19 | dataset_folder_path = os.path.join(root_folder_path, dataset_folder_name) 20 | image_folder_path = os.path.join(dataset_folder_path, image_folder_name) 21 | 22 | image_file_path_list = sorted( 23 | glob.glob(os.path.join(image_folder_path, "*.jpg"))) 24 | if image_folder_name == "bounding_box_train": 25 | assert len(image_file_path_list) == 16522 26 | elif image_folder_name == "bounding_box_test": 27 | assert len(image_file_path_list) == 17661 28 | elif image_folder_name == "query": 29 | assert len(image_file_path_list) == 2228 30 | else: 31 | assert False, "{} is an invalid argument!".format(image_folder_name) 32 | 33 | accumulated_info_list = [] 34 | for image_file_path in image_file_path_list: 35 | image_file_name = image_file_path.split(os.sep)[-1] 36 | identity_ID = int(image_file_name.split("_")[0]) 37 | camera_ID = int(image_file_name.split("_")[1][1]) 38 | # Append the records 39 | accumulated_info = { 40 | "image_file_path": image_file_path, 41 | "identity_ID": identity_ID, 42 | "camera_ID": camera_ID 43 | } 44 | accumulated_info_list.append(accumulated_info) 45 | 46 | # Convert list to data frame 47 | accumulated_info_dataframe = pd.DataFrame(accumulated_info_list) 48 | return accumulated_info_dataframe 49 | 50 | 51 | def load_DukeMTMC_reID(root_folder_path): 52 | train_and_valid_accumulated_info_dataframe = _load_accumulated_info( 53 | root_folder_path=root_folder_path, 54 | image_folder_name="bounding_box_train") 55 | test_gallery_accumulated_info_dataframe = _load_accumulated_info( 56 | root_folder_path=root_folder_path, 57 | image_folder_name="bounding_box_test") 58 | test_query_accumulated_info_dataframe = _load_accumulated_info( 59 | root_folder_path=root_folder_path, image_folder_name="query") 60 | return train_and_valid_accumulated_info_dataframe, test_query_accumulated_info_dataframe, test_gallery_accumulated_info_dataframe 61 | -------------------------------------------------------------------------------- /datasets/market1501.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import pandas as pd 5 | 6 | 7 | def _load_accumulated_info(root_folder_path, 8 | dataset_folder_name="Market-1501-v15.09.15", 9 | image_folder_name="bounding_box_train"): 10 | """ 11 | References: 12 | https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view 13 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 14 | gdrive download 0B8-rUzbwVRk0c054eEozWG9COHM 15 | 7za x Market-1501-v15.09.15.zip 16 | sha256sum Market-1501-v15.09.15.zip 17 | 416bb77b5a2449b32e936f623cbee58becf1a9e7e936f36380cb8f9ab928fe96 Market-1501-v15.09.15.zip 18 | """ 19 | dataset_folder_path = os.path.join(root_folder_path, dataset_folder_name) 20 | image_folder_path = os.path.join(dataset_folder_path, image_folder_name) 21 | 22 | image_file_path_list = sorted( 23 | glob.glob(os.path.join(image_folder_path, "*.jpg"))) 24 | if image_folder_name == "bounding_box_train": 25 | assert len(image_file_path_list) == 12936 26 | elif image_folder_name == "bounding_box_test": 27 | assert len(image_file_path_list) == 19732 28 | elif image_folder_name == "query": 29 | assert len(image_file_path_list) == 3368 30 | else: 31 | assert False, "{} is an invalid argument!".format(image_folder_name) 32 | 33 | accumulated_info_list = [] 34 | for image_file_path in image_file_path_list: 35 | # Extract identity_ID 36 | image_file_name = image_file_path.split(os.sep)[-1] 37 | identity_ID = int(image_file_name.split("_")[0]) 38 | if identity_ID == -1: 39 | # Ignore junk images 40 | # https://github.com/Cysu/open-reid/issues/16 41 | # https://github.com/michuanhaohao/reid-strong-baseline/blob/\ 42 | # 69348ceb539fc4bafd006575f7bd432a4d08b9e6/data/datasets/market1501.py#L71 43 | continue 44 | 45 | # Extract camera_ID 46 | cam_seq_ID = image_file_name.split("_")[1] 47 | camera_ID = int(cam_seq_ID[1]) 48 | 49 | # Append the records 50 | accumulated_info = { 51 | "image_file_path": image_file_path, 52 | "identity_ID": identity_ID, 53 | "camera_ID": camera_ID 54 | } 55 | accumulated_info_list.append(accumulated_info) 56 | 57 | # Convert list to data frame 58 | accumulated_info_dataframe = pd.DataFrame(accumulated_info_list) 59 | return accumulated_info_dataframe 60 | 61 | 62 | def load_Market1501(root_folder_path): 63 | train_and_valid_accumulated_info_dataframe = _load_accumulated_info( 64 | root_folder_path=root_folder_path, 65 | image_folder_name="bounding_box_train") 66 | test_gallery_accumulated_info_dataframe = _load_accumulated_info( 67 | root_folder_path=root_folder_path, 68 | image_folder_name="bounding_box_test") 69 | test_query_accumulated_info_dataframe = _load_accumulated_info( 70 | root_folder_path=root_folder_path, image_folder_name="query") 71 | return train_and_valid_accumulated_info_dataframe, test_query_accumulated_info_dataframe, test_gallery_accumulated_info_dataframe 72 | -------------------------------------------------------------------------------- /datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | 6 | def _load_accumulated_info(root_folder_path, 7 | dataset_folder_name="MSMT17_V2", 8 | image_folder_name="mask_train_v2", 9 | list_file_name="list_train.txt"): 10 | # https://www.pkuvmc.com/publications/msmt17.html 11 | dataset_folder_path = os.path.join(root_folder_path, dataset_folder_name) 12 | image_folder_path = os.path.join(dataset_folder_path, image_folder_name) 13 | 14 | accumulated_info_list = [] 15 | list_file_path = os.path.join(dataset_folder_path, list_file_name) 16 | with open(list_file_path) as file_object: 17 | for line_content in file_object: 18 | image_file_name, identity_ID = line_content.split(" ") 19 | identity_ID = int(identity_ID) 20 | image_file_path = os.path.join(image_folder_path, image_file_name) 21 | assert os.path.isfile( 22 | image_file_path), "{} does not exists!".format(image_file_path) 23 | assert identity_ID == int(image_file_name.split(os.sep)[0]) 24 | camera_ID = int(image_file_name.split(os.sep)[1].split("_")[2]) 25 | 26 | # Append the records 27 | accumulated_info = { 28 | "image_file_path": image_file_path, 29 | "identity_ID": identity_ID, 30 | "camera_ID": camera_ID 31 | } 32 | accumulated_info_list.append(accumulated_info) 33 | 34 | # Convert list to data frame 35 | accumulated_info_dataframe = pd.DataFrame(accumulated_info_list) 36 | return accumulated_info_dataframe 37 | 38 | 39 | def load_MSMT17(root_folder_path): 40 | train_accumulated_info_dataframe = _load_accumulated_info( 41 | root_folder_path=root_folder_path, 42 | image_folder_name="mask_train_v2", 43 | list_file_name="list_train.txt") 44 | valid_accumulated_info_dataframe = _load_accumulated_info( 45 | root_folder_path=root_folder_path, 46 | image_folder_name="mask_train_v2", 47 | list_file_name="list_val.txt") 48 | train_and_valid_accumulated_info_dataframe = pd.concat( 49 | [train_accumulated_info_dataframe, valid_accumulated_info_dataframe], 50 | ignore_index=True) 51 | test_query_accumulated_info_dataframe = _load_accumulated_info( 52 | root_folder_path=root_folder_path, 53 | image_folder_name="mask_test_v2", 54 | list_file_name="list_query.txt") 55 | test_gallery_accumulated_info_dataframe = _load_accumulated_info( 56 | root_folder_path=root_folder_path, 57 | image_folder_name="mask_test_v2", 58 | list_file_name="list_gallery.txt") 59 | return train_and_valid_accumulated_info_dataframe, test_query_accumulated_info_dataframe, test_gallery_accumulated_info_dataframe 60 | -------------------------------------------------------------------------------- /evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyximport 3 | 4 | # Cython Compilation 5 | pyximport.install(setup_args={"include_dirs": np.get_include()}, 6 | language_level=3) 7 | 8 | from .rank_cy import evaluate_cy 9 | 10 | 11 | def compute_CMC_mAP(distmat, 12 | q_pids, 13 | g_pids, 14 | q_camids, 15 | g_camids, 16 | max_rank=20, 17 | use_metric_cuhk03=False): 18 | """Evaluates CMC rank. 19 | 20 | Args: 21 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 22 | q_pids (numpy.ndarray): 1-D array containing person identities 23 | of each query instance. 24 | g_pids (numpy.ndarray): 1-D array containing person identities 25 | of each gallery instance. 26 | q_camids (numpy.ndarray): 1-D array containing camera views under 27 | which each query instance is captured. 28 | g_camids (numpy.ndarray): 1-D array containing camera views under 29 | which each gallery instance is captured. 30 | max_rank (int, optional): maximum CMC rank to be computed. Default is 20. 31 | use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03. 32 | Default is False. This should be enabled when using cuhk03 classic split. 33 | """ 34 | return evaluate_cy(distmat=distmat, 35 | q_pids=q_pids, 36 | g_pids=g_pids, 37 | q_camids=q_camids, 38 | g_camids=g_camids, 39 | max_rank=max_rank, 40 | use_metric_cuhk03=use_metric_cuhk03) 41 | -------------------------------------------------------------------------------- /evaluation/metrics/rank_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | 3 | from __future__ import print_function 4 | 5 | import cython 6 | import numpy as np 7 | cimport numpy as np 8 | from collections import defaultdict 9 | import random 10 | 11 | 12 | """ 13 | Compiler directives: 14 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 15 | 16 | Cython tutorial: 17 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 18 | 19 | Credit to https://github.com/luzai 20 | """ 21 | 22 | 23 | # Main interface 24 | cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False): 25 | distmat = np.asarray(distmat, dtype=np.float32) 26 | q_pids = np.asarray(q_pids, dtype=np.int64) 27 | g_pids = np.asarray(g_pids, dtype=np.int64) 28 | q_camids = np.asarray(q_camids, dtype=np.int64) 29 | g_camids = np.asarray(g_camids, dtype=np.int64) 30 | if use_metric_cuhk03: 31 | return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 32 | return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 33 | 34 | 35 | cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 36 | long[:]q_camids, long[:]g_camids, long max_rank): 37 | 38 | cdef long num_q = distmat.shape[0] 39 | cdef long num_g = distmat.shape[1] 40 | 41 | if num_g < max_rank: 42 | max_rank = num_g 43 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 44 | 45 | cdef: 46 | long num_repeats = 10 47 | long[:,:] indices = np.argsort(distmat, axis=1) 48 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 49 | 50 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 51 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 52 | float num_valid_q = 0. # number of valid query 53 | 54 | long q_idx, q_pid, q_camid, g_idx 55 | long[:] order = np.zeros(num_g, dtype=np.int64) 56 | long keep 57 | 58 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 59 | float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32) 60 | float[:] cmc, masked_cmc 61 | long num_g_real, num_g_real_masked, rank_idx, rnd_idx 62 | unsigned long meet_condition 63 | float AP 64 | long[:] kept_g_pids, mask 65 | 66 | float num_rel 67 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 68 | float tmp_cmc_sum 69 | 70 | for q_idx in range(num_q): 71 | # get query pid and camid 72 | q_pid = q_pids[q_idx] 73 | q_camid = q_camids[q_idx] 74 | 75 | # remove gallery samples that have the same pid and camid with query 76 | for g_idx in range(num_g): 77 | order[g_idx] = indices[q_idx, g_idx] 78 | num_g_real = 0 79 | meet_condition = 0 80 | kept_g_pids = np.zeros(num_g, dtype=np.int64) 81 | 82 | for g_idx in range(num_g): 83 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 84 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 85 | kept_g_pids[num_g_real] = g_pids[order[g_idx]] 86 | num_g_real += 1 87 | if matches[q_idx][g_idx] > 1e-31: 88 | meet_condition = 1 89 | 90 | if not meet_condition: 91 | # this condition is true when query identity does not appear in gallery 92 | continue 93 | 94 | # cuhk03-specific setting 95 | g_pids_dict = defaultdict(list) # overhead! 96 | for g_idx in range(num_g_real): 97 | g_pids_dict[kept_g_pids[g_idx]].append(g_idx) 98 | 99 | cmc = np.zeros(max_rank, dtype=np.float32) 100 | for _ in range(num_repeats): 101 | mask = np.zeros(num_g_real, dtype=np.int64) 102 | 103 | for _, idxs in g_pids_dict.items(): 104 | # randomly sample one image for each gallery person 105 | rnd_idx = np.random.choice(idxs) 106 | #rnd_idx = idxs[0] # use deterministic for debugging 107 | mask[rnd_idx] = 1 108 | 109 | num_g_real_masked = 0 110 | for g_idx in range(num_g_real): 111 | if mask[g_idx] == 1: 112 | masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx] 113 | num_g_real_masked += 1 114 | 115 | masked_cmc = np.zeros(num_g, dtype=np.float32) 116 | function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked) 117 | for g_idx in range(num_g_real_masked): 118 | if masked_cmc[g_idx] > 1: 119 | masked_cmc[g_idx] = 1 120 | 121 | for rank_idx in range(max_rank): 122 | cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats 123 | 124 | for rank_idx in range(max_rank): 125 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 126 | # compute average precision 127 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 128 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 129 | num_rel = 0 130 | tmp_cmc_sum = 0 131 | for g_idx in range(num_g_real): 132 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 133 | num_rel += raw_cmc[g_idx] 134 | all_AP[q_idx] = tmp_cmc_sum / num_rel 135 | num_valid_q += 1. 136 | 137 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 138 | 139 | # compute averaged cmc 140 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 141 | for rank_idx in range(max_rank): 142 | for q_idx in range(num_q): 143 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 144 | avg_cmc[rank_idx] /= num_valid_q 145 | 146 | cdef float mAP = 0 147 | for q_idx in range(num_q): 148 | mAP += all_AP[q_idx] 149 | mAP /= num_valid_q 150 | 151 | return np.asarray(avg_cmc).astype(np.float32), mAP 152 | 153 | 154 | cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 155 | long[:]q_camids, long[:]g_camids, long max_rank): 156 | 157 | cdef long num_q = distmat.shape[0] 158 | cdef long num_g = distmat.shape[1] 159 | 160 | if num_g < max_rank: 161 | max_rank = num_g 162 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 163 | 164 | cdef: 165 | long[:,:] indices = np.argsort(distmat, axis=1) 166 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 167 | 168 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 169 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 170 | float num_valid_q = 0. # number of valid query 171 | 172 | long q_idx, q_pid, q_camid, g_idx 173 | long[:] order = np.zeros(num_g, dtype=np.int64) 174 | long keep 175 | 176 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 177 | float[:] cmc = np.zeros(num_g, dtype=np.float32) 178 | long num_g_real, rank_idx 179 | unsigned long meet_condition 180 | 181 | float num_rel 182 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 183 | float tmp_cmc_sum 184 | 185 | for q_idx in range(num_q): 186 | # get query pid and camid 187 | q_pid = q_pids[q_idx] 188 | q_camid = q_camids[q_idx] 189 | 190 | # remove gallery samples that have the same pid and camid with query 191 | for g_idx in range(num_g): 192 | order[g_idx] = indices[q_idx, g_idx] 193 | num_g_real = 0 194 | meet_condition = 0 195 | 196 | for g_idx in range(num_g): 197 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 198 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 199 | num_g_real += 1 200 | if matches[q_idx][g_idx] > 1e-31: 201 | meet_condition = 1 202 | 203 | if not meet_condition: 204 | # this condition is true when query identity does not appear in gallery 205 | continue 206 | 207 | # compute cmc 208 | function_cumsum(raw_cmc, cmc, num_g_real) 209 | for g_idx in range(num_g_real): 210 | if cmc[g_idx] > 1: 211 | cmc[g_idx] = 1 212 | 213 | for rank_idx in range(max_rank): 214 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 215 | num_valid_q += 1. 216 | 217 | # compute average precision 218 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 219 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 220 | num_rel = 0 221 | tmp_cmc_sum = 0 222 | for g_idx in range(num_g_real): 223 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 224 | num_rel += raw_cmc[g_idx] 225 | all_AP[q_idx] = tmp_cmc_sum / num_rel 226 | 227 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 228 | 229 | # compute averaged cmc 230 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 231 | for rank_idx in range(max_rank): 232 | for q_idx in range(num_q): 233 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 234 | avg_cmc[rank_idx] /= num_valid_q 235 | 236 | cdef float mAP = 0 237 | for q_idx in range(num_q): 238 | mAP += all_AP[q_idx] 239 | mAP /= num_valid_q 240 | 241 | return np.asarray(avg_cmc).astype(np.float32), mAP 242 | 243 | 244 | # Compute the cumulative sum 245 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n): 246 | cdef long i 247 | dst[0] = src[0] 248 | for i in range(1, n): 249 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /evaluation/post_processing/re_ranking_ranklist.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/zhunzhong07/person-re-ranking/blob/master/python-version/re_ranking_ranklist.py 3 | 4 | Created on Mon Jun 26 14:46:56 2017 5 | @author: luohao 6 | Modified by Houjing Huang, 2017-12-22. 7 | - This version accepts distance matrix instead of raw features. 8 | - The difference of `/` division between python 2 and 3 is handled. 9 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 10 | 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | 15 | API 16 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 17 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 18 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 19 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 20 | Returns: 21 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 22 | """ 23 | 24 | import numpy as np 25 | 26 | 27 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 28 | 29 | # The following naming, e.g. gallery_num, is different from outer scope. 30 | # Don't care about it. 31 | 32 | original_dist = np.concatenate([ 33 | np.concatenate([q_q_dist, q_g_dist], axis=1), 34 | np.concatenate([q_g_dist.T, g_g_dist], axis=1) 35 | ], 36 | axis=0) 37 | original_dist = np.power(original_dist, 2).astype(np.float32) 38 | original_dist = np.transpose(1. * original_dist / 39 | np.max(original_dist, axis=0)) 40 | V = np.zeros_like(original_dist).astype(np.float32) 41 | initial_rank = np.argsort(original_dist).astype(np.int32) 42 | 43 | query_num = q_g_dist.shape[0] 44 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 45 | all_num = gallery_num 46 | 47 | for i in range(all_num): 48 | # k-reciprocal neighbors 49 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 50 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 51 | fi = np.where(backward_k_neigh_index == i)[0] 52 | k_reciprocal_index = forward_k_neigh_index[fi] 53 | k_reciprocal_expansion_index = k_reciprocal_index 54 | for j in range(len(k_reciprocal_index)): 55 | candidate = k_reciprocal_index[j] 56 | candidate_forward_k_neigh_index = initial_rank[ 57 | candidate, :int(np.around(k1 / 2.)) + 1] 58 | candidate_backward_k_neigh_index = initial_rank[ 59 | candidate_forward_k_neigh_index, :int(np.around(k1 / 2.)) + 1] 60 | fi_candidate = np.where( 61 | candidate_backward_k_neigh_index == candidate)[0] 62 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[ 63 | fi_candidate] 64 | if len( 65 | np.intersect1d(candidate_k_reciprocal_index, 66 | k_reciprocal_index) 67 | ) > 2. / 3 * len(candidate_k_reciprocal_index): 68 | k_reciprocal_expansion_index = np.append( 69 | k_reciprocal_expansion_index, candidate_k_reciprocal_index) 70 | 71 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 72 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 73 | V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight) 74 | original_dist = original_dist[:query_num,] 75 | if k2 != 1: 76 | V_qe = np.zeros_like(V, dtype=np.float32) 77 | for i in range(all_num): 78 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 79 | V = V_qe 80 | del V_qe 81 | del initial_rank 82 | invIndex = [] 83 | for i in range(gallery_num): 84 | invIndex.append(np.where(V[:, i] != 0)[0]) 85 | 86 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float32) 87 | 88 | for i in range(query_num): 89 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32) 90 | indNonZero = np.where(V[i, :] != 0)[0] 91 | indImages = [] 92 | indImages = [invIndex[ind] for ind in indNonZero] 93 | for j in range(len(indNonZero)): 94 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum( 95 | V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) 96 | jaccard_dist[i] = 1 - temp_min / (2. - temp_min) 97 | 98 | final_dist = jaccard_dist * (1 - 99 | lambda_value) + original_dist * lambda_value 100 | del original_dist 101 | del V 102 | del jaccard_dist 103 | final_dist = final_dist[:query_num, query_num:] 104 | return final_dist 105 | -------------------------------------------------------------------------------- /image_augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from urllib.request import urlopen 2 | 3 | import cv2 4 | import numpy as np 5 | from albumentations import (Compose, HorizontalFlip, PadIfNeeded, RandomCrop, 6 | Rotate) 7 | 8 | from .random_erasing import RandomErasing 9 | 10 | 11 | class BaseImageAugmentor(object): 12 | 13 | def __init__(self, 14 | horizontal_flip_probability=0.5, 15 | rotate_limit=0, 16 | image_height=224, 17 | image_width=224, 18 | padding_length=20, 19 | padding_ratio=0): 20 | # Initiation 21 | self.transforms = [] 22 | self.transformer = None 23 | 24 | # Flip the input horizontally 25 | if horizontal_flip_probability > 0: 26 | self.transforms.append( 27 | HorizontalFlip(p=horizontal_flip_probability)) 28 | 29 | # Rotate the input by an angle selected randomly from the uniform distribution 30 | if rotate_limit > 0: 31 | self.transforms.append( 32 | Rotate(limit=rotate_limit, 33 | border_mode=cv2.BORDER_CONSTANT, 34 | value=0, 35 | p=1.0)) 36 | 37 | # Pad side of the image and crop a random part of it 38 | if padding_length > 0 or padding_ratio > 0: 39 | min_height = image_height + int( 40 | max(padding_length, image_height * padding_ratio)) 41 | min_width = image_width + int( 42 | max(padding_length, image_width * padding_ratio)) 43 | self.transforms.append( 44 | PadIfNeeded(min_height=min_height, 45 | min_width=min_width, 46 | border_mode=cv2.BORDER_CONSTANT, 47 | value=0)) 48 | self.transforms.append( 49 | RandomCrop(height=image_height, width=image_width)) 50 | 51 | def add_transforms(self, additional_transforms): 52 | self.transforms += additional_transforms 53 | 54 | def compose_transforms(self): 55 | self.transformer = Compose(transforms=self.transforms) 56 | 57 | def apply_augmentation(self, image_content_array): 58 | transformed_image_content_list = [] 59 | for image_content in image_content_array: 60 | transformed_image_content = self.transformer( 61 | image=image_content)["image"] 62 | transformed_image_content_list.append(transformed_image_content) 63 | return np.array(transformed_image_content_list, dtype=np.uint8) 64 | 65 | 66 | class RandomErasingImageAugmentor(BaseImageAugmentor): 67 | 68 | def __init__(self, **kwargs): 69 | super(RandomErasingImageAugmentor, self).__init__(**kwargs) 70 | additional_transforms = [RandomErasing()] 71 | self.add_transforms(additional_transforms) 72 | 73 | 74 | def example(): 75 | print("Loading the image content ...") 76 | raw_data = urlopen( 77 | url="https://avatars3.githubusercontent.com/u/15064790").read() 78 | raw_data = np.frombuffer(raw_data, np.uint8) 79 | image_content = cv2.imdecode(raw_data, cv2.IMREAD_COLOR) 80 | image_content = cv2.cvtColor(image_content, cv2.COLOR_BGR2RGB) 81 | image_height, image_width = image_content.shape[:2] 82 | 83 | print("Initiating the image augmentor ...") 84 | image_augmentor = RandomErasingImageAugmentor(image_height=image_height, 85 | image_width=image_width) 86 | image_augmentor.compose_transforms() 87 | 88 | print("Generating the batch ...") 89 | image_content_list = [image_content] * 8 90 | image_content_array = np.array(image_content_list) 91 | 92 | print("Applying data augmentation ...") 93 | image_content_array = image_augmentor.apply_augmentation( 94 | image_content_array) 95 | 96 | print("Visualization ...") 97 | for image_index, image_content in enumerate(image_content_array, start=1): 98 | image_content = cv2.cvtColor(image_content, cv2.COLOR_RGB2BGR) 99 | cv2.imshow("image {}".format(image_index), image_content) 100 | cv2.waitKey(0) 101 | cv2.destroyAllWindows() 102 | 103 | print("All done!") 104 | 105 | 106 | if __name__ == "__main__": 107 | example() 108 | -------------------------------------------------------------------------------- /image_augmentation/random_erasing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from albumentations import ImageOnlyTransform 3 | 4 | 5 | def apply_random_erasing(image_content, sl, sh, r1, mean, max_attempt_num): 6 | # Make a copy of the input image since we don't want to modify it directly 7 | image_content = image_content.copy() 8 | image_height, image_width = image_content.shape[:-1] 9 | image_area = image_height * image_width 10 | for _ in range(max_attempt_num): 11 | target_area = np.random.uniform(sl, sh) * image_area 12 | aspect_ratio = np.random.uniform(r1, 1 / r1) 13 | erasing_height = int(np.round(np.sqrt(target_area * aspect_ratio))) 14 | erasing_width = int(np.round(np.sqrt(target_area / aspect_ratio))) 15 | if erasing_width < image_width and erasing_height < image_height: 16 | starting_height = np.random.randint(0, 17 | image_height - erasing_height) 18 | starting_width = np.random.randint(0, image_width - erasing_width) 19 | image_content[starting_height:starting_height + erasing_height, 20 | starting_width:starting_width + 21 | erasing_width] = np.array(mean, 22 | dtype=np.float32) * 255 23 | break 24 | return image_content 25 | 26 | 27 | class RandomErasing(ImageOnlyTransform): 28 | """ 29 | References: 30 | https://arxiv.org/abs/1708.04896 31 | https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py 32 | https://github.com/albumentations-team/albumentations/blob/0.4.0/albumentations/augmentations/transforms.py#L1492-L1569 33 | """ 34 | 35 | def __init__(self, 36 | sl=0.02, 37 | sh=0.4, 38 | r1=0.3, 39 | mean=(0.4914, 0.4822, 0.4465), 40 | max_attempt_num=100, 41 | always_apply=False, 42 | p=0.5): 43 | super(RandomErasing, self).__init__(always_apply, p) 44 | self.sl = sl 45 | self.sh = sh 46 | self.r1 = r1 47 | self.mean = mean 48 | self.max_attempt_num = max_attempt_num 49 | 50 | def apply(self, image, sl, sh, r1, mean, max_attempt_num, **params): # pylint: disable=arguments-differ 51 | return apply_random_erasing(image, sl, sh, r1, mean, max_attempt_num) 52 | 53 | def get_params_dependent_on_targets(self, params): 54 | return { 55 | "sl": self.sl, 56 | "sh": self.sh, 57 | "r1": self.r1, 58 | "mean": self.mean, 59 | "max_attempt_num": self.max_attempt_num 60 | } 61 | 62 | @property 63 | def targets_as_params(self): 64 | return ["image"] 65 | 66 | def get_transform_init_args_names(self): 67 | return ("sl", "sh", "r1", "mean", "max_attempt_num") 68 | -------------------------------------------------------------------------------- /metric_learning/triplet_hermans.py: -------------------------------------------------------------------------------- 1 | # https://github.com/VisualComputingInstitute/triplet-reid 2 | # https://blog.csdn.net/lwplwf/article/details/84562494 3 | # https://omoindrot.github.io/triplet-loss 4 | # https://github.com/omoindrot/tensorflow-triplet-loss 5 | import numbers 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def all_diffs(a, b): 11 | """Returns a tensor of all combinations of a - b. 12 | 13 | Args: 14 | a (2D tensor): A batch of vectors shaped (B1, F). 15 | b (2D tensor): A batch of vectors shaped (B2, F). 16 | 17 | Returns: 18 | The matrix of all pairwise differences between all vectors in `a` and in 19 | `b`, will be of shape (B1, B2). 20 | 21 | Note: 22 | For convenience, if either `a` or `b` is a `Distribution` object, its 23 | mean is used. 24 | """ 25 | return tf.expand_dims(a, axis=1) - tf.expand_dims(b, axis=0) 26 | 27 | 28 | def cdist(a, b, metric='euclidean'): 29 | """Similar to scipy.spatial's cdist, but symbolic. 30 | 31 | The currently supported metrics can be listed as `cdist.supported_metrics` and are: 32 | - 'euclidean', although with a fudge-factor epsilon. 33 | - 'sqeuclidean', the squared euclidean. 34 | - 'cityblock', the manhattan or L1 distance. 35 | 36 | Args: 37 | a (2D tensor): The left-hand side, shaped (B1, F). 38 | b (2D tensor): The right-hand side, shaped (B2, F). 39 | metric (string): Which distance metric to use, see notes. 40 | 41 | Returns: 42 | The matrix of all pairwise distances between all vectors in `a` and in 43 | `b`, will be of shape (B1, B2). 44 | 45 | Note: 46 | When a square root is taken (such as in the Euclidean case), a small 47 | epsilon is added because the gradient of the square-root at zero is 48 | undefined. Thus, it will never return exact zero in these cases. 49 | """ 50 | with tf.name_scope("cdist"): 51 | diffs = all_diffs(a, b) 52 | if metric == 'sqeuclidean': 53 | return tf.reduce_sum(tf.square(diffs), axis=-1) 54 | elif metric == 'euclidean': 55 | return tf.sqrt(tf.reduce_sum(tf.square(diffs), axis=-1) + 1e-12) 56 | elif metric == 'cityblock': 57 | return tf.reduce_sum(tf.abs(diffs), axis=-1) 58 | else: 59 | raise NotImplementedError( 60 | 'The following metric is not implemented by `cdist` yet: {}'. 61 | format(metric)) 62 | 63 | 64 | cdist.supported_metrics = [ 65 | 'euclidean', 66 | 'sqeuclidean', 67 | 'cityblock', 68 | ] 69 | 70 | 71 | def get_at_indices(tensor, indices): 72 | """Like `tensor[np.arange(len(tensor)), indices]` in numpy.""" 73 | counter = tf.range(tf.shape(indices, out_type=indices.dtype)[0]) 74 | return tf.gather_nd(tensor, tf.stack((counter, indices), -1)) 75 | 76 | 77 | def batch_hard(dists, pids, margin, batch_precision_at_k=None): 78 | """Computes the batch-hard loss from arxiv.org/abs/1703.07737. 79 | 80 | Args: 81 | dists (2D tensor): A square all-to-all distance matrix as given by cdist. 82 | pids (1D tensor): The identities of the entries in `batch`, shape (B,). 83 | This can be of any type that can be compared, thus also a string. 84 | margin: The value of the margin if a number, alternatively the string 85 | 'soft' for using the soft-margin formulation, or `None` for not 86 | using a margin at all. 87 | 88 | Returns: 89 | A 1D tensor of shape (B,) containing the loss value for each sample. 90 | """ 91 | with tf.name_scope("batch_hard"): 92 | same_identity_mask = tf.equal(tf.expand_dims(pids, axis=1), 93 | tf.expand_dims(pids, axis=0)) 94 | negative_mask = tf.math.logical_not(same_identity_mask) 95 | positive_mask = tf.math.logical_xor( 96 | same_identity_mask, tf.eye(tf.shape(pids)[0], dtype=tf.bool)) 97 | 98 | furthest_positive = tf.reduce_max(dists * 99 | tf.cast(positive_mask, tf.float32), 100 | axis=1) 101 | closest_negative = tf.map_fn( 102 | lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])), 103 | (dists, negative_mask), tf.float32) 104 | # Another way of achieving the same, though more hacky: 105 | # closest_negative = tf.reduce_min(dists + 1e5*tf.cast(same_identity_mask, tf.float32), axis=1) 106 | 107 | diff = furthest_positive - closest_negative 108 | if isinstance(margin, numbers.Real): 109 | diff = tf.maximum(diff + margin, 0.0) 110 | elif margin == 'soft': 111 | diff = tf.nn.softplus(diff) 112 | elif margin.lower() == 'none': 113 | pass 114 | else: 115 | raise NotImplementedError( 116 | 'The margin {} is not implemented in batch_hard'.format(margin)) 117 | 118 | if batch_precision_at_k is None: 119 | return diff 120 | 121 | # For monitoring, compute the within-batch top-1 accuracy and the 122 | # within-batch precision-at-k, which is somewhat more expressive. 123 | with tf.name_scope("monitoring"): 124 | # This is like argsort along the last axis. Add one to K as we'll 125 | # drop the diagonal. 126 | _, indices = tf.nn.top_k(-dists, k=batch_precision_at_k + 1) 127 | 128 | # Drop the diagonal (distance to self is always least). 129 | indices = indices[:, 1:] 130 | 131 | # Generate the index indexing into the batch dimension. 132 | # This is simething like [[0,0,0],[1,1,1],...,[B,B,B]] 133 | batch_index = tf.tile(tf.expand_dims(tf.range(tf.shape(indices)[0]), 1), 134 | (1, tf.shape(indices)[1])) 135 | 136 | # Stitch the above together with the argsort indices to get the 137 | # indices of the top-k of each row. 138 | topk_indices = tf.stack((batch_index, indices), -1) 139 | 140 | # See if the topk belong to the same person as they should, or not. 141 | topk_is_same = tf.gather_nd(same_identity_mask, topk_indices) 142 | 143 | # All of the above could be reduced to the simpler following if k==1 144 | # top1_is_same = get_at_indices(same_identity_mask, top_idxs[:,1]) 145 | 146 | topk_is_same_f32 = tf.cast(topk_is_same, tf.float32) 147 | top1 = tf.reduce_mean(topk_is_same_f32[:, 0]) 148 | prec_at_k = tf.reduce_mean(topk_is_same_f32) 149 | 150 | # Finally, let's get some more info that can help in debugging while 151 | # we're at it! 152 | negative_dists = tf.boolean_mask(dists, negative_mask) 153 | positive_dists = tf.boolean_mask(dists, positive_mask) 154 | 155 | return diff, top1, prec_at_k, topk_is_same, negative_dists, positive_dists 156 | -------------------------------------------------------------------------------- /regularizers/adaptation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras import backend as K 4 | from tensorflow.keras.callbacks import Callback 5 | from tensorflow.keras.initializers import Constant 6 | from tensorflow.keras.layers import Layer 7 | from tensorflow.keras.models import Model 8 | 9 | 10 | class AdaptiveL1L2(Layer): 11 | 12 | def __init__(self, 13 | amplitude_l1=None, 14 | amplitude_l2=None, 15 | omega_l1=1, 16 | omega_l2=1, 17 | **kwargs): 18 | super(AdaptiveL1L2, self).__init__(**kwargs) 19 | self.amplitude_l1 = amplitude_l1 20 | self.amplitude_l2 = amplitude_l2 21 | self.omega_l1 = omega_l1 22 | self.omega_l2 = omega_l2 23 | self.preprocess_function = lambda x: amplitude_l2 * np.clip( 24 | 0.2 * omega_l2 * x + 0.5, 0.0, 1.0 25 | ) # NB: This function is only applicable to L2. 26 | 27 | def build(self, input_shape): 28 | self.l1_regularization_factor = None # pylint: disable=attribute-defined-outside-init 29 | if self.amplitude_l1 is not None: 30 | self.l1_regularization_factor = self.add_weight( # pylint: disable=attribute-defined-outside-init 31 | name="l1_regularization_factor", 32 | initializer=Constant(0)) 33 | self.l2_regularization_factor = None # pylint: disable=attribute-defined-outside-init 34 | if self.amplitude_l2 is not None: 35 | self.l2_regularization_factor = self.add_weight( # pylint: disable=attribute-defined-outside-init 36 | name="l2_regularization_factor", 37 | initializer=Constant(0)) 38 | super(AdaptiveL1L2, self).build(input_shape) 39 | 40 | def call(self, inputs): # pylint: disable=arguments-differ 41 | regularization = 0. 42 | if self.l1_regularization_factor is not None: 43 | regularization += self.amplitude_l1 * K.hard_sigmoid( 44 | self.omega_l1 * self.l1_regularization_factor) * tf.reduce_sum( 45 | tf.abs(inputs)) 46 | if self.l2_regularization_factor is not None: 47 | regularization += self.amplitude_l2 * K.hard_sigmoid( 48 | self.omega_l2 * self.l2_regularization_factor) * tf.reduce_sum( 49 | tf.square(inputs)) 50 | return regularization 51 | 52 | def get_config(self): 53 | config = { 54 | "amplitude_l1": self.amplitude_l1, 55 | "amplitude_l2": self.amplitude_l2, 56 | "omega_l1": self.omega_l1, 57 | "omega_l2": self.omega_l2 58 | } 59 | base_config = super(AdaptiveL1L2, self).get_config() 60 | return dict(list(base_config.items()) + list(config.items())) 61 | 62 | 63 | class InspectRegularizationFactors(Callback): 64 | 65 | def forward(self, model, logs): 66 | for item in model.layers: 67 | if isinstance(item, Model): 68 | self.forward(item, logs) 69 | for regularizer_name in [ 70 | "kernel_regularizer", "bias_regularizer", 71 | "gamma_regularizer", "beta_regularizer" 72 | ]: 73 | if not hasattr(item, regularizer_name): 74 | continue 75 | regularizer = getattr(item, regularizer_name) 76 | if not isinstance(regularizer, AdaptiveL1L2): 77 | continue 78 | preprocess_function = regularizer.preprocess_function 79 | for variable_name in [ 80 | "l1_regularization_factor", "l2_regularization_factor" 81 | ]: 82 | regularization_factor = getattr(regularizer, variable_name) 83 | if regularization_factor is not None: 84 | logs["{}_{}_{}".format( 85 | item.name, regularizer_name, 86 | variable_name)] = preprocess_function( 87 | K.get_value(regularization_factor)) 88 | 89 | def on_epoch_end(self, epoch, logs=None): # @UnusedVariable 90 | self.forward(self.model, logs) 91 | -------------------------------------------------------------------------------- /solution.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | import time 5 | from datetime import datetime 6 | 7 | import cv2 8 | import numpy as np 9 | import pandas as pd 10 | import tensorflow as tf 11 | from absl import app, flags 12 | from sklearn.metrics import pairwise_distances 13 | from sklearn.model_selection import GroupShuffleSplit, StratifiedShuffleSplit 14 | from tensorflow.keras import backend as K 15 | from tensorflow.keras.callbacks import (Callback, LearningRateScheduler, 16 | ModelCheckpoint) 17 | from tensorflow.keras.initializers import RandomNormal 18 | from tensorflow.keras.layers import (Activation, BatchNormalization, 19 | Concatenate, Conv2D, Dense, 20 | GlobalAveragePooling2D, Input, Lambda) 21 | from tensorflow.keras.losses import categorical_crossentropy 22 | from tensorflow.keras.models import Model 23 | from tensorflow.keras.optimizers import Adam 24 | from tensorflow.keras.utils import Sequence 25 | 26 | import applications 27 | import image_augmentation 28 | from callbacks import HistoryLogger 29 | from datasets import load_accumulated_info_of_dataset 30 | from evaluation.metrics import compute_CMC_mAP 31 | from evaluation.post_processing.re_ranking_ranklist import re_ranking 32 | from metric_learning.triplet_hermans import batch_hard, cdist 33 | from regularizers.adaptation import InspectRegularizationFactors 34 | from utils.model_utils import replicate_model, specify_regularizers 35 | from utils.vis_utils import summarize_model, visualize_model 36 | 37 | flags.DEFINE_string("root_folder_path", "", "Folder path of the dataset.") 38 | flags.DEFINE_string("dataset_name", "Market1501", "Name of the dataset.") 39 | # ["Market1501", "DukeMTMC_reID", "MSMT17"] 40 | flags.DEFINE_string("backbone_model_name", "ResNet50", 41 | "Name of the backbone model.") 42 | # ["ResNet50", "ResNet101", "ResNet152", 43 | # "ResNet50V2", "ResNet101V2", "ResNet152V2", 44 | # "ResNeXt50", "ResNeXt101"] 45 | flags.DEFINE_integer("freeze_backbone_for_N_epochs", 20, 46 | "Freeze layers in the backbone model for N epochs.") 47 | flags.DEFINE_integer("image_width", 128, "Width of the images.") 48 | flags.DEFINE_integer("image_height", 384, "Height of the images.") 49 | flags.DEFINE_integer("region_num", 2, 50 | "Number of regions in the regional branch.") 51 | flags.DEFINE_float("kernel_regularization_factor", 0.005, 52 | "Regularization factor of kernel.") 53 | flags.DEFINE_float("bias_regularization_factor", 0.005, 54 | "Regularization factor of bias.") 55 | flags.DEFINE_float("gamma_regularization_factor", 0.005, 56 | "Regularization factor of gamma.") 57 | flags.DEFINE_float("beta_regularization_factor", 0.005, 58 | "Regularization factor of beta.") 59 | flags.DEFINE_bool("use_adaptive_l1_l2_regularizer", True, 60 | "Use the adaptive L1L2 regularizer.") 61 | flags.DEFINE_float("min_value_in_clipping", 0.0, 62 | "Minimum value when using the clipping function.") 63 | flags.DEFINE_float("max_value_in_clipping", 1.0, 64 | "Maximum value when using the clipping function.") 65 | flags.DEFINE_float("validation_size", 0.0, 66 | "Proportion or absolute number of validation samples.") 67 | flags.DEFINE_float("testing_size", 1.0, 68 | "Proportion or absolute number of testing groups.") 69 | flags.DEFINE_integer( 70 | "evaluate_validation_every_N_epochs", 1, 71 | "Evaluate the performance on validation samples every N epochs.") 72 | flags.DEFINE_integer( 73 | "evaluate_testing_every_N_epochs", 10, 74 | "Evaluate the performance on testing samples every N epochs.") 75 | flags.DEFINE_integer("identity_num_per_batch", 16, 76 | "Number of identities in one batch.") 77 | flags.DEFINE_integer("image_num_per_identity", 4, 78 | "Number of images of one identity.") 79 | flags.DEFINE_string("learning_rate_mode", "default", 80 | "Mode of the learning rate scheduler.") 81 | # ["constant", "linear", "cosine", "warmup", "default"] 82 | flags.DEFINE_float("learning_rate_start", 2e-4, "Starting learning rate.") 83 | flags.DEFINE_float("learning_rate_end", 2e-4, "Ending learning rate.") 84 | flags.DEFINE_float("learning_rate_base", 2e-4, "Base learning rate.") 85 | flags.DEFINE_integer("learning_rate_warmup_epochs", 10, 86 | "Number of epochs to warmup the learning rate.") 87 | flags.DEFINE_integer("learning_rate_steady_epochs", 30, 88 | "Number of epochs to keep the learning rate steady.") 89 | flags.DEFINE_float("learning_rate_drop_factor", 10, 90 | "Factor to decrease the learning rate.") 91 | flags.DEFINE_float("learning_rate_lower_bound", 2e-6, 92 | "Lower bound of the learning rate.") 93 | flags.DEFINE_integer("steps_per_epoch", 200, "Number of steps per epoch.") 94 | flags.DEFINE_integer("epoch_num", 200, "Number of epochs.") 95 | flags.DEFINE_integer("workers", 5, 96 | "Number of processes to spin up for data generator.") 97 | flags.DEFINE_string("image_augmentor_name", "RandomErasingImageAugmentor", 98 | "Name of image augmentor.") 99 | # ["BaseImageAugmentor", "RandomErasingImageAugmentor"] 100 | flags.DEFINE_bool("use_data_augmentation_in_training", True, 101 | "Use data augmentation in training.") 102 | flags.DEFINE_bool("use_data_augmentation_in_evaluation", False, 103 | "Use data augmentation in evaluation.") 104 | flags.DEFINE_integer("augmentation_num", 1, 105 | "Number of augmented samples to use in evaluation.") 106 | flags.DEFINE_bool("use_horizontal_flipping_in_evaluation", True, 107 | "Use horizontal flipping in evaluation.") 108 | flags.DEFINE_bool("use_identity_balancing_in_training", False, 109 | "Use identity balancing in training.") 110 | flags.DEFINE_bool("use_re_ranking", False, "Use the re-ranking method.") 111 | flags.DEFINE_bool("evaluation_only", False, "Only perform evaluation.") 112 | flags.DEFINE_bool("save_data_to_disk", False, 113 | "Save image features, identity ID and camera ID to disk.") 114 | flags.DEFINE_string("pretrained_model_file_path", "", 115 | "File path of the pretrained model.") 116 | flags.DEFINE_string( 117 | "output_folder_path", 118 | os.path.abspath( 119 | os.path.join(__file__, "../output_{}".format( 120 | datetime.now().strftime("%Y_%m_%d")))), 121 | "Path to directory to output files.") 122 | FLAGS = flags.FLAGS 123 | 124 | 125 | def apply_stratifiedshufflesplit(y, test_size, random_state=0): 126 | if test_size == 1: 127 | train_indexes = np.arange(len(y)) # Hacky snippet 128 | test_indexes = np.arange(len(y)) 129 | else: 130 | shufflesplit_instance = StratifiedShuffleSplit( 131 | n_splits=1, test_size=test_size, random_state=random_state) 132 | train_indexes, test_indexes = next( 133 | shufflesplit_instance.split(np.arange(len(y)), y=y)) 134 | return train_indexes, test_indexes 135 | 136 | 137 | def apply_groupshufflesplit(groups, test_size, random_state=0): 138 | groupshufflesplit_instance = GroupShuffleSplit(n_splits=1, 139 | test_size=test_size, 140 | random_state=random_state) 141 | train_indexes, test_indexes = next( 142 | groupshufflesplit_instance.split(np.arange(len(groups)), groups=groups)) 143 | return train_indexes, test_indexes 144 | 145 | 146 | def init_model(backbone_model_name, 147 | freeze_backbone_for_N_epochs, 148 | input_shape, 149 | region_num, 150 | attribute_name_to_label_encoder_dict, 151 | kernel_regularization_factor, 152 | bias_regularization_factor, 153 | gamma_regularization_factor, 154 | beta_regularization_factor, 155 | use_adaptive_l1_l2_regularizer, 156 | min_value_in_clipping, 157 | max_value_in_clipping, 158 | share_last_block=False): 159 | 160 | def _add_objective_module(input_tensor): 161 | # Add a pooling layer if needed 162 | if len(K.int_shape(input_tensor)) == 4: 163 | global_pooling_tensor = GlobalAveragePooling2D()(input_tensor) 164 | else: 165 | global_pooling_tensor = input_tensor 166 | if min_value_in_clipping is not None and max_value_in_clipping is not None: 167 | global_pooling_tensor = Lambda(lambda x: K.clip( 168 | x, 169 | min_value=min_value_in_clipping, 170 | max_value=max_value_in_clipping))(global_pooling_tensor) 171 | 172 | # https://arxiv.org/abs/1801.07698v1 Section 3.2.2 Output setting 173 | # https://arxiv.org/abs/1807.11042 174 | classification_input_tensor = global_pooling_tensor 175 | classification_embedding_tensor = BatchNormalization( 176 | scale=True, epsilon=2e-5)(classification_input_tensor) 177 | 178 | # Add categorical crossentropy loss 179 | label_encoder = attribute_name_to_label_encoder_dict["identity_ID"] 180 | class_num = len(label_encoder.classes_) 181 | classification_output_tensor = Dense( 182 | units=class_num, 183 | use_bias=False, 184 | kernel_initializer=RandomNormal( 185 | mean=0.0, stddev=0.001))(classification_embedding_tensor) 186 | classification_output_tensor = Activation("softmax")( 187 | classification_output_tensor) 188 | 189 | # Add miscellaneous loss 190 | miscellaneous_input_tensor = global_pooling_tensor 191 | miscellaneous_embedding_tensor = miscellaneous_input_tensor 192 | miscellaneous_output_tensor = miscellaneous_input_tensor 193 | 194 | return classification_output_tensor, classification_embedding_tensor, miscellaneous_output_tensor, miscellaneous_embedding_tensor 195 | 196 | def _apply_concatenation(tensor_list): 197 | if len(tensor_list) == 1: 198 | return tensor_list[0] 199 | else: 200 | return Concatenate()(tensor_list) 201 | 202 | def _triplet_hermans_loss(y_true, 203 | y_pred, 204 | metric="euclidean", 205 | margin="soft"): 206 | # Create the loss in two steps: 207 | # 1. Compute all pairwise distances according to the specified metric. 208 | # 2. For each anchor along the first dimension, compute its loss. 209 | dists = cdist(y_pred, y_pred, metric=metric) 210 | loss = batch_hard(dists=dists, 211 | pids=tf.argmax(y_true, axis=-1), 212 | margin=margin) 213 | return loss 214 | 215 | # Initiation 216 | classification_output_tensor_list = [] 217 | classification_embedding_tensor_list = [] 218 | miscellaneous_output_tensor_list = [] 219 | miscellaneous_embedding_tensor_list = [] 220 | 221 | # Initiate the early blocks 222 | model_instantiation = getattr(applications, backbone_model_name, None) 223 | assert model_instantiation is not None, "Backbone {} is not supported.".format( 224 | backbone_model_name) 225 | submodel_list, preprocess_input = model_instantiation( 226 | input_shape=input_shape) 227 | vanilla_input_tensor = Input(shape=K.int_shape(submodel_list[0].input)[1:]) 228 | intermediate_output_tensor = vanilla_input_tensor 229 | for submodel in submodel_list[:-1]: 230 | if freeze_backbone_for_N_epochs > 0: 231 | submodel.trainable = False 232 | intermediate_output_tensor = submodel(intermediate_output_tensor) 233 | 234 | # Initiate the last blocks 235 | last_block = submodel_list[-1] 236 | last_block_for_global_branch_model = replicate_model( 237 | last_block, name="last_block_for_global_branch") 238 | if freeze_backbone_for_N_epochs > 0: 239 | last_block_for_global_branch_model.trainable = False 240 | if share_last_block: 241 | last_block_for_regional_branch_model = last_block_for_global_branch_model 242 | else: 243 | last_block_for_regional_branch_model = replicate_model( 244 | last_block, name="last_block_for_regional_branch") 245 | if freeze_backbone_for_N_epochs > 0: 246 | last_block_for_regional_branch_model.trainable = False 247 | 248 | # Add the global branch 249 | classification_output_tensor, classification_embedding_tensor, miscellaneous_output_tensor, miscellaneous_embedding_tensor = _add_objective_module( 250 | last_block_for_global_branch_model(intermediate_output_tensor)) 251 | classification_output_tensor_list.append(classification_output_tensor) 252 | classification_embedding_tensor_list.append(classification_embedding_tensor) 253 | miscellaneous_output_tensor_list.append(miscellaneous_output_tensor) 254 | miscellaneous_embedding_tensor_list.append(miscellaneous_embedding_tensor) 255 | 256 | # Add the regional branch 257 | if region_num > 0: 258 | # Process each region 259 | regional_branch_output_tensor = last_block_for_regional_branch_model( 260 | intermediate_output_tensor) 261 | total_height = K.int_shape(regional_branch_output_tensor)[1] 262 | region_size = total_height // region_num 263 | for region_index in np.arange(region_num): 264 | # Get a slice of feature maps 265 | start_index = region_index * region_size 266 | end_index = (region_index + 1) * region_size 267 | if region_index == region_num - 1: 268 | end_index = total_height 269 | sliced_regional_branch_output_tensor = Lambda( 270 | lambda x, start_index=start_index, end_index=end_index: 271 | x[:, start_index:end_index])(regional_branch_output_tensor) 272 | 273 | # Downsampling 274 | sliced_regional_branch_output_tensor = Conv2D( 275 | filters=K.int_shape(sliced_regional_branch_output_tensor)[-1] // 276 | region_num, 277 | kernel_size=3, 278 | padding="same")(sliced_regional_branch_output_tensor) 279 | sliced_regional_branch_output_tensor = Activation("relu")( 280 | sliced_regional_branch_output_tensor) 281 | 282 | # Add the regional branch 283 | classification_output_tensor, classification_embedding_tensor, miscellaneous_output_tensor, miscellaneous_embedding_tensor = _add_objective_module( 284 | sliced_regional_branch_output_tensor) 285 | classification_output_tensor_list.append( 286 | classification_output_tensor) 287 | classification_embedding_tensor_list.append( 288 | classification_embedding_tensor) 289 | miscellaneous_output_tensor_list.append(miscellaneous_output_tensor) 290 | miscellaneous_embedding_tensor_list.append( 291 | miscellaneous_embedding_tensor) 292 | 293 | # Define the merged model 294 | embedding_tensor_list = [ 295 | _apply_concatenation(miscellaneous_embedding_tensor_list) 296 | ] 297 | embedding_size_list = [ 298 | K.int_shape(embedding_tensor)[1] 299 | for embedding_tensor in embedding_tensor_list 300 | ] 301 | merged_embedding_tensor = _apply_concatenation(embedding_tensor_list) 302 | merged_model = Model(inputs=[vanilla_input_tensor], 303 | outputs=classification_output_tensor_list + 304 | miscellaneous_output_tensor_list + 305 | [merged_embedding_tensor]) 306 | merged_model = specify_regularizers(merged_model, 307 | kernel_regularization_factor, 308 | bias_regularization_factor, 309 | gamma_regularization_factor, 310 | beta_regularization_factor, 311 | use_adaptive_l1_l2_regularizer) 312 | 313 | # Define the models for training/inference 314 | training_model = Model(inputs=[merged_model.input], 315 | outputs=merged_model.output[:-1], 316 | name="training_model") 317 | inference_model = Model(inputs=[merged_model.input], 318 | outputs=[merged_model.output[-1]], 319 | name="inference_model") 320 | inference_model.embedding_size_list = embedding_size_list 321 | 322 | # Compile the model 323 | categorical_crossentropy_loss_function = lambda y_true, y_pred: 1.0 * categorical_crossentropy( 324 | y_true, y_pred, from_logits=False, label_smoothing=0.1) 325 | classification_loss_function_list = [ 326 | categorical_crossentropy_loss_function 327 | ] * len(classification_output_tensor_list) 328 | triplet_hermans_loss_function = lambda y_true, y_pred: 1.0 * _triplet_hermans_loss( 329 | y_true, y_pred) 330 | miscellaneous_loss_function_list = [triplet_hermans_loss_function 331 | ] * len(miscellaneous_output_tensor_list) 332 | training_model.compile_kwargs = { 333 | "optimizer": 334 | Adam(), 335 | "loss": 336 | classification_loss_function_list + miscellaneous_loss_function_list 337 | } 338 | training_model.compile(**training_model.compile_kwargs) 339 | 340 | # Print the summary of the models 341 | summarize_model(training_model) 342 | summarize_model(inference_model) 343 | 344 | return training_model, inference_model, preprocess_input 345 | 346 | 347 | def read_image_file(image_file_path, input_shape): 348 | # Read image file 349 | image_content = cv2.imread(image_file_path) 350 | 351 | # Resize the image 352 | image_content = cv2.resize(image_content, input_shape[:2][::-1]) 353 | 354 | # Convert from BGR to RGB 355 | image_content = cv2.cvtColor(image_content, cv2.COLOR_BGR2RGB) 356 | 357 | return image_content 358 | 359 | 360 | class TrainDataSequence(Sequence): 361 | 362 | def __init__(self, accumulated_info_dataframe, 363 | attribute_name_to_label_encoder_dict, preprocess_input, 364 | input_shape, image_augmentor, use_data_augmentation, 365 | use_identity_balancing, label_repetition_num, 366 | identity_num_per_batch, image_num_per_identity, 367 | steps_per_epoch): 368 | super(TrainDataSequence, self).__init__() 369 | 370 | # Save as variables 371 | self.accumulated_info_dataframe, self.attribute_name_to_label_encoder_dict, self.preprocess_input, self.input_shape = accumulated_info_dataframe, attribute_name_to_label_encoder_dict, preprocess_input, input_shape 372 | self.image_augmentor, self.use_data_augmentation, self.use_identity_balancing = image_augmentor, use_data_augmentation, use_identity_balancing 373 | self.label_repetition_num = label_repetition_num 374 | self.identity_num_per_batch, self.image_num_per_identity, self.steps_per_epoch = identity_num_per_batch, image_num_per_identity, steps_per_epoch 375 | 376 | # Unpack image_file_path and identity_ID 377 | self.image_file_path_array, self.identity_ID_array = self.accumulated_info_dataframe[ 378 | ["image_file_path", "identity_ID"]].values.transpose() 379 | self.image_file_path_to_record_index_dict = dict([ 380 | (image_file_path, record_index) 381 | for record_index, image_file_path in enumerate( 382 | self.image_file_path_array) 383 | ]) 384 | self.batch_size = identity_num_per_batch * image_num_per_identity 385 | self.image_num_per_epoch = self.batch_size * steps_per_epoch 386 | 387 | # Initiation 388 | self.image_file_path_list_generator = self._get_image_file_path_list_generator( 389 | ) 390 | self.image_file_path_list = next(self.image_file_path_list_generator) 391 | 392 | def _get_image_file_path_list_generator(self): 393 | # Map identity ID to image file paths 394 | identity_ID_to_image_file_paths_dict = {} 395 | for image_file_path, identity_ID in zip(self.image_file_path_array, 396 | self.identity_ID_array): 397 | if identity_ID not in identity_ID_to_image_file_paths_dict: 398 | identity_ID_to_image_file_paths_dict[identity_ID] = [] 399 | identity_ID_to_image_file_paths_dict[identity_ID].append( 400 | image_file_path) 401 | 402 | image_file_path_list = [] 403 | while True: 404 | # Split image file paths into multiple sections 405 | identity_ID_to_image_file_paths_in_sections_dict = {} 406 | for identity_ID in identity_ID_to_image_file_paths_dict: 407 | image_file_paths = np.array( 408 | identity_ID_to_image_file_paths_dict[identity_ID]) 409 | if len(image_file_paths) < self.image_num_per_identity: 410 | continue 411 | np.random.shuffle(image_file_paths) 412 | section_num = int( 413 | len(image_file_paths) / self.image_num_per_identity) 414 | image_file_paths = image_file_paths[:section_num * 415 | self.image_num_per_identity] 416 | image_file_paths_in_sections = np.split(image_file_paths, 417 | section_num) 418 | identity_ID_to_image_file_paths_in_sections_dict[ 419 | identity_ID] = image_file_paths_in_sections 420 | 421 | while len(identity_ID_to_image_file_paths_in_sections_dict 422 | ) >= self.identity_num_per_batch: 423 | # Choose identity_num_per_batch identity_IDs 424 | identity_IDs = np.random.choice( 425 | list(identity_ID_to_image_file_paths_in_sections_dict.keys( 426 | )), 427 | size=self.identity_num_per_batch, 428 | replace=False) 429 | for identity_ID in identity_IDs: 430 | # Get one section 431 | image_file_paths_in_sections = identity_ID_to_image_file_paths_in_sections_dict[ 432 | identity_ID] 433 | image_file_paths = image_file_paths_in_sections.pop(-1) 434 | if self.use_identity_balancing or len( 435 | image_file_paths_in_sections) == 0: 436 | del identity_ID_to_image_file_paths_in_sections_dict[ 437 | identity_ID] 438 | 439 | # Add the entries 440 | image_file_path_list += image_file_paths.tolist() 441 | 442 | if len(image_file_path_list) == self.image_num_per_epoch: 443 | yield image_file_path_list 444 | image_file_path_list = [] 445 | 446 | def __len__(self): 447 | return self.steps_per_epoch 448 | 449 | def __getitem__(self, index): 450 | label_encoder = self.attribute_name_to_label_encoder_dict["identity_ID"] 451 | image_content_list, one_hot_encoding_list = [], [] 452 | image_file_path_list = self.image_file_path_list[index * 453 | self.batch_size: 454 | (index + 1) * 455 | self.batch_size] 456 | for image_file_path in image_file_path_list: 457 | # Read image 458 | image_content = read_image_file(image_file_path, self.input_shape) 459 | image_content_list.append(image_content) 460 | 461 | # Get current record from accumulated_info_dataframe 462 | record_index = self.image_file_path_to_record_index_dict[ 463 | image_file_path] 464 | accumulated_info = self.accumulated_info_dataframe.iloc[ 465 | record_index] 466 | assert image_file_path == accumulated_info["image_file_path"] 467 | 468 | # Get the one hot encoding vector 469 | identity_ID = accumulated_info["identity_ID"] 470 | one_hot_encoding = np.zeros(len(label_encoder.classes_)) 471 | one_hot_encoding[label_encoder.transform([identity_ID])[0]] = 1 472 | one_hot_encoding_list.append(one_hot_encoding) 473 | 474 | # Construct image_content_array 475 | image_content_array = np.array(image_content_list) 476 | if self.use_data_augmentation: 477 | # Apply data augmentation 478 | image_content_array = self.image_augmentor.apply_augmentation( 479 | image_content_array) 480 | # Apply preprocess_input function 481 | image_content_array = self.preprocess_input(image_content_array) 482 | 483 | # Construct one_hot_encoding_array_list 484 | one_hot_encoding_array = np.array(one_hot_encoding_list) 485 | one_hot_encoding_array_list = [one_hot_encoding_array 486 | ] * self.label_repetition_num 487 | 488 | return image_content_array, one_hot_encoding_array_list 489 | 490 | def on_epoch_end(self): 491 | self.image_file_path_list = next(self.image_file_path_list_generator) 492 | 493 | 494 | class TestDataSequence(Sequence): 495 | 496 | def __init__(self, accumulated_info_dataframe, preprocess_input, 497 | input_shape, image_augmentor, use_data_augmentation, 498 | batch_size): 499 | super(TestDataSequence, self).__init__() 500 | 501 | # Save as variables 502 | self.accumulated_info_dataframe, self.preprocess_input, self.input_shape = accumulated_info_dataframe, preprocess_input, input_shape 503 | self.image_augmentor, self.use_data_augmentation = image_augmentor, use_data_augmentation 504 | 505 | # Unpack image_file_path and identity_ID 506 | self.image_file_path_array = self.accumulated_info_dataframe[ 507 | "image_file_path"].values 508 | self.batch_size = batch_size 509 | self.steps_per_epoch = int( 510 | np.ceil(len(self.image_file_path_array) / self.batch_size)) 511 | 512 | # Initiation 513 | self.image_file_path_list = self.image_file_path_array.tolist() 514 | self.use_horizontal_flipping = False 515 | 516 | def enable_horizontal_flipping(self): 517 | self.use_horizontal_flipping = True 518 | 519 | def disable_horizontal_flipping(self): 520 | self.use_horizontal_flipping = False 521 | 522 | def __len__(self): 523 | return self.steps_per_epoch 524 | 525 | def __getitem__(self, index): 526 | image_content_list = [] 527 | image_file_path_list = self.image_file_path_list[index * 528 | self.batch_size: 529 | (index + 1) * 530 | self.batch_size] 531 | for image_file_path in image_file_path_list: 532 | # Read image 533 | image_content = read_image_file(image_file_path, self.input_shape) 534 | if self.use_horizontal_flipping: 535 | image_content = cv2.flip(image_content, 1) 536 | image_content_list.append(image_content) 537 | 538 | # Construct image_content_array 539 | image_content_array = np.array(image_content_list) 540 | if self.use_data_augmentation: 541 | # Apply data augmentation 542 | image_content_array = self.image_augmentor.apply_augmentation( 543 | image_content_array) 544 | # Apply preprocess_input function 545 | image_content_array = self.preprocess_input(image_content_array) 546 | 547 | return image_content_array 548 | 549 | 550 | class Evaluator(Callback): 551 | 552 | def __init__(self, 553 | inference_model, 554 | split_name, 555 | query_accumulated_info_dataframe, 556 | gallery_accumulated_info_dataframe, 557 | preprocess_input, 558 | input_shape, 559 | image_augmentor, 560 | use_data_augmentation, 561 | augmentation_num, 562 | use_horizontal_flipping, 563 | use_re_ranking, 564 | batch_size, 565 | workers, 566 | use_multiprocessing, 567 | rank_list=(1, 5, 10, 20), 568 | every_N_epochs=1, 569 | output_folder_path=None): 570 | super(Evaluator, self).__init__() 571 | if hasattr(self, "_supports_tf_logs"): 572 | self._supports_tf_logs = True 573 | 574 | self.callback_disabled = query_accumulated_info_dataframe is None or gallery_accumulated_info_dataframe is None 575 | if self.callback_disabled: 576 | return 577 | 578 | self.inference_model = inference_model 579 | self.split_name = split_name 580 | self.query_generator = TestDataSequence( 581 | query_accumulated_info_dataframe, preprocess_input, input_shape, 582 | image_augmentor, use_data_augmentation, batch_size) 583 | self.gallery_generator = TestDataSequence( 584 | gallery_accumulated_info_dataframe, preprocess_input, input_shape, 585 | image_augmentor, use_data_augmentation, batch_size) 586 | self.query_identity_ID_array, self.query_camera_ID_array = query_accumulated_info_dataframe[ 587 | ["identity_ID", "camera_ID"]].values.transpose() 588 | self.gallery_identity_ID_array, self.gallery_camera_ID_array = gallery_accumulated_info_dataframe[ 589 | ["identity_ID", "camera_ID"]].values.transpose() 590 | self.preprocess_input, self.input_shape, self.image_augmentor, self.use_data_augmentation, self.augmentation_num, self.use_horizontal_flipping = \ 591 | preprocess_input, input_shape, image_augmentor, use_data_augmentation, augmentation_num, use_horizontal_flipping 592 | self.use_re_ranking, self.batch_size = use_re_ranking, batch_size 593 | self.workers, self.use_multiprocessing = workers, use_multiprocessing 594 | self.rank_list, self.every_N_epochs = rank_list, every_N_epochs 595 | self.output_file_path = None if output_folder_path is None else os.path.join( 596 | output_folder_path, "{}.npz".format(split_name)) 597 | 598 | if not use_data_augmentation and augmentation_num != 1: 599 | print( 600 | "Set augmentation_num to 1 since use_data_augmentation is False." 601 | ) 602 | self.augmentation_num = 1 603 | 604 | self.metrics = ["cosine"] 605 | 606 | def extract_features(self, data_generator): 607 | # Extract the accumulated_feature_array 608 | accumulated_feature_array = None 609 | for _ in np.arange(self.augmentation_num): 610 | data_generator.disable_horizontal_flipping() 611 | feature_array = self.inference_model.predict( 612 | x=data_generator, 613 | workers=self.workers, 614 | use_multiprocessing=self.use_multiprocessing) 615 | if self.use_horizontal_flipping: 616 | data_generator.enable_horizontal_flipping() 617 | feature_array += self.inference_model.predict( 618 | x=data_generator, 619 | workers=self.workers, 620 | use_multiprocessing=self.use_multiprocessing) 621 | feature_array /= 2 622 | if accumulated_feature_array is None: 623 | accumulated_feature_array = feature_array / self.augmentation_num 624 | else: 625 | accumulated_feature_array += feature_array / self.augmentation_num 626 | return accumulated_feature_array 627 | 628 | def split_features(self, accumulated_feature_array): 629 | # Split the accumulated_feature_array into separate slices 630 | feature_array_list = [] 631 | for embedding_size_index in np.arange( 632 | len(self.inference_model.embedding_size_list)): 633 | if embedding_size_index == 0: 634 | start_index = 0 635 | end_index = self.inference_model.embedding_size_list[0] 636 | else: 637 | start_index = np.sum(self.inference_model. 638 | embedding_size_list[:embedding_size_index]) 639 | end_index = np.sum(self.inference_model. 640 | embedding_size_list[:embedding_size_index + 641 | 1]) 642 | feature_array = accumulated_feature_array[:, start_index:end_index] 643 | feature_array_list.append(feature_array) 644 | return feature_array_list 645 | 646 | def compute_distance_matrix(self, query_image_features, 647 | gallery_image_features, metric, use_re_ranking): 648 | # Compute the distance matrix 649 | query_gallery_distance = pairwise_distances(query_image_features, 650 | gallery_image_features, 651 | metric=metric) 652 | distance_matrix = query_gallery_distance 653 | 654 | # Use the re-ranking method 655 | if use_re_ranking: 656 | query_query_distance = pairwise_distances(query_image_features, 657 | query_image_features, 658 | metric=metric) 659 | gallery_gallery_distance = pairwise_distances( 660 | gallery_image_features, gallery_image_features, metric=metric) 661 | distance_matrix = re_ranking(query_gallery_distance, 662 | query_query_distance, 663 | gallery_gallery_distance) 664 | 665 | return distance_matrix 666 | 667 | def on_epoch_end(self, epoch, logs=None): 668 | if self.callback_disabled or (epoch + 1) % self.every_N_epochs != 0: 669 | return 670 | 671 | # Extract features 672 | feature_extraction_start = time.time() 673 | query_image_features_array = self.extract_features(self.query_generator) 674 | gallery_image_features_array = self.extract_features( 675 | self.gallery_generator) 676 | feature_extraction_end = time.time() 677 | feature_extraction_speed = ( 678 | len(query_image_features_array) + len(gallery_image_features_array) 679 | ) / (feature_extraction_end - feature_extraction_start) 680 | print("Speed of feature extraction: {:.2f} images per second.".format( 681 | feature_extraction_speed)) 682 | 683 | # Check unique values in the features array 684 | query_unique_values = np.unique(query_image_features_array) 685 | additional_metrics = [] 686 | if len(query_unique_values) <= 2**8: 687 | print("Unique values in query_image_features_array: {}".format( 688 | query_unique_values)) 689 | additional_metrics.append("hamming") 690 | 691 | # Save image features, identity ID and camera ID to disk 692 | if self.output_file_path is not None: 693 | np.savez(self.output_file_path, 694 | query_image_features_array=query_image_features_array, 695 | gallery_image_features_array=gallery_image_features_array, 696 | query_identity_ID_array=self.query_identity_ID_array, 697 | gallery_identity_ID_array=self.gallery_identity_ID_array, 698 | query_camera_ID_array=self.query_camera_ID_array, 699 | gallery_camera_ID_array=self.gallery_camera_ID_array) 700 | 701 | # Split features 702 | print("embedding_size_list:", self.inference_model.embedding_size_list) 703 | query_image_features_list = self.split_features( 704 | query_image_features_array) 705 | gallery_image_features_list = self.split_features( 706 | gallery_image_features_array) 707 | 708 | for metric in self.metrics + additional_metrics: 709 | distance_matrix_list = [] 710 | for query_image_features, gallery_image_features in zip( 711 | query_image_features_list, gallery_image_features_list): 712 | distance_matrix = self.compute_distance_matrix( 713 | query_image_features, gallery_image_features, metric, 714 | self.use_re_ranking) 715 | distance_matrix_list.append(distance_matrix) 716 | 717 | method_name_list = (np.arange(len(distance_matrix_list)) + 718 | 1).tolist() 719 | for distance_matrix, method_name in zip(distance_matrix_list, 720 | method_name_list): 721 | # Compute the CMC and mAP scores 722 | CMC_score_array, mAP_score = compute_CMC_mAP( 723 | distmat=distance_matrix, 724 | q_pids=self.query_identity_ID_array, 725 | g_pids=self.gallery_identity_ID_array, 726 | q_camids=self.query_camera_ID_array, 727 | g_camids=self.gallery_camera_ID_array) 728 | 729 | # Append the CMC and mAP scores 730 | logs["{}_{}_{}_{}_rank_to_accuracy_dict".format( 731 | self.split_name, metric, self.use_re_ranking, 732 | method_name)] = dict([("rank-{} accuracy".format(rank), 733 | CMC_score_array[rank - 1]) 734 | for rank in self.rank_list]) 735 | logs["{}_{}_{}_{}_mAP_score".format(self.split_name, metric, 736 | self.use_re_ranking, 737 | method_name)] = mAP_score 738 | 739 | 740 | def learning_rate_scheduler(epoch_index, epoch_num, learning_rate_mode, 741 | learning_rate_start, learning_rate_end, 742 | learning_rate_base, learning_rate_warmup_epochs, 743 | learning_rate_steady_epochs, 744 | learning_rate_drop_factor, 745 | learning_rate_lower_bound): 746 | learning_rate = None 747 | if learning_rate_mode == "constant": 748 | assert learning_rate_start == learning_rate_end, "starting and ending learning rates should be equal!" 749 | learning_rate = learning_rate_start 750 | elif learning_rate_mode == "linear": 751 | learning_rate = (learning_rate_end - learning_rate_start) / ( 752 | epoch_num - 1) * epoch_index + learning_rate_start 753 | elif learning_rate_mode == "cosine": 754 | assert learning_rate_start > learning_rate_end, "starting learning rate should be higher than ending learning rate!" 755 | learning_rate = (learning_rate_start - learning_rate_end) / 2 * np.cos( 756 | np.pi * epoch_index / 757 | (epoch_num - 1)) + (learning_rate_start + learning_rate_end) / 2 758 | elif learning_rate_mode == "warmup": 759 | learning_rate = (learning_rate_end - learning_rate_start) / ( 760 | learning_rate_warmup_epochs - 1) * epoch_index + learning_rate_start 761 | learning_rate = np.min((learning_rate, learning_rate_end)) 762 | elif learning_rate_mode == "default": 763 | if epoch_index < learning_rate_warmup_epochs: 764 | learning_rate = (learning_rate_base - learning_rate_lower_bound) / ( 765 | learning_rate_warmup_epochs - 766 | 1) * epoch_index + learning_rate_lower_bound 767 | else: 768 | if learning_rate_drop_factor == 0: 769 | learning_rate_drop_factor = np.exp( 770 | learning_rate_steady_epochs / 771 | (epoch_num - learning_rate_warmup_epochs * 2) * 772 | np.log(learning_rate_base / learning_rate_lower_bound)) 773 | learning_rate = learning_rate_base / np.power( 774 | learning_rate_drop_factor, 775 | int((epoch_index - learning_rate_warmup_epochs) / 776 | learning_rate_steady_epochs)) 777 | else: 778 | assert False, "{} is an invalid argument!".format(learning_rate_mode) 779 | learning_rate = np.max((learning_rate, learning_rate_lower_bound)) 780 | return learning_rate 781 | 782 | 783 | def main(_): 784 | print("Getting hyperparameters ...") 785 | print("Using command {}".format(" ".join(sys.argv))) 786 | flag_values_dict = FLAGS.flag_values_dict() 787 | for flag_name in sorted(flag_values_dict.keys()): 788 | flag_value = flag_values_dict[flag_name] 789 | print(flag_name, flag_value) 790 | root_folder_path, dataset_name = FLAGS.root_folder_path, FLAGS.dataset_name 791 | backbone_model_name, freeze_backbone_for_N_epochs = FLAGS.backbone_model_name, FLAGS.freeze_backbone_for_N_epochs 792 | image_height, image_width = FLAGS.image_height, FLAGS.image_width 793 | input_shape = (image_height, image_width, 3) 794 | region_num = FLAGS.region_num 795 | kernel_regularization_factor = FLAGS.kernel_regularization_factor 796 | bias_regularization_factor = FLAGS.bias_regularization_factor 797 | gamma_regularization_factor = FLAGS.gamma_regularization_factor 798 | beta_regularization_factor = FLAGS.beta_regularization_factor 799 | use_adaptive_l1_l2_regularizer = FLAGS.use_adaptive_l1_l2_regularizer 800 | min_value_in_clipping, max_value_in_clipping = FLAGS.min_value_in_clipping, FLAGS.max_value_in_clipping 801 | validation_size = FLAGS.validation_size 802 | validation_size = int( 803 | validation_size) if validation_size > 1 else validation_size 804 | use_validation = validation_size != 0 805 | testing_size = FLAGS.testing_size 806 | testing_size = int(testing_size) if testing_size > 1 else testing_size 807 | use_testing = testing_size != 0 808 | evaluate_validation_every_N_epochs = FLAGS.evaluate_validation_every_N_epochs 809 | evaluate_testing_every_N_epochs = FLAGS.evaluate_testing_every_N_epochs 810 | identity_num_per_batch, image_num_per_identity = FLAGS.identity_num_per_batch, FLAGS.image_num_per_identity 811 | batch_size = identity_num_per_batch * image_num_per_identity 812 | learning_rate_mode, learning_rate_start, learning_rate_end = FLAGS.learning_rate_mode, FLAGS.learning_rate_start, FLAGS.learning_rate_end 813 | learning_rate_base, learning_rate_warmup_epochs, learning_rate_steady_epochs = FLAGS.learning_rate_base, FLAGS.learning_rate_warmup_epochs, FLAGS.learning_rate_steady_epochs 814 | learning_rate_drop_factor, learning_rate_lower_bound = FLAGS.learning_rate_drop_factor, FLAGS.learning_rate_lower_bound 815 | steps_per_epoch = FLAGS.steps_per_epoch 816 | epoch_num = FLAGS.epoch_num 817 | workers = FLAGS.workers 818 | use_multiprocessing = workers > 1 819 | image_augmentor_name = FLAGS.image_augmentor_name 820 | use_data_augmentation_in_training = FLAGS.use_data_augmentation_in_training 821 | use_data_augmentation_in_evaluation = FLAGS.use_data_augmentation_in_evaluation 822 | augmentation_num = FLAGS.augmentation_num 823 | use_horizontal_flipping_in_evaluation = FLAGS.use_horizontal_flipping_in_evaluation 824 | use_identity_balancing_in_training = FLAGS.use_identity_balancing_in_training 825 | use_re_ranking = FLAGS.use_re_ranking 826 | evaluation_only, save_data_to_disk = FLAGS.evaluation_only, FLAGS.save_data_to_disk 827 | pretrained_model_file_path = FLAGS.pretrained_model_file_path 828 | 829 | output_folder_path = os.path.abspath( 830 | os.path.join( 831 | FLAGS.output_folder_path, 832 | "{}_{}x{}".format(dataset_name, input_shape[0], input_shape[1]), 833 | "{}_{}_{}".format(backbone_model_name, identity_num_per_batch, 834 | image_num_per_identity))) 835 | shutil.rmtree(output_folder_path, ignore_errors=True) 836 | os.makedirs(output_folder_path) 837 | print("Recreating the output folder at {} ...".format(output_folder_path)) 838 | 839 | print("Loading the annotations of the {} dataset ...".format(dataset_name)) 840 | train_and_valid_accumulated_info_dataframe, test_query_accumulated_info_dataframe, \ 841 | test_gallery_accumulated_info_dataframe, train_and_valid_attribute_name_to_label_encoder_dict = \ 842 | load_accumulated_info_of_dataset(root_folder_path=root_folder_path, dataset_name=dataset_name) 843 | 844 | if use_validation: 845 | print("Using customized cross validation splits ...") 846 | train_and_valid_identity_ID_array = train_and_valid_accumulated_info_dataframe[ 847 | "identity_ID"].values 848 | train_indexes, valid_indexes = apply_stratifiedshufflesplit( 849 | y=train_and_valid_identity_ID_array, test_size=validation_size) 850 | train_accumulated_info_dataframe = train_and_valid_accumulated_info_dataframe.iloc[ 851 | train_indexes] 852 | valid_accumulated_info_dataframe = train_and_valid_accumulated_info_dataframe.iloc[ 853 | valid_indexes] 854 | 855 | print("Splitting the validation dataset ...") 856 | valid_identity_ID_array = valid_accumulated_info_dataframe[ 857 | "identity_ID"].values 858 | gallery_size = len(test_gallery_accumulated_info_dataframe) / ( 859 | len(test_query_accumulated_info_dataframe) + 860 | len(test_gallery_accumulated_info_dataframe)) 861 | valid_query_indexes, valid_gallery_indexes = apply_stratifiedshufflesplit( 862 | y=valid_identity_ID_array, test_size=gallery_size) 863 | valid_query_accumulated_info_dataframe = valid_accumulated_info_dataframe.iloc[ 864 | valid_query_indexes] 865 | valid_gallery_accumulated_info_dataframe = valid_accumulated_info_dataframe.iloc[ 866 | valid_gallery_indexes] 867 | else: 868 | train_accumulated_info_dataframe = train_and_valid_accumulated_info_dataframe 869 | valid_query_accumulated_info_dataframe, valid_gallery_accumulated_info_dataframe = None, None 870 | 871 | if use_testing: 872 | if testing_size != 1: 873 | print("Using a subset from the testing dataset ...") 874 | test_accumulated_info_dataframe = pd.concat([ 875 | test_query_accumulated_info_dataframe, 876 | test_gallery_accumulated_info_dataframe 877 | ], 878 | ignore_index=True) 879 | test_identity_ID_array = test_accumulated_info_dataframe[ 880 | "identity_ID"].values 881 | _, test_query_and_gallery_indexes = apply_groupshufflesplit( 882 | groups=test_identity_ID_array, test_size=testing_size) 883 | test_query_mask = test_query_and_gallery_indexes < len( 884 | test_query_accumulated_info_dataframe) 885 | test_gallery_mask = np.logical_not(test_query_mask) 886 | test_query_indexes, test_gallery_indexes = test_query_and_gallery_indexes[ 887 | test_query_mask], test_query_and_gallery_indexes[ 888 | test_gallery_mask] 889 | test_query_accumulated_info_dataframe = test_accumulated_info_dataframe.iloc[ 890 | test_query_indexes] 891 | test_gallery_accumulated_info_dataframe = test_accumulated_info_dataframe.iloc[ 892 | test_gallery_indexes] 893 | else: 894 | test_query_accumulated_info_dataframe, test_gallery_accumulated_info_dataframe = None, None 895 | 896 | print("Initiating the model ...") 897 | training_model, inference_model, preprocess_input = init_model( 898 | backbone_model_name=backbone_model_name, 899 | freeze_backbone_for_N_epochs=freeze_backbone_for_N_epochs, 900 | input_shape=input_shape, 901 | region_num=region_num, 902 | attribute_name_to_label_encoder_dict= 903 | train_and_valid_attribute_name_to_label_encoder_dict, 904 | kernel_regularization_factor=kernel_regularization_factor, 905 | bias_regularization_factor=bias_regularization_factor, 906 | gamma_regularization_factor=gamma_regularization_factor, 907 | beta_regularization_factor=beta_regularization_factor, 908 | use_adaptive_l1_l2_regularizer=use_adaptive_l1_l2_regularizer, 909 | min_value_in_clipping=min_value_in_clipping, 910 | max_value_in_clipping=max_value_in_clipping) 911 | visualize_model(model=training_model, output_folder_path=output_folder_path) 912 | 913 | print("Initiating the image augmentor {} ...".format(image_augmentor_name)) 914 | image_augmentor = getattr(image_augmentation, 915 | image_augmentor_name)(image_height=image_height, 916 | image_width=image_width) 917 | image_augmentor.compose_transforms() 918 | 919 | print("Perform training ...") 920 | train_generator = TrainDataSequence( 921 | accumulated_info_dataframe=train_accumulated_info_dataframe, 922 | attribute_name_to_label_encoder_dict= 923 | train_and_valid_attribute_name_to_label_encoder_dict, 924 | preprocess_input=preprocess_input, 925 | input_shape=input_shape, 926 | image_augmentor=image_augmentor, 927 | use_data_augmentation=use_data_augmentation_in_training, 928 | use_identity_balancing=use_identity_balancing_in_training, 929 | label_repetition_num=len(training_model.outputs), 930 | identity_num_per_batch=identity_num_per_batch, 931 | image_num_per_identity=image_num_per_identity, 932 | steps_per_epoch=steps_per_epoch) 933 | valid_evaluator_callback = Evaluator( 934 | inference_model=inference_model, 935 | split_name="valid", 936 | query_accumulated_info_dataframe=valid_query_accumulated_info_dataframe, 937 | gallery_accumulated_info_dataframe= 938 | valid_gallery_accumulated_info_dataframe, 939 | preprocess_input=preprocess_input, 940 | input_shape=input_shape, 941 | image_augmentor=image_augmentor, 942 | use_data_augmentation=use_data_augmentation_in_evaluation, 943 | augmentation_num=augmentation_num, 944 | use_horizontal_flipping=use_horizontal_flipping_in_evaluation, 945 | use_re_ranking=use_re_ranking, 946 | batch_size=batch_size, 947 | workers=workers, 948 | use_multiprocessing=use_multiprocessing, 949 | every_N_epochs=evaluate_validation_every_N_epochs, 950 | output_folder_path=output_folder_path if save_data_to_disk else None) 951 | test_evaluator_callback = Evaluator( 952 | inference_model=inference_model, 953 | split_name="test", 954 | query_accumulated_info_dataframe=test_query_accumulated_info_dataframe, 955 | gallery_accumulated_info_dataframe= 956 | test_gallery_accumulated_info_dataframe, 957 | preprocess_input=preprocess_input, 958 | input_shape=input_shape, 959 | image_augmentor=image_augmentor, 960 | use_data_augmentation=use_data_augmentation_in_evaluation, 961 | augmentation_num=augmentation_num, 962 | use_horizontal_flipping=use_horizontal_flipping_in_evaluation, 963 | use_re_ranking=use_re_ranking, 964 | batch_size=batch_size, 965 | workers=workers, 966 | use_multiprocessing=use_multiprocessing, 967 | every_N_epochs=evaluate_testing_every_N_epochs, 968 | output_folder_path=output_folder_path if save_data_to_disk else None) 969 | inspect_regularization_factors_callback = InspectRegularizationFactors() 970 | optimal_model_file_path = os.path.join(output_folder_path, 971 | "training_model.h5") 972 | modelcheckpoint_monitor = "test_cosine_False_1_mAP_score" if use_testing else "valid_cosine_False_1_mAP_score" 973 | modelcheckpoint_callback = ModelCheckpoint(filepath=optimal_model_file_path, 974 | monitor=modelcheckpoint_monitor, 975 | mode="max", 976 | save_best_only=True, 977 | save_weights_only=False, 978 | verbose=1) 979 | learningratescheduler_callback = LearningRateScheduler( 980 | schedule=lambda epoch_index: learning_rate_scheduler( 981 | epoch_index=epoch_index, 982 | epoch_num=epoch_num, 983 | learning_rate_mode=learning_rate_mode, 984 | learning_rate_start=learning_rate_start, 985 | learning_rate_end=learning_rate_end, 986 | learning_rate_base=learning_rate_base, 987 | learning_rate_warmup_epochs=learning_rate_warmup_epochs, 988 | learning_rate_steady_epochs=learning_rate_steady_epochs, 989 | learning_rate_drop_factor=learning_rate_drop_factor, 990 | learning_rate_lower_bound=learning_rate_lower_bound), 991 | verbose=1) 992 | if len(pretrained_model_file_path) > 0: 993 | assert os.path.isfile(pretrained_model_file_path) 994 | print("Loading weights from {} ...".format(pretrained_model_file_path)) 995 | # Hacky workaround for the issue with "load_weights" 996 | if use_adaptive_l1_l2_regularizer: 997 | _ = training_model.test_on_batch(train_generator[0]) 998 | # Load weights from the pretrained model 999 | training_model.load_weights(pretrained_model_file_path) 1000 | if evaluation_only: 1001 | print("Freezing the whole model in the evaluation_only mode ...") 1002 | training_model.trainable = False 1003 | training_model.compile(**training_model.compile_kwargs) 1004 | 1005 | assert testing_size == 1, "Use all testing samples for evaluation!" 1006 | historylogger_callback = HistoryLogger( 1007 | output_folder_path=os.path.join(output_folder_path, "evaluation")) 1008 | training_model.fit(x=train_generator, 1009 | steps_per_epoch=1, 1010 | callbacks=[ 1011 | inspect_regularization_factors_callback, 1012 | valid_evaluator_callback, 1013 | test_evaluator_callback, historylogger_callback 1014 | ], 1015 | epochs=1, 1016 | workers=workers, 1017 | use_multiprocessing=use_multiprocessing, 1018 | verbose=2) 1019 | else: 1020 | if freeze_backbone_for_N_epochs > 0: 1021 | print("Freeze layers in the backbone model for {} epochs.".format( 1022 | freeze_backbone_for_N_epochs)) 1023 | historylogger_callback = HistoryLogger( 1024 | output_folder_path=os.path.join(output_folder_path, 1025 | "training_A")) 1026 | training_model.fit(x=train_generator, 1027 | steps_per_epoch=steps_per_epoch, 1028 | callbacks=[ 1029 | valid_evaluator_callback, 1030 | test_evaluator_callback, 1031 | learningratescheduler_callback, 1032 | historylogger_callback 1033 | ], 1034 | epochs=freeze_backbone_for_N_epochs, 1035 | workers=workers, 1036 | use_multiprocessing=use_multiprocessing, 1037 | verbose=2) 1038 | 1039 | print("Unfreeze layers in the backbone model.") 1040 | for item in training_model.layers: 1041 | item.trainable = True 1042 | training_model.compile(**training_model.compile_kwargs) 1043 | 1044 | print("Perform conventional training for {} epochs.".format(epoch_num)) 1045 | historylogger_callback = HistoryLogger( 1046 | output_folder_path=os.path.join(output_folder_path, "training_B")) 1047 | training_model.fit(x=train_generator, 1048 | steps_per_epoch=steps_per_epoch, 1049 | callbacks=[ 1050 | inspect_regularization_factors_callback, 1051 | valid_evaluator_callback, 1052 | test_evaluator_callback, 1053 | modelcheckpoint_callback, 1054 | learningratescheduler_callback, 1055 | historylogger_callback 1056 | ], 1057 | epochs=epoch_num, 1058 | workers=workers, 1059 | use_multiprocessing=use_multiprocessing, 1060 | verbose=2) 1061 | 1062 | if not os.path.isfile(optimal_model_file_path): 1063 | print("Saving model to {} ...".format(optimal_model_file_path)) 1064 | training_model.save(optimal_model_file_path) 1065 | 1066 | print("All done!") 1067 | 1068 | 1069 | if __name__ == "__main__": 1070 | app.run(main) 1071 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from tensorflow.keras.layers import BatchNormalization, Conv1D, Conv2D, Dense 5 | from tensorflow.keras.models import Model, clone_model, model_from_json 6 | from tensorflow.keras.regularizers import l2 7 | 8 | sys.path.append(os.path.abspath(os.path.join(__file__, "../.."))) 9 | from regularizers.adaptation import AdaptiveL1L2 10 | 11 | 12 | def replicate_model(model, name): 13 | vanilla_weights = model.get_weights() 14 | model = clone_model(model) 15 | model = Model(inputs=model.input, outputs=model.output, name=name) 16 | model.set_weights(vanilla_weights) 17 | return model 18 | 19 | 20 | def specify_regularizers(model, 21 | kernel_regularization_factor=0, 22 | bias_regularization_factor=0, 23 | gamma_regularization_factor=0, 24 | beta_regularization_factor=0, 25 | use_adaptive_l1_l2_regularizer=False, 26 | omitted_layer_name_prefix_tuple=()): 27 | 28 | def _init_regularizers(model, kernel_regularization_factor, 29 | bias_regularization_factor, 30 | gamma_regularization_factor, 31 | beta_regularization_factor, 32 | use_adaptive_l1_l2_regularizer, 33 | omitted_layer_name_prefix_tuple): 34 | print("Initializing regularizers for {} ...".format(model.name)) 35 | for item in model.layers: 36 | if isinstance(item, Model): 37 | _init_regularizers(item, kernel_regularization_factor, 38 | bias_regularization_factor, 39 | gamma_regularization_factor, 40 | beta_regularization_factor, 41 | use_adaptive_l1_l2_regularizer, 42 | omitted_layer_name_prefix_tuple) 43 | omit_layer = False 44 | for omitted_layer_name_prefix in omitted_layer_name_prefix_tuple: 45 | if item.name.startswith(omitted_layer_name_prefix): 46 | omit_layer = True 47 | break 48 | if omit_layer: 49 | continue 50 | if isinstance(item, (Conv1D, Conv2D, Dense)): 51 | if kernel_regularization_factor >= 0 and hasattr( 52 | item, "kernel_regularizer"): 53 | item.kernel_regularizer = AdaptiveL1L2( 54 | amplitude_l2=kernel_regularization_factor 55 | ) if use_adaptive_l1_l2_regularizer else l2( 56 | l=kernel_regularization_factor) 57 | if bias_regularization_factor >= 0 and hasattr( 58 | item, "bias_regularizer"): 59 | if item.use_bias: 60 | item.bias_regularizer = AdaptiveL1L2( 61 | amplitude_l2=bias_regularization_factor 62 | ) if use_adaptive_l1_l2_regularizer else l2( 63 | l=bias_regularization_factor) 64 | elif isinstance(item, BatchNormalization): 65 | if gamma_regularization_factor >= 0 and hasattr( 66 | item, "gamma_regularizer"): 67 | if item.scale: 68 | item.gamma_regularizer = AdaptiveL1L2( 69 | amplitude_l2=gamma_regularization_factor 70 | ) if use_adaptive_l1_l2_regularizer else l2( 71 | l=gamma_regularization_factor) 72 | if beta_regularization_factor >= 0 and hasattr( 73 | item, "beta_regularizer"): 74 | if item.center: 75 | item.beta_regularizer = AdaptiveL1L2( 76 | amplitude_l2=beta_regularization_factor 77 | ) if use_adaptive_l1_l2_regularizer else l2( 78 | l=beta_regularization_factor) 79 | 80 | # Initialize regularizers 81 | _init_regularizers(model, kernel_regularization_factor, 82 | bias_regularization_factor, gamma_regularization_factor, 83 | beta_regularization_factor, 84 | use_adaptive_l1_l2_regularizer, 85 | omitted_layer_name_prefix_tuple) 86 | 87 | # Reload the model 88 | # https://github.com/keras-team/keras/issues/2717#issuecomment-447570737 89 | vanilla_weights = model.get_weights() 90 | model = model_from_json(json_string=model.to_json(), 91 | custom_objects={"AdaptiveL1L2": AdaptiveL1L2}) 92 | model.set_weights(vanilla_weights) 93 | 94 | return model 95 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras.utils import plot_model 6 | 7 | 8 | def summarize_model(model): 9 | # Summarize the model at hand 10 | identifier = "{}_{}".format(model.name, id(model)) 11 | print("Summarizing {} ...".format(identifier)) 12 | model.summary() 13 | 14 | # Summarize submodels 15 | for item in model.layers: 16 | if isinstance(item, Model): 17 | summarize_model(item) 18 | 19 | 20 | def visualize_model(model, output_folder_path): 21 | # Visualize the model at hand 22 | identifier = "{}_{}".format(model.name, id(model)) 23 | print("Visualizing {} ...".format(identifier)) 24 | try: 25 | # TODO: Wait for patches from upstream. 26 | # https://github.com/tensorflow/tensorflow/issues/38988 27 | model._layers = [ # pylint: disable=protected-access 28 | item for item in model._layers # pylint: disable=protected-access 29 | if isinstance(item, Layer) 30 | ] 31 | plot_model(model, 32 | show_shapes=True, 33 | show_layer_names=True, 34 | to_file=os.path.join(output_folder_path, 35 | "{}.png".format(identifier))) 36 | except Exception as exception: # pylint: disable=broad-except 37 | print(exception) 38 | print("Failed to plot {}.".format(identifier)) 39 | 40 | # Visualize submodels 41 | for item in model.layers: 42 | if isinstance(item, Model): 43 | visualize_model(item, output_folder_path) 44 | --------------------------------------------------------------------------------