├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── ThirdPartyNotices.txt ├── configs ├── example.yaml ├── example_margin.yaml └── example_resnet50.yaml ├── misc └── ms_loss.png ├── requirements.txt ├── ret_benchmark ├── config │ ├── __init__.py │ ├── defaults.py │ └── model_path.py ├── data │ ├── __init__.py │ ├── build.py │ ├── collate_batch.py │ ├── datasets │ │ ├── __init__.py │ │ └── base_dataset.py │ ├── evaluations │ │ ├── __init__.py │ │ └── ret_metric.py │ ├── samplers │ │ ├── __init__.py │ │ └── random_identity_sampler.py │ └── transforms │ │ ├── __init__.py │ │ └── build.py ├── engine │ ├── __init__.py │ └── trainer.py ├── losses │ ├── __init__.py │ ├── build.py │ ├── margin_loss.py │ ├── multi_similarity_loss.py │ └── registry.py ├── modeling │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── bninception.py │ │ ├── build.py │ │ └── resnet.py │ ├── build.py │ ├── heads │ │ ├── __init__.py │ │ ├── build.py │ │ └── linear_norm.py │ ├── registry.py │ └── xbm.py ├── solver │ ├── __init__.py │ ├── build.py │ └── lr_scheduler.py └── utils │ ├── checkpoint.py │ ├── config_util.py │ ├── feat_extractor.py │ ├── freeze_bn.py │ ├── img_reader.py │ ├── init_methods.py │ ├── logger.py │ ├── metric_logger.py │ ├── model_serialization.py │ └── registry.py ├── scripts ├── prepare_cub.sh ├── run_cub.sh ├── run_cub_margin.sh └── split_cub_for_ms_loss.py ├── setup.py └── tools └── main.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = F401, F841, E402, E722, E999 3 | max-line-length = 128 4 | max-complexity=18 5 | format=pylint 6 | show_source = True 7 | statistics = True 8 | count = True 9 | exclude = tests,ret_benchmark/modeling/backbone -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | resource 2 | build 3 | *.pyc 4 | *.zip 5 | */__pycache__ 6 | __pycache__ 7 | 8 | # Package Files # 9 | *.pkl 10 | *.log 11 | *.jar 12 | *.war 13 | *.nar 14 | *.ear 15 | *.zip 16 | *.tar.gz 17 | *.rar 18 | *.egg-info 19 | 20 | #some local files 21 | */.settings/ 22 | */.DS_Store 23 | .DS_Store 24 | */.idea/ 25 | .idea/ 26 | gradlew 27 | gradlew.bat 28 | unused.txt 29 | output/ 30 | *.egg-info/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC-4.0) 2 | Public License 3 | 4 | For Multi-Similarity Loss for Deep Metric Learning (MS-Loss) 5 | 6 | Copyright (c) 2014-present, Malong Technologies Co., Ltd. All rights reserved. 7 | 8 | 9 | By exercising the Licensed Rights (defined below), You accept and agree 10 | to be bound by the terms and conditions of this Creative Commons 11 | Attribution-NonCommercial 4.0 International Public License ("Public 12 | License"). To the extent this Public License may be interpreted as a 13 | contract, You are granted the Licensed Rights in consideration of Your 14 | acceptance of these terms and conditions, and the Licensor grants You 15 | such rights in consideration of benefits the Licensor receives from 16 | making the Licensed Material available under these terms and 17 | conditions. 18 | 19 | 20 | Section 1 -- Definitions. 21 | 22 | a. Adapted Material means material subject to Copyright and Similar 23 | Rights that is derived from or based upon the Licensed Material 24 | and in which the Licensed Material is translated, altered, 25 | arranged, transformed, or otherwise modified in a manner requiring 26 | permission under the Copyright and Similar Rights held by the 27 | Licensor. For purposes of this Public License, where the Licensed 28 | Material is a musical work, performance, or sound recording, 29 | Adapted Material is always produced where the Licensed Material is 30 | synched in timed relation with a moving image. 31 | 32 | b. Adapter's License means the license You apply to Your Copyright 33 | and Similar Rights in Your contributions to Adapted Material in 34 | accordance with the terms and conditions of this Public License. 35 | 36 | c. Copyright and Similar Rights means copyright and/or similar rights 37 | closely related to copyright including, without limitation, 38 | performance, broadcast, sound recording, and Sui Generis Database 39 | Rights, without regard to how the rights are labeled or 40 | categorized. For purposes of this Public License, the rights 41 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 42 | Rights. 43 | d. Effective Technological Measures means those measures that, in the 44 | absence of proper authority, may not be circumvented under laws 45 | fulfilling obligations under Article 11 of the WIPO Copyright 46 | Treaty adopted on December 20, 1996, and/or similar international 47 | agreements. 48 | 49 | e. Exceptions and Limitations means fair use, fair dealing, and/or 50 | any other exception or limitation to Copyright and Similar Rights 51 | that applies to Your use of the Licensed Material. 52 | 53 | f. Licensed Material means the artistic or literary work, database, 54 | or other material to which the Licensor applied this Public 55 | License. 56 | 57 | g. Licensed Rights means the rights granted to You subject to the 58 | terms and conditions of this Public License, which are limited to 59 | all Copyright and Similar Rights that apply to Your use of the 60 | Licensed Material and that the Licensor has authority to license. 61 | 62 | h. Licensor means the individual(s) or entity(ies) granting rights 63 | under this Public License. 64 | 65 | i. NonCommercial means not primarily intended for or directed towards 66 | commercial advantage or monetary compensation. For purposes of 67 | this Public License, the exchange of the Licensed Material for 68 | other material subject to Copyright and Similar Rights by digital 69 | file-sharing or similar means is NonCommercial provided there is 70 | no payment of monetary compensation in connection with the 71 | exchange. 72 | 73 | j. Share means to provide material to the public by any means or 74 | process that requires permission under the Licensed Rights, such 75 | as reproduction, public display, public performance, distribution, 76 | dissemination, communication, or importation, and to make material 77 | available to the public including in ways that members of the 78 | public may access the material from a place and at a time 79 | individually chosen by them. 80 | 81 | k. Sui Generis Database Rights means rights other than copyright 82 | resulting from Directive 96/9/EC of the European Parliament and of 83 | the Council of 11 March 1996 on the legal protection of databases, 84 | as amended and/or succeeded, as well as other essentially 85 | equivalent rights anywhere in the world. 86 | 87 | l. You means the individual or entity exercising the Licensed Rights 88 | under this Public License. Your has a corresponding meaning. 89 | 90 | 91 | Section 2 -- Scope. 92 | 93 | a. License grant. 94 | 95 | 1. Subject to the terms and conditions of this Public License, 96 | the Licensor hereby grants You a worldwide, royalty-free, 97 | non-sublicensable, non-exclusive, irrevocable license to 98 | exercise the Licensed Rights in the Licensed Material to: 99 | 100 | a. reproduce and Share the Licensed Material, in whole or 101 | in part, for NonCommercial purposes only; and 102 | 103 | b. produce, reproduce, and Share Adapted Material for 104 | NonCommercial purposes only. 105 | 106 | 2. Exceptions and Limitations. For the avoidance of doubt, where 107 | Exceptions and Limitations apply to Your use, this Public 108 | License does not apply, and You do not need to comply with 109 | its terms and conditions. 110 | 111 | 3. Term. The term of this Public License is specified in Section 112 | 6(a). 113 | 114 | 4. Media and formats; technical modifications allowed. The 115 | Licensor authorizes You to exercise the Licensed Rights in 116 | all media and formats whether now known or hereafter created, 117 | and to make technical modifications necessary to do so. The 118 | Licensor waives and/or agrees not to assert any right or 119 | authority to forbid You from making technical modifications 120 | necessary to exercise the Licensed Rights, including 121 | technical modifications necessary to circumvent Effective 122 | Technological Measures. For purposes of this Public License, 123 | simply making modifications authorized by this Section 2(a) 124 | (4) never produces Adapted Material. 125 | 126 | 5. Downstream recipients. 127 | 128 | a. Offer from the Licensor -- Licensed Material. Every 129 | recipient of the Licensed Material automatically 130 | receives an offer from the Licensor to exercise the 131 | Licensed Rights under the terms and conditions of this 132 | Public License. 133 | 134 | b. No downstream restrictions. You may not offer or impose 135 | any additional or different terms or conditions on, or 136 | apply any Effective Technological Measures to, the 137 | Licensed Material if doing so restricts exercise of the 138 | Licensed Rights by any recipient of the Licensed 139 | Material. 140 | 141 | 6. No endorsement. Nothing in this Public License constitutes or 142 | may be construed as permission to assert or imply that You 143 | are, or that Your use of the Licensed Material is, connected 144 | with, or sponsored, endorsed, or granted official status by, 145 | the Licensor or others designated to receive attribution as 146 | provided in Section 3(a)(1)(A)(i). 147 | 148 | b. Other rights. 149 | 150 | 1. Moral rights, such as the right of integrity, are not 151 | licensed under this Public License, nor are publicity, 152 | privacy, and/or other similar personality rights; however, to 153 | the extent possible, the Licensor waives and/or agrees not to 154 | assert any such rights held by the Licensor to the limited 155 | extent necessary to allow You to exercise the Licensed 156 | Rights, but not otherwise. 157 | 158 | 2. Patent and trademark rights are not licensed under this 159 | Public License. 160 | 161 | 3. To the extent possible, the Licensor waives any right to 162 | collect royalties from You for the exercise of the Licensed 163 | Rights, whether directly or through a collecting society 164 | under any voluntary or waivable statutory or compulsory 165 | licensing scheme. In all other cases the Licensor expressly 166 | reserves any right to collect such royalties, including when 167 | the Licensed Material is used other than for NonCommercial 168 | purposes. 169 | 170 | 171 | Section 3 -- License Conditions. 172 | 173 | Your exercise of the Licensed Rights is expressly made subject to the 174 | following conditions. 175 | 176 | a. Attribution. 177 | 178 | 1. If You Share the Licensed Material (including in modified 179 | form), You must: 180 | 181 | a. retain the following if it is supplied by the Licensor 182 | with the Licensed Material: 183 | 184 | i. identification of the creator(s) of the Licensed 185 | Material and any others designated to receive 186 | attribution, in any reasonable manner requested by 187 | the Licensor (including by pseudonym if 188 | designated); 189 | 190 | ii. a copyright notice; 191 | 192 | iii. a notice that refers to this Public License; 193 | 194 | iv. a notice that refers to the disclaimer of 195 | warranties; 196 | 197 | v. a URI or hyperlink to the Licensed Material to the 198 | extent reasonably practicable; 199 | 200 | b. indicate if You modified the Licensed Material and 201 | retain an indication of any previous modifications; and 202 | 203 | c. indicate the Licensed Material is licensed under this 204 | Public License, and include the text of, or the URI or 205 | hyperlink to, this Public License. 206 | 207 | 2. You may satisfy the conditions in Section 3(a)(1) in any 208 | reasonable manner based on the medium, means, and context in 209 | which You Share the Licensed Material. For example, it may be 210 | reasonable to satisfy the conditions by providing a URI or 211 | hyperlink to a resource that includes the required 212 | information. 213 | 214 | 3. If requested by the Licensor, You must remove any of the 215 | information required by Section 3(a)(1)(A) to the extent 216 | reasonably practicable. 217 | 218 | 4. If You Share Adapted Material You produce, the Adapter's 219 | License You apply must not prevent recipients of the Adapted 220 | Material from complying with this Public License. 221 | 222 | 223 | Section 4 -- Sui Generis Database Rights. 224 | 225 | Where the Licensed Rights include Sui Generis Database Rights that 226 | apply to Your use of the Licensed Material: 227 | 228 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 229 | to extract, reuse, reproduce, and Share all or a substantial 230 | portion of the contents of the database for NonCommercial purposes 231 | only; 232 | 233 | b. if You include all or a substantial portion of the database 234 | contents in a database in which You have Sui Generis Database 235 | Rights, then the database in which You have Sui Generis Database 236 | Rights (but not its individual contents) is Adapted Material; and 237 | 238 | c. You must comply with the conditions in Section 3(a) if You Share 239 | all or a substantial portion of the contents of the database. 240 | 241 | For the avoidance of doubt, this Section 4 supplements and does not 242 | replace Your obligations under this Public License where the Licensed 243 | Rights include other Copyright and Similar Rights. 244 | 245 | 246 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 247 | 248 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 249 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 250 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 251 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 252 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 253 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 254 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 255 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 256 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 257 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 258 | 259 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 260 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 261 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 262 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 263 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 264 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 265 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 266 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 267 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 268 | 269 | c. The disclaimer of warranties and limitation of liability provided 270 | above shall be interpreted in a manner that, to the extent 271 | possible, most closely approximates an absolute disclaimer and 272 | waiver of all liability. 273 | 274 | 275 | Section 6 -- Term and Termination. 276 | 277 | a. This Public License applies for the term of the Copyright and 278 | Similar Rights licensed here. However, if You fail to comply with 279 | this Public License, then Your rights under this Public License 280 | terminate automatically. 281 | 282 | b. Where Your right to use the Licensed Material has terminated under 283 | Section 6(a), it reinstates: 284 | 285 | 1. automatically as of the date the violation is cured, provided 286 | it is cured within 30 days of Your discovery of the 287 | violation; or 288 | 289 | 2. upon express reinstatement by the Licensor. 290 | 291 | For the avoidance of doubt, this Section 6(b) does not affect any 292 | right the Licensor may have to seek remedies for Your violations 293 | of this Public License. 294 | 295 | c. For the avoidance of doubt, the Licensor may also offer the 296 | Licensed Material under separate terms or conditions or stop 297 | distributing the Licensed Material at any time; however, doing so 298 | will not terminate this Public License. 299 | 300 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 301 | License. 302 | 303 | 304 | Section 7 -- Other Terms and Conditions. 305 | 306 | a. The Licensor shall not be bound by any additional or different 307 | terms or conditions communicated by You unless expressly agreed. 308 | 309 | b. Any arrangements, understandings, or agreements regarding the 310 | Licensed Material not stated herein are separate from and 311 | independent of the terms and conditions of this Public License. 312 | 313 | 314 | Section 8 -- Interpretation. 315 | 316 | a. For the avoidance of doubt, this Public License does not, and 317 | shall not be interpreted to, reduce, limit, restrict, or impose 318 | conditions on any use of the Licensed Material that could lawfully 319 | be made without permission under this Public License. 320 | 321 | b. To the extent possible, if any provision of this Public License is 322 | deemed unenforceable, it shall be automatically reformed to the 323 | minimum extent necessary to make it enforceable. If the provision 324 | cannot be reformed, it shall be severed from this Public License 325 | without affecting the enforceability of the remaining terms and 326 | conditions. 327 | 328 | c. No term or condition of this Public License will be waived and no 329 | failure to comply consented to unless expressly agreed to by the 330 | Licensor. 331 | 332 | d. Nothing in this Public License constitutes or may be interpreted 333 | as a limitation upon, or waiver of, any privileges and immunities 334 | that apply to the Licensor or You, including from the legal 335 | processes of any jurisdiction or authority. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: CC BY-NC 4.0](https://licensebuttons.net/l/by-nc/4.0/80x15.png)](https://creativecommons.org/licenses/by-nc/4.0/) 2 | 3 | 4 | # Multi-Similarity Loss for Deep Metric Learning (MS-Loss) 5 | 6 | Code for the CVPR 2019 paper [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) 7 | 8 | 9 | 10 | ### Performance compared with SOTA methods on CUB-200-2011 11 | 12 | |Rank@K | 1 | 2 | 4 | 8 | 16 | 32 | 13 | |:--- |:-:|:-:|:-:|:-:|:-: |:-: | 14 | |Clustering64 | 48.2 | 61.4 | 71.8 | 81.9 | - | - | 15 | |ProxyNCA64 | 49.2 | 61.9 | 67.9 | 72.4 | - | - | 16 | |Smart Mining64 | 49.8 | 62.3 | 74.1 | 83.3 | - | 17 | |Our MS-Loss64| **57.4** |**69.8** |**80.0** |**87.8** |93.2 |96.4| 18 | |HTL512 | 57.1| 68.8| 78.7| 86.5| 92.5| 95.5 | 19 | |ABIER512 |57.5 |68.7 |78.3 |86.2 |91.9 |95.5 | 20 | |Our MS-Loss512|**65.7** |**77.0** |**86.3**|**91.2** |**95.0** |**97.3**| 21 | 22 | 23 | ### Prepare the data and the pretrained model 24 | 25 | The following script will prepare the [CUB](http://www.vision.caltech.edu.s3-us-west-2.amazonaws.com/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) dataset for training by downloading to the ./resource/datasets/ folder; which will then build the data list (train.txt test.txt): 26 | 27 | ```bash 28 | ./scripts/prepare_cub.sh 29 | ``` 30 | 31 | Download the imagenet pretrained model of 32 | [bninception](http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth) and put it in the folder: ~/.torch/models/. 33 | 34 | 35 | ### Installation 36 | 37 | ```bash 38 | pip install -r requirements.txt 39 | python setup.py develop build 40 | ``` 41 | ### Train and Test on CUB200-2011 with MS-Loss 42 | 43 | ```bash 44 | ./scripts/run_cub.sh 45 | ``` 46 | Trained models will be saved in the ./output/ folder if using the default config. 47 | 48 | Best recall@1 higher than 66 (65.7 in the paper). 49 | 50 | ### Contact 51 | 52 | For any questions, please feel free to reach 53 | ``` 54 | github@malongtech.com 55 | ``` 56 | 57 | ### Citation 58 | 59 | If you use this method or this code in your research, please cite as: 60 | 61 | @inproceedings{wang2019multi, 62 | title={Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning}, 63 | author={Wang, Xun and Han, Xintong and Huang, Weilin and Dong, Dengke and Scott, Matthew R}, 64 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 65 | pages={5022--5030}, 66 | year={2019} 67 | } 68 | 69 | ## License 70 | 71 | MS-Loss is CC-BY-NC 4.0 licensed, as found in the [LICENSE](LICENSE) file. It is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact sales@malongtech.com. 72 | 73 | -------------------------------------------------------------------------------- /ThirdPartyNotices.txt: -------------------------------------------------------------------------------- 1 | THIRD PARTY SOFTWARE NOTICES AND INFORMATION 2 | 3 | Do Not Translate or Localize 4 | 5 | This software incorporates material from the following third parties. 6 | 7 | _____ 8 | 9 | Cadene/pretrained-models.pytorch 10 | 11 | BSD 3-Clause License 12 | 13 | Copyright (c) 2017, Remi Cadene 14 | All rights reserved. 15 | 16 | Redistribution and use in source and binary forms, with or without 17 | modification, are permitted provided that the following conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright notice, this 20 | list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright notice, 23 | this list of conditions and the following disclaimer in the documentation 24 | and/or other materials provided with the distribution. 25 | 26 | * Neither the name of the copyright holder nor the names of its 27 | contributors may be used to endorse or promote products derived from 28 | this software without specific prior written permission. 29 | 30 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 40 | 41 | _____ 42 | 43 | facebookresearch/maskrcnn-benchmark 44 | 45 | MIT License 46 | 47 | Copyright (c) 2018 Facebook 48 | 49 | Permission is hereby granted, free of charge, to any person obtaining a copy 50 | of this software and associated documentation files (the "Software"), to deal 51 | in the Software without restriction, including without limitation the rights 52 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 53 | copies of the Software, and to permit persons to whom the Software is 54 | furnished to do so, subject to the following conditions: 55 | 56 | The above copyright notice and this permission notice shall be included in all 57 | copies or substantial portions of the Software. 58 | 59 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 60 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 61 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 62 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 63 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 64 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 65 | SOFTWARE. -------------------------------------------------------------------------------- /configs/example.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: bninception 11 | 12 | SOLVER: 13 | MAX_ITERS: 3000 14 | STEPS: [1200, 2400] 15 | OPTIMIZER_NAME: Adam 16 | BASE_LR: 0.00003 17 | WARMUP_ITERS: 0 18 | WEIGHT_DECAY: 0.0005 19 | 20 | DATA: 21 | TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt 22 | TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt 23 | TRAIN_BATCHSIZE: 80 24 | TEST_BATCHSIZE: 256 25 | NUM_WORKERS: 8 26 | NUM_INSTANCES: 5 27 | 28 | VALIDATION: 29 | VERBOSE: 200 -------------------------------------------------------------------------------- /configs/example_margin.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: bninception 11 | 12 | LOSSES: 13 | NAME: margin_loss 14 | MARGIN_LOSS: 15 | N_CLASSES: 100 16 | BETA_CONSTANT: False # if False (i.e. class specific beta) train.txt should have labels 0 .... N_CLASSES -1 17 | 18 | SOLVER: 19 | MAX_ITERS: 3000 20 | STEPS: [1200, 2400] 21 | OPTIMIZER_NAME: Adam 22 | BASE_LR: 0.00003 23 | WARMUP_ITERS: 0 24 | WEIGHT_DECAY: 0.0005 25 | 26 | DATA: 27 | TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt 28 | TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt 29 | TRAIN_BATCHSIZE: 120 30 | TEST_BATCHSIZE: 256 31 | NUM_WORKERS: 8 32 | NUM_INSTANCES: 5 33 | 34 | VALIDATION: 35 | VERBOSE: 200 36 | 37 | SAVE_DIR: output_margin 38 | 39 | -------------------------------------------------------------------------------- /configs/example_resnet50.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: resnet50 11 | 12 | INPUT: 13 | MODE: 'RGB' 14 | PIXEL_MEAN: [0.485, 0.456, 0.406] 15 | PIXEL_STD: [0.229, 0.224, 0.225] 16 | 17 | SOLVER: 18 | MAX_ITERS: 3000 19 | STEPS: [1200, 2400] 20 | OPTIMIZER_NAME: Adam 21 | BASE_LR: 0.00003 22 | WARMUP_ITERS: 0 23 | WEIGHT_DECAY: 0.0005 24 | 25 | DATA: 26 | TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt 27 | TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt 28 | TRAIN_BATCHSIZE: 80 29 | TEST_BATCHSIZE: 256 30 | NUM_WORKERS: 8 31 | NUM_INSTANCES: 5 32 | 33 | VALIDATION: 34 | VERBOSE: 200 35 | -------------------------------------------------------------------------------- /misc/ms_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msight-tech/research-ms-loss/b68507d4e22d8a6d3d3c0e6c31be708f9dcd20ee/misc/ms_loss.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | numpy==1.15.4 3 | yacs==0.1.4 4 | setuptools==40.6.2 5 | pytest==4.4.0 6 | Pillow==8.3.2 7 | torchvision==0.3.0 8 | -------------------------------------------------------------------------------- /ret_benchmark/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .defaults import _C as cfg 9 | -------------------------------------------------------------------------------- /ret_benchmark/config/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from yacs.config import CfgNode as CN 9 | from .model_path import MODEL_PATH 10 | 11 | # ----------------------------------------------------------------------------- 12 | # Config definition 13 | # ----------------------------------------------------------------------------- 14 | 15 | _C = CN() 16 | 17 | _C.MODEL = CN() 18 | _C.MODEL.DEVICE = "cuda" 19 | 20 | _C.MODEL.BACKBONE = CN() 21 | _C.MODEL.BACKBONE.NAME = "bninception" 22 | 23 | _C.MODEL.PRETRAIN = 'imagenet' 24 | _C.MODEL.PRETRIANED_PATH = MODEL_PATH 25 | 26 | _C.MODEL.HEAD = CN() 27 | _C.MODEL.HEAD.NAME = "linear_norm" 28 | _C.MODEL.HEAD.DIM = 512 29 | 30 | _C.MODEL.WEIGHT = "" 31 | 32 | # Checkpoint save dir 33 | _C.SAVE_DIR = 'output' 34 | 35 | # Loss 36 | _C.LOSSES = CN() 37 | _C.LOSSES.NAME = 'ms_loss' 38 | 39 | # ms loss 40 | _C.LOSSES.MULTI_SIMILARITY_LOSS = CN() 41 | _C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS = 2.0 42 | _C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG = 40.0 43 | _C.LOSSES.MULTI_SIMILARITY_LOSS.HARD_MINING = True 44 | 45 | # margin loss 46 | _C.LOSSES.MARGIN_LOSS = CN() 47 | _C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False 48 | _C.LOSSES.MARGIN_LOSS.N_CLASSES = 100 49 | _C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False 50 | _C.LOSSES.MARGIN_LOSS.CUTOFF = 0.5 51 | _C.LOSSES.MARGIN_LOSS.UPPER_CUTOFF = 1.4 52 | 53 | # Data option 54 | _C.DATA = CN() 55 | _C.DATA.TRAIN_IMG_SOURCE = 'resource/datasets/CUB_200_2011/train.txt' 56 | _C.DATA.TEST_IMG_SOURCE = 'resource/datasets/CUB_200_2011/test.txt' 57 | _C.DATA.TRAIN_BATCHSIZE = 70 58 | _C.DATA.TEST_BATCHSIZE = 256 59 | _C.DATA.NUM_WORKERS = 8 60 | _C.DATA.NUM_INSTANCES = 5 61 | 62 | # Input option 63 | _C.INPUT = CN() 64 | 65 | # INPUT CONFIG 66 | _C.INPUT.MODE = 'BGR' 67 | _C.INPUT.PIXEL_MEAN = [104. / 255, 117. / 255, 128. / 255] 68 | _C.INPUT.PIXEL_STD = 3 * [1. / 255] 69 | 70 | _C.INPUT.FLIP_PROB = 0.5 71 | _C.INPUT.ORIGIN_SIZE = 256 72 | _C.INPUT.CROP_SCALE = [0.16, 1] 73 | _C.INPUT.CROP_SIZE = 227 74 | 75 | # SOLVER 76 | _C.SOLVER = CN() 77 | _C.SOLVER.IS_FINETURN = False 78 | _C.SOLVER.FINETURN_MODE_PATH = '' 79 | _C.SOLVER.MAX_ITERS = 4000 80 | _C.SOLVER.STEPS = [1000, 2000, 3000] 81 | _C.SOLVER.OPTIMIZER_NAME = 'SGD' 82 | _C.SOLVER.BASE_LR = 0.01 83 | _C.SOLVER.BIAS_LR_FACTOR = 1 84 | _C.SOLVER.WEIGHT_DECAY = 0.0005 85 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005 86 | _C.SOLVER.MOMENTUM = 0.9 87 | _C.SOLVER.GAMMA = 0.1 88 | _C.SOLVER.WARMUP_FACTOR = 0.01 89 | _C.SOLVER.WARMUP_ITERS = 200 90 | _C.SOLVER.WARMUP_METHOD = 'linear' 91 | _C.SOLVER.CHECKPOINT_PERIOD = 200 92 | _C.SOLVER.RNG_SEED = 1 93 | 94 | # Logger 95 | _C.LOGGER = CN() 96 | _C.LOGGER.LEVEL = 20 97 | _C.LOGGER.STREAM = 'stdout' 98 | 99 | # Validation 100 | _C.VALIDATION = CN() 101 | _C.VALIDATION.VERBOSE = 200 102 | _C.VALIDATION.IS_VALIDATION = True 103 | -------------------------------------------------------------------------------- /ret_benchmark/config/model_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | # ----------------------------------------------------------------------------- 9 | # Config definition of imagenet pretrained model path 10 | # ----------------------------------------------------------------------------- 11 | 12 | 13 | from yacs.config import CfgNode as CN 14 | 15 | MODEL_PATH = { 16 | 'bninception': "~/.torch/models/bn_inception-52deb4733.pth", 17 | 'resnet50': "~/.torch/models/resnet50-19c8e357.pth", 18 | } 19 | 20 | MODEL_PATH = CN(MODEL_PATH) 21 | -------------------------------------------------------------------------------- /ret_benchmark/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .build import build_data 9 | -------------------------------------------------------------------------------- /ret_benchmark/data/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from torch.utils.data import DataLoader 9 | 10 | from .collate_batch import collate_fn 11 | from .datasets import BaseDataSet 12 | from .samplers import RandomIdentitySampler 13 | from .transforms import build_transforms 14 | 15 | 16 | def build_data(cfg, is_train=True): 17 | transforms = build_transforms(cfg, is_train=is_train) 18 | if is_train: 19 | dataset = BaseDataSet(cfg.DATA.TRAIN_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE) 20 | sampler = RandomIdentitySampler(dataset=dataset, 21 | batch_size=cfg.DATA.TRAIN_BATCHSIZE, 22 | num_instances=cfg.DATA.NUM_INSTANCES, 23 | max_iters=cfg.SOLVER.MAX_ITERS 24 | ) 25 | data_loader = DataLoader(dataset, 26 | collate_fn=collate_fn, 27 | batch_sampler=sampler, 28 | num_workers=cfg.DATA.NUM_WORKERS, 29 | pin_memory=True 30 | ) 31 | else: 32 | dataset = BaseDataSet(cfg.DATA.TEST_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE) 33 | data_loader = DataLoader(dataset, 34 | collate_fn=collate_fn, 35 | shuffle=False, 36 | batch_size=cfg.DATA.TEST_BATCHSIZE, 37 | num_workers=cfg.DATA.NUM_WORKERS 38 | ) 39 | return data_loader 40 | -------------------------------------------------------------------------------- /ret_benchmark/data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | 10 | 11 | def collate_fn(batch): 12 | imgs, labels = zip(*batch) 13 | labels = [int(k) for k in labels] 14 | labels = torch.tensor(labels, dtype=torch.int64) 15 | return torch.stack(imgs, dim=0), labels 16 | -------------------------------------------------------------------------------- /ret_benchmark/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .base_dataset import BaseDataSet 9 | -------------------------------------------------------------------------------- /ret_benchmark/data/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import re 10 | from collections import defaultdict 11 | 12 | from torch.utils.data import Dataset 13 | from ret_benchmark.utils.img_reader import read_image 14 | 15 | 16 | class BaseDataSet(Dataset): 17 | """ 18 | Basic Dataset read image path from img_source 19 | img_source: list of img_path and label 20 | """ 21 | 22 | def __init__(self, img_source, transforms=None, mode="RGB"): 23 | self.mode = mode 24 | self.transforms = transforms 25 | self.root = os.path.dirname(img_source) 26 | assert os.path.exists(img_source), f"{img_source} NOT found." 27 | self.img_source = img_source 28 | 29 | self.label_list = list() 30 | self.path_list = list() 31 | self._load_data() 32 | self.label_index_dict = self._build_label_index_dict() 33 | 34 | def __len__(self): 35 | return len(self.label_list) 36 | 37 | def __repr__(self): 38 | return self.__str__() 39 | 40 | def __str__(self): 41 | return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|" 42 | 43 | def _load_data(self): 44 | with open(self.img_source, 'r') as f: 45 | for line in f: 46 | _path, _label = re.split(r",| ", line.strip()) 47 | self.path_list.append(_path) 48 | self.label_list.append(_label) 49 | 50 | def _build_label_index_dict(self): 51 | index_dict = defaultdict(list) 52 | for i, label in enumerate(self.label_list): 53 | index_dict[label].append(i) 54 | return index_dict 55 | 56 | def __getitem__(self, index): 57 | path = self.path_list[index] 58 | img_path = os.path.join(self.root, path) 59 | label = self.label_list[index] 60 | 61 | img = read_image(img_path, mode=self.mode) 62 | if self.transforms is not None: 63 | img = self.transforms(img) 64 | return img, label 65 | -------------------------------------------------------------------------------- /ret_benchmark/data/evaluations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .ret_metric import RetMetric 9 | -------------------------------------------------------------------------------- /ret_benchmark/data/evaluations/ret_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import numpy as np 9 | 10 | 11 | class RetMetric(object): 12 | def __init__(self, feats, labels): 13 | 14 | if len(feats) == 2 and type(feats) == list: 15 | """ 16 | feats = [gallery_feats, query_feats] 17 | labels = [gallery_labels, query_labels] 18 | """ 19 | self.is_equal_query = False 20 | 21 | self.gallery_feats, self.query_feats = feats 22 | self.gallery_labels, self.query_labels = labels 23 | 24 | else: 25 | self.is_equal_query = True 26 | self.gallery_feats = self.query_feats = feats 27 | self.gallery_labels = self.query_labels = labels 28 | 29 | self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats)) 30 | 31 | def recall_k(self, k=1): 32 | m = len(self.sim_mat) 33 | 34 | match_counter = 0 35 | 36 | for i in range(m): 37 | pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]] 38 | neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]] 39 | 40 | thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim) 41 | 42 | if np.sum(neg_sim > thresh) < k: 43 | match_counter += 1 44 | return float(match_counter) / m 45 | -------------------------------------------------------------------------------- /ret_benchmark/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .random_identity_sampler import RandomIdentitySampler 9 | -------------------------------------------------------------------------------- /ret_benchmark/data/samplers/random_identity_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import copy 9 | import random 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data.sampler import Sampler 15 | 16 | 17 | class RandomIdentitySampler(Sampler): 18 | """ 19 | Randomly sample N identities, then for each identity, 20 | randomly sample K instances, therefore batch size is N*K. 21 | Args: 22 | - dataset (BaseDataSet). 23 | - num_instances (int): number of instances per identity in a batch. 24 | - batch_size (int): number of examples in a batch. 25 | """ 26 | 27 | def __init__(self, dataset, batch_size, num_instances, max_iters): 28 | self.label_index_dict = dataset.label_index_dict 29 | self.batch_size = batch_size 30 | self.K = num_instances 31 | self.num_labels_per_batch = self.batch_size // self.K 32 | self.max_iters = max_iters 33 | self.labels = list(self.label_index_dict.keys()) 34 | 35 | def __len__(self): 36 | return self.max_iters 37 | 38 | def __repr__(self): 39 | return self.__str__() 40 | 41 | def __str__(self): 42 | return f"|Sampler| iters {self.max_iters}| K {self.K}| M {self.batch_size}|" 43 | 44 | def _prepare_batch(self): 45 | batch_idxs_dict = defaultdict(list) 46 | 47 | for label in self.labels: 48 | idxs = copy.deepcopy(self.label_index_dict[label]) 49 | if len(idxs) < self.K: 50 | idxs.extend(np.random.choice(idxs, size=self.K - len(idxs), replace=True)) 51 | random.shuffle(idxs) 52 | 53 | batch_idxs_dict[label] = [idxs[i * self.K: (i + 1) * self.K] for i in range(len(idxs) // self.K)] 54 | 55 | avai_labels = copy.deepcopy(self.labels) 56 | return batch_idxs_dict, avai_labels 57 | 58 | def __iter__(self): 59 | batch_idxs_dict, avai_labels = self._prepare_batch() 60 | for _ in range(self.max_iters): 61 | batch = [] 62 | if len(avai_labels) < self.num_labels_per_batch: 63 | batch_idxs_dict, avai_labels = self._prepare_batch() 64 | 65 | selected_labels = random.sample(avai_labels, self.num_labels_per_batch) 66 | for label in selected_labels: 67 | batch_idxs = batch_idxs_dict[label].pop(0) 68 | batch.extend(batch_idxs) 69 | if len(batch_idxs_dict[label]) == 0: 70 | avai_labels.remove(label) 71 | yield batch 72 | -------------------------------------------------------------------------------- /ret_benchmark/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .build import build_transforms 9 | -------------------------------------------------------------------------------- /ret_benchmark/data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torchvision.transforms as T 9 | 10 | 11 | def build_transforms(cfg, is_train=True): 12 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, 13 | std=cfg.INPUT.PIXEL_STD) 14 | if is_train: 15 | transform = T.Compose([ 16 | T.Resize(size=cfg.INPUT.ORIGIN_SIZE), 17 | T.RandomResizedCrop( 18 | scale=cfg.INPUT.CROP_SCALE, 19 | size=cfg.INPUT.CROP_SIZE 20 | ), 21 | T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB), 22 | T.ToTensor(), 23 | normalize_transform, 24 | ]) 25 | else: 26 | transform = T.Compose([ 27 | T.Resize(size=cfg.INPUT.ORIGIN_SIZE), 28 | T.CenterCrop(cfg.INPUT.CROP_SIZE), 29 | T.ToTensor(), 30 | normalize_transform 31 | ]) 32 | return transform 33 | -------------------------------------------------------------------------------- /ret_benchmark/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .trainer import do_train 9 | -------------------------------------------------------------------------------- /ret_benchmark/engine/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import datetime 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from ret_benchmark.data.evaluations import RetMetric 15 | from ret_benchmark.utils.feat_extractor import feat_extractor 16 | from ret_benchmark.utils.freeze_bn import set_bn_eval 17 | from ret_benchmark.utils.metric_logger import MetricLogger 18 | 19 | 20 | def do_train( 21 | cfg, 22 | model, 23 | train_loader, 24 | val_loader, 25 | optimizer, 26 | scheduler, 27 | criterion, 28 | checkpointer, 29 | device, 30 | checkpoint_period, 31 | arguments, 32 | logger 33 | ): 34 | logger.info("Start training") 35 | meters = MetricLogger(delimiter=" ") 36 | max_iter = len(train_loader) 37 | 38 | start_iter = arguments["iteration"] 39 | best_iteration = -1 40 | best_recall = 0 41 | 42 | start_training_time = time.time() 43 | end = time.time() 44 | for iteration, (images, targets) in enumerate(train_loader, start_iter): 45 | 46 | if iteration % cfg.VALIDATION.VERBOSE == 0 or iteration == max_iter: 47 | model.eval() 48 | logger.info('Validation') 49 | labels = val_loader.dataset.label_list 50 | labels = np.array([int(k) for k in labels]) 51 | feats = feat_extractor(model, val_loader, logger=logger) 52 | 53 | ret_metric = RetMetric(feats=feats, labels=labels) 54 | recall_curr = ret_metric.recall_k(1) 55 | 56 | if recall_curr > best_recall: 57 | best_recall = recall_curr 58 | best_iteration = iteration 59 | logger.info(f'Best iteration {iteration}: recall@1: {best_recall:.3f}') 60 | checkpointer.save(f"best_model") 61 | else: 62 | logger.info(f'Recall@1 at iteration {iteration:06d}: {recall_curr:.3f}') 63 | 64 | model.train() 65 | model.apply(set_bn_eval) 66 | 67 | data_time = time.time() - end 68 | iteration = iteration + 1 69 | arguments["iteration"] = iteration 70 | 71 | scheduler.step() 72 | 73 | images = images.to(device) 74 | targets = torch.stack([target.to(device) for target in targets]) 75 | 76 | feats = model(images) 77 | loss = criterion(feats, targets) 78 | optimizer.zero_grad() 79 | loss.backward() 80 | optimizer.step() 81 | 82 | batch_time = time.time() - end 83 | end = time.time() 84 | meters.update(time=batch_time, data=data_time, loss=loss.item()) 85 | 86 | eta_seconds = meters.time.global_avg * (max_iter - iteration) 87 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 88 | 89 | if iteration % 20 == 0 or iteration == max_iter: 90 | logger.info( 91 | meters.delimiter.join( 92 | [ 93 | "eta: {eta}", 94 | "iter: {iter}", 95 | "{meters}", 96 | "lr: {lr:.6f}", 97 | "max mem: {memory:.1f} GB", 98 | ] 99 | ).format( 100 | eta=eta_string, 101 | iter=iteration, 102 | meters=str(meters), 103 | lr=optimizer.param_groups[0]["lr"], 104 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0, 105 | ) 106 | ) 107 | 108 | if iteration % checkpoint_period == 0: 109 | checkpointer.save("model_{:06d}".format(iteration)) 110 | 111 | total_training_time = time.time() - start_training_time 112 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 113 | logger.info( 114 | "Total training time: {} ({:.4f} s / it)".format( 115 | total_time_str, total_training_time / (max_iter) 116 | ) 117 | ) 118 | 119 | logger.info(f"Best iteration: {best_iteration :06d} | best recall {best_recall} ") 120 | -------------------------------------------------------------------------------- /ret_benchmark/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .build import build_loss 9 | -------------------------------------------------------------------------------- /ret_benchmark/losses/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .multi_similarity_loss import MultiSimilarityLoss 9 | from .margin_loss import MarginLoss 10 | from .registry import LOSS 11 | 12 | 13 | def build_loss(cfg): 14 | loss_name = cfg.LOSSES.NAME 15 | assert loss_name in LOSS, \ 16 | f'loss name {loss_name} is not registered in registry :{LOSS.keys()}' 17 | return LOSS[loss_name](cfg) 18 | -------------------------------------------------------------------------------- /ret_benchmark/losses/margin_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from ret_benchmark.losses.registry import LOSS 7 | 8 | 9 | class DistanceWeightedSampling(object): 10 | """ 11 | """ 12 | def __init__(self, cfg): 13 | super(DistanceWeightedSampling, self).__init__() 14 | self.cutoff = cfg.LOSSES.MARGIN_LOSS.CUTOFF 15 | self.upper_cutoff = cfg.LOSSES.MARGIN_LOSS.UPPER_CUTOFF 16 | 17 | def sample(self, batch, labels): 18 | 19 | if isinstance(labels, torch.Tensor): 20 | labels = labels.detach().cpu().numpy() 21 | bs = batch.shape[0] 22 | distances = self.p_dist(batch.detach()).clamp(min=self.cutoff) 23 | 24 | positives, negatives = [], [] 25 | 26 | for i in range(bs): 27 | pos = labels == labels[i] 28 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i]) 29 | # sample positives randomly 30 | pos[i] = 0 31 | positives.append(np.random.choice(np.where(pos)[0])) 32 | # sample negatives by distance 33 | negatives.append(np.random.choice(bs, p=q_d_inv)) 34 | 35 | sampled_triplets = [[a, p, n] for a, p, n in zip(list(range(bs)), positives, negatives)] 36 | return sampled_triplets 37 | 38 | @staticmethod 39 | def p_dist(A, eps=1e-4): 40 | prod = torch.mm(A, A.t()) 41 | norm = prod.diag().unsqueeze(1).expand_as(prod) 42 | res = (norm + norm.t() - 2 * prod).clamp(min=0) 43 | return res.clamp(min=eps).sqrt() 44 | 45 | def inverse_sphere_distances(self, batch, dist, labels, anchor_label): 46 | bs, dim = len(dist), batch.shape[-1] 47 | # negated log-distribution of distances of unit sphere in dimension 48 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dist) - (float(dim-3) / 2) 49 | * torch.log(1.0 - 0.25 * (dist.pow(2)))) 50 | # set sampling probabilities of positives to zero 51 | log_q_d_inv[np.where(labels == anchor_label)[0]] = 0 52 | 53 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 54 | # set sampling probabilities of positives to zero 55 | q_d_inv[np.where(labels == anchor_label)[0]] = 0 56 | 57 | # NOTE: Cutting of values with high distances made the results slightly worse. 58 | # q_d_inv[np.where(dist > self.upper_cutoff)[0]] = 0 59 | 60 | q_d_inv = q_d_inv/q_d_inv.sum() 61 | return q_d_inv.detach().cpu().numpy() 62 | 63 | 64 | @LOSS.register("margin_loss") 65 | class MarginLoss(nn.Module): 66 | """Margin based loss with DistanceWeightedSampling 67 | """ 68 | def __init__(self, cfg): 69 | super(MarginLoss, self).__init__() 70 | self.beta_val = 1.2 71 | self.margin = 0.2 72 | self.nu = 0.0 73 | self.n_classes = cfg.LOSSES.MARGIN_LOSS.N_CLASSES 74 | self.beta_constant = cfg.LOSSES.MARGIN_LOSS.BETA_CONSTANT 75 | if self.beta_constant: 76 | self.beta = self.beta_val 77 | else: 78 | self.beta = torch.nn.Parameter(torch.ones(self.n_classes)*self.beta_val) 79 | self.sampler = DistanceWeightedSampling(cfg) 80 | 81 | def forward(self, batch, labels): 82 | if isinstance(labels, torch.Tensor): 83 | labels = labels.detach().cpu().numpy() 84 | sampled_triplets = self.sampler.sample(batch, labels) 85 | 86 | # compute distances between anchor-positive and anchor-negative. 87 | d_ap, d_an = [], [] 88 | for triplet in sampled_triplets: 89 | train_triplet = {'Anchor': batch[triplet[0], :], 90 | 'Positive': batch[triplet[1], :], 'Negative': batch[triplet[2]]} 91 | pos_dist = ((train_triplet['Anchor']-train_triplet['Positive']).pow(2).sum()+1e-8).pow(1/2) 92 | neg_dist = ((train_triplet['Anchor']-train_triplet['Negative']).pow(2).sum()+1e-8).pow(1/2) 93 | 94 | d_ap.append(pos_dist) 95 | d_an.append(neg_dist) 96 | d_ap, d_an = torch.stack(d_ap), torch.stack(d_an) 97 | 98 | # group betas together by anchor class in sampled triplets (as each beta belongs to one class). 99 | if self.beta_constant: 100 | beta = self.beta 101 | else: 102 | beta = torch.stack([self.beta[labels[triplet[0]]] for 103 | triplet in sampled_triplets]).type(torch.cuda.FloatTensor) 104 | # compute actual margin positive and margin negative loss 105 | pos_loss = F.relu(d_ap-beta+self.margin) 106 | neg_loss = F.relu(beta-d_an+self.margin) 107 | 108 | # compute normalization constant 109 | pair_count = torch.sum((pos_loss > 0.)+(neg_loss > 0.)).type(torch.cuda.FloatTensor) 110 | # actual Margin Loss 111 | loss = torch.sum(pos_loss+neg_loss) if pair_count == 0. else torch.sum(pos_loss+neg_loss)/pair_count 112 | 113 | # (Optional) Add regularization penalty on betas. 114 | # if self.nu: loss = loss + beta_regularisation_loss.type(torch.cuda.FloatTensor) 115 | return loss 116 | -------------------------------------------------------------------------------- /ret_benchmark/losses/multi_similarity_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from ret_benchmark.losses.registry import LOSS 12 | 13 | 14 | @LOSS.register('ms_loss') 15 | class MultiSimilarityLoss(nn.Module): 16 | def __init__(self, cfg): 17 | super(MultiSimilarityLoss, self).__init__() 18 | self.thresh = 0.5 19 | self.margin = 0.1 20 | 21 | self.scale_pos = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS 22 | self.scale_neg = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG 23 | 24 | def forward(self, feats, labels): 25 | assert feats.size(0) == labels.size(0), \ 26 | f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" 27 | batch_size = feats.size(0) 28 | sim_mat = torch.matmul(feats, torch.t(feats)) 29 | 30 | epsilon = 1e-5 31 | loss = list() 32 | 33 | for i in range(batch_size): 34 | pos_pair_ = sim_mat[i][labels == labels[i]] 35 | pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] 36 | neg_pair_ = sim_mat[i][labels != labels[i]] 37 | 38 | neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] 39 | pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] 40 | 41 | if len(neg_pair) < 1 or len(pos_pair) < 1: 42 | continue 43 | 44 | # weighting step 45 | pos_loss = 1.0 / self.scale_pos * torch.log( 46 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) 47 | neg_loss = 1.0 / self.scale_neg * torch.log( 48 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) 49 | loss.append(pos_loss + neg_loss) 50 | 51 | if len(loss) == 0: 52 | return torch.zeros([], requires_grad=True) 53 | 54 | loss = sum(loss) / batch_size 55 | return loss 56 | -------------------------------------------------------------------------------- /ret_benchmark/losses/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from ret_benchmark.utils.registry import Registry 9 | 10 | LOSS = Registry() 11 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .backbone import build_backbone 9 | from .build import build_model 10 | from .heads import build_head 11 | from .registry import BACKBONES, HEADS 12 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_backbone 2 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/backbone/bninception.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from ret_benchmark.modeling import registry 8 | 9 | @registry.BACKBONES.register('bninception') 10 | class BNInception(nn.Module): 11 | 12 | def __init__(self): 13 | super(BNInception, self).__init__() 14 | inplace = True 15 | self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) 16 | self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True) 17 | self.conv1_relu_7x7 = nn.ReLU(inplace) 18 | self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 19 | self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) 20 | self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 21 | self.conv2_relu_3x3_reduce = nn.ReLU(inplace) 22 | self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 23 | self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True) 24 | self.conv2_relu_3x3 = nn.ReLU(inplace) 25 | self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 26 | self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 27 | self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True) 28 | self.inception_3a_relu_1x1 = nn.ReLU(inplace) 29 | self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 30 | self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 31 | self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace) 32 | self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 33 | self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True) 34 | self.inception_3a_relu_3x3 = nn.ReLU(inplace) 35 | self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 36 | self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 37 | self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace) 38 | self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 39 | self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 40 | self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace) 41 | self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 42 | self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 43 | self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace) 44 | self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 45 | self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) 46 | self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True) 47 | self.inception_3a_relu_pool_proj = nn.ReLU(inplace) 48 | self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 49 | self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True) 50 | self.inception_3b_relu_1x1 = nn.ReLU(inplace) 51 | self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 52 | self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 53 | self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace) 54 | self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 55 | self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True) 56 | self.inception_3b_relu_3x3 = nn.ReLU(inplace) 57 | self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 58 | self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 59 | self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace) 60 | self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 61 | self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 62 | self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace) 63 | self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 64 | self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 65 | self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace) 66 | self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 67 | self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 68 | self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True) 69 | self.inception_3b_relu_pool_proj = nn.ReLU(inplace) 70 | self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) 71 | self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 72 | self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace) 73 | self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 74 | self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True) 75 | self.inception_3c_relu_3x3 = nn.ReLU(inplace) 76 | self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) 77 | self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 78 | self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace) 79 | self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 80 | self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 81 | self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace) 82 | self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 83 | self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 84 | self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace) 85 | self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 86 | self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) 87 | self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True) 88 | self.inception_4a_relu_1x1 = nn.ReLU(inplace) 89 | self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) 90 | self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 91 | self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace) 92 | self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 93 | self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True) 94 | self.inception_4a_relu_3x3 = nn.ReLU(inplace) 95 | self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 96 | self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 97 | self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace) 98 | self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 99 | self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) 100 | self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace) 101 | self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 102 | self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) 103 | self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace) 104 | self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 105 | self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 106 | self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 107 | self.inception_4a_relu_pool_proj = nn.ReLU(inplace) 108 | self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) 109 | self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True) 110 | self.inception_4b_relu_1x1 = nn.ReLU(inplace) 111 | self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 112 | self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 113 | self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace) 114 | self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 115 | self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True) 116 | self.inception_4b_relu_3x3 = nn.ReLU(inplace) 117 | self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 118 | self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 119 | self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace) 120 | self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 121 | self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) 122 | self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace) 123 | self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 124 | self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) 125 | self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace) 126 | self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 127 | self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 128 | self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 129 | self.inception_4b_relu_pool_proj = nn.ReLU(inplace) 130 | self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) 131 | self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True) 132 | self.inception_4c_relu_1x1 = nn.ReLU(inplace) 133 | self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 134 | self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 135 | self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace) 136 | self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 137 | self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True) 138 | self.inception_4c_relu_3x3 = nn.ReLU(inplace) 139 | self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 140 | self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 141 | self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace) 142 | self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 143 | self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True) 144 | self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace) 145 | self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 146 | self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True) 147 | self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace) 148 | self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 149 | self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 150 | self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 151 | self.inception_4c_relu_pool_proj = nn.ReLU(inplace) 152 | self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) 153 | self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True) 154 | self.inception_4d_relu_1x1 = nn.ReLU(inplace) 155 | self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 156 | self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 157 | self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace) 158 | self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 159 | self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True) 160 | self.inception_4d_relu_3x3 = nn.ReLU(inplace) 161 | self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) 162 | self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) 163 | self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace) 164 | self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 165 | self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True) 166 | self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace) 167 | self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 168 | self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True) 169 | self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace) 170 | self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 171 | self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 172 | self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 173 | self.inception_4d_relu_pool_proj = nn.ReLU(inplace) 174 | self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 175 | self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 176 | self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace) 177 | self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 178 | self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True) 179 | self.inception_4e_relu_3x3 = nn.ReLU(inplace) 180 | self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) 181 | self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 182 | self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace) 183 | self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 184 | self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True) 185 | self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace) 186 | self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 187 | self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True) 188 | self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace) 189 | self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 190 | self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) 191 | self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True) 192 | self.inception_5a_relu_1x1 = nn.ReLU(inplace) 193 | self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) 194 | self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 195 | self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace) 196 | self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 197 | self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True) 198 | self.inception_5a_relu_3x3 = nn.ReLU(inplace) 199 | self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) 200 | self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) 201 | self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace) 202 | self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 203 | self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) 204 | self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace) 205 | self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 206 | self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) 207 | self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace) 208 | self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 209 | self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) 210 | self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 211 | self.inception_5a_relu_pool_proj = nn.ReLU(inplace) 212 | self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) 213 | self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True) 214 | self.inception_5b_relu_1x1 = nn.ReLU(inplace) 215 | self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 216 | self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 217 | self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace) 218 | self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 219 | self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True) 220 | self.inception_5b_relu_3x3 = nn.ReLU(inplace) 221 | self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 222 | self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 223 | self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace) 224 | self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 225 | self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) 226 | self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace) 227 | self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 228 | self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) 229 | self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace) 230 | self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) 231 | self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) 232 | self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 233 | self.inception_5b_relu_pool_proj = nn.ReLU(inplace) 234 | 235 | def features(self, input): 236 | conv1_7x7_s2_out = self.conv1_7x7_s2(input) 237 | conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) 238 | conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) 239 | pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out) 240 | conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) 241 | conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) 242 | conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) 243 | conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out) 244 | conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) 245 | conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) 246 | pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out) 247 | inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) 248 | inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) 249 | inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) 250 | inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) 251 | inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) 252 | inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) 253 | inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out) 254 | inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) 255 | inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) 256 | inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) 257 | inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn( 258 | inception_3a_double_3x3_reduce_out) 259 | inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce( 260 | inception_3a_double_3x3_reduce_bn_out) 261 | inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out) 262 | inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) 263 | inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) 264 | inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out) 265 | inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) 266 | inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) 267 | inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) 268 | inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) 269 | inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) 270 | inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) 271 | inception_3a_output_out = torch.cat( 272 | [inception_3a_relu_1x1_out, inception_3a_relu_3x3_out, inception_3a_relu_double_3x3_2_out, 273 | inception_3a_relu_pool_proj_out], 1) 274 | inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) 275 | inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) 276 | inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) 277 | inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) 278 | inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) 279 | inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) 280 | inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out) 281 | inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) 282 | inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) 283 | inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) 284 | inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn( 285 | inception_3b_double_3x3_reduce_out) 286 | inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce( 287 | inception_3b_double_3x3_reduce_bn_out) 288 | inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out) 289 | inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) 290 | inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) 291 | inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out) 292 | inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) 293 | inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) 294 | inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) 295 | inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) 296 | inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) 297 | inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) 298 | inception_3b_output_out = torch.cat( 299 | [inception_3b_relu_1x1_out, inception_3b_relu_3x3_out, inception_3b_relu_double_3x3_2_out, 300 | inception_3b_relu_pool_proj_out], 1) 301 | inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) 302 | inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) 303 | inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) 304 | inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out) 305 | inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) 306 | inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) 307 | inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) 308 | inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn( 309 | inception_3c_double_3x3_reduce_out) 310 | inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce( 311 | inception_3c_double_3x3_reduce_bn_out) 312 | inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out) 313 | inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) 314 | inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) 315 | inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out) 316 | inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) 317 | inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) 318 | inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) 319 | inception_3c_output_out = torch.cat( 320 | [inception_3c_relu_3x3_out, inception_3c_relu_double_3x3_2_out, inception_3c_pool_out], 1) 321 | inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) 322 | inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) 323 | inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) 324 | inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) 325 | inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) 326 | inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) 327 | inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out) 328 | inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) 329 | inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) 330 | inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) 331 | inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn( 332 | inception_4a_double_3x3_reduce_out) 333 | inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce( 334 | inception_4a_double_3x3_reduce_bn_out) 335 | inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out) 336 | inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) 337 | inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) 338 | inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out) 339 | inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) 340 | inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) 341 | inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) 342 | inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) 343 | inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) 344 | inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) 345 | inception_4a_output_out = torch.cat( 346 | [inception_4a_relu_1x1_out, inception_4a_relu_3x3_out, inception_4a_relu_double_3x3_2_out, 347 | inception_4a_relu_pool_proj_out], 1) 348 | inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) 349 | inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) 350 | inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) 351 | inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) 352 | inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) 353 | inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) 354 | inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out) 355 | inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) 356 | inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) 357 | inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) 358 | inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn( 359 | inception_4b_double_3x3_reduce_out) 360 | inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce( 361 | inception_4b_double_3x3_reduce_bn_out) 362 | inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out) 363 | inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) 364 | inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) 365 | inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out) 366 | inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) 367 | inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) 368 | inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) 369 | inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) 370 | inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) 371 | inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) 372 | inception_4b_output_out = torch.cat( 373 | [inception_4b_relu_1x1_out, inception_4b_relu_3x3_out, inception_4b_relu_double_3x3_2_out, 374 | inception_4b_relu_pool_proj_out], 1) 375 | inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) 376 | inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) 377 | inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) 378 | inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) 379 | inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) 380 | inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) 381 | inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out) 382 | inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) 383 | inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) 384 | inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) 385 | inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn( 386 | inception_4c_double_3x3_reduce_out) 387 | inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce( 388 | inception_4c_double_3x3_reduce_bn_out) 389 | inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out) 390 | inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) 391 | inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) 392 | inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out) 393 | inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) 394 | inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) 395 | inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) 396 | inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) 397 | inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) 398 | inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) 399 | inception_4c_output_out = torch.cat( 400 | [inception_4c_relu_1x1_out, inception_4c_relu_3x3_out, inception_4c_relu_double_3x3_2_out, 401 | inception_4c_relu_pool_proj_out], 1) 402 | inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) 403 | inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) 404 | inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) 405 | inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) 406 | inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) 407 | inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) 408 | inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out) 409 | inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) 410 | inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) 411 | inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) 412 | inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn( 413 | inception_4d_double_3x3_reduce_out) 414 | inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce( 415 | inception_4d_double_3x3_reduce_bn_out) 416 | inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out) 417 | inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) 418 | inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) 419 | inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out) 420 | inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) 421 | inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) 422 | inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) 423 | inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) 424 | inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) 425 | inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) 426 | inception_4d_output_out = torch.cat( 427 | [inception_4d_relu_1x1_out, inception_4d_relu_3x3_out, inception_4d_relu_double_3x3_2_out, 428 | inception_4d_relu_pool_proj_out], 1) 429 | inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) 430 | inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) 431 | inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) 432 | inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out) 433 | inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) 434 | inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) 435 | inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) 436 | inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn( 437 | inception_4e_double_3x3_reduce_out) 438 | inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce( 439 | inception_4e_double_3x3_reduce_bn_out) 440 | inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out) 441 | inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) 442 | inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) 443 | inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out) 444 | inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) 445 | inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) 446 | inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) 447 | inception_4e_output_out = torch.cat( 448 | [inception_4e_relu_3x3_out, inception_4e_relu_double_3x3_2_out, inception_4e_pool_out], 1) 449 | inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) 450 | inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) 451 | inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) 452 | inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) 453 | inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) 454 | inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) 455 | inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out) 456 | inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) 457 | inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) 458 | inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) 459 | inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn( 460 | inception_5a_double_3x3_reduce_out) 461 | inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce( 462 | inception_5a_double_3x3_reduce_bn_out) 463 | inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out) 464 | inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) 465 | inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) 466 | inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out) 467 | inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) 468 | inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) 469 | inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) 470 | inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) 471 | inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) 472 | inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) 473 | inception_5a_output_out = torch.cat( 474 | [inception_5a_relu_1x1_out, inception_5a_relu_3x3_out, inception_5a_relu_double_3x3_2_out, 475 | inception_5a_relu_pool_proj_out], 1) 476 | inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) 477 | inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) 478 | inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) 479 | inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) 480 | inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) 481 | inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) 482 | inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out) 483 | inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) 484 | inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) 485 | inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) 486 | inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn( 487 | inception_5b_double_3x3_reduce_out) 488 | inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce( 489 | inception_5b_double_3x3_reduce_bn_out) 490 | inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out) 491 | inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) 492 | inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) 493 | inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out) 494 | inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) 495 | inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) 496 | inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) 497 | inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) 498 | inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) 499 | inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) 500 | inception_5b_output_out = torch.cat( 501 | [inception_5b_relu_1x1_out, inception_5b_relu_3x3_out, inception_5b_relu_double_3x3_2_out, 502 | inception_5b_relu_pool_proj_out], 1) 503 | return inception_5b_output_out 504 | 505 | def logits(self, features): 506 | x = F.adaptive_max_pool2d(features, output_size=1) 507 | x = x.view(x.size(0), -1) 508 | return x 509 | 510 | def forward(self, input): 511 | x = self.features(input) 512 | x = self.logits(x) 513 | return x 514 | 515 | def load_param(self, model_path): 516 | param_dict = torch.load(model_path) 517 | for i in param_dict: 518 | if 'last_linear' in i: 519 | continue 520 | self.state_dict()[i].copy_(param_dict[i]) 521 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/backbone/build.py: -------------------------------------------------------------------------------- 1 | from ret_benchmark.modeling.registry import BACKBONES 2 | 3 | from .bninception import BNInception 4 | from .resnet import ResNet50 5 | 6 | 7 | def build_backbone(cfg): 8 | assert cfg.MODEL.BACKBONE.NAME in BACKBONES, \ 9 | f"backbone {cfg.MODEL.BACKBONE} is not registered in registry : {BACKBONES.keys()}" 10 | return BACKBONES[cfg.MODEL.BACKBONE.NAME]() 11 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | from ret_benchmark.modeling import registry 7 | 8 | 9 | @registry.BACKBONES.register('resnet50') 10 | class ResNet50(nn.Module): 11 | 12 | def __init__(self): 13 | super(ResNet50, self).__init__() 14 | self.model = models.resnet50(pretrained=True) 15 | 16 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 17 | module.eval() 18 | module.train = lambda _: None 19 | 20 | def forward(self, x): 21 | x = self.model.conv1(x) 22 | x = self.model.bn1(x) 23 | x = self.model.relu(x) 24 | x = self.model.maxpool(x) 25 | 26 | x = self.model.layer1(x) 27 | x = self.model.layer2(x) 28 | x = self.model.layer3(x) 29 | x = self.model.layer4(x) 30 | 31 | x = self.model.avgpool(x) 32 | x = x.view(x.size(0), -1) 33 | # x = self.model.fc(x) --remove 34 | return x 35 | 36 | def load_param(self, model_path): 37 | param_dict = torch.load(model_path) 38 | for i in param_dict: 39 | if 'last_linear' in i: 40 | continue 41 | self.model.state_dict()[i].copy_(param_dict[i]) 42 | 43 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | from collections import OrderedDict 11 | 12 | import torch 13 | from torch.nn.modules import Sequential 14 | 15 | from .backbone import build_backbone 16 | from .heads import build_head 17 | 18 | 19 | def build_model(cfg): 20 | backbone = build_backbone(cfg) 21 | head = build_head(cfg) 22 | 23 | model = Sequential(OrderedDict([ 24 | ('backbone', backbone), 25 | ('head', head) 26 | ])) 27 | 28 | if cfg.MODEL.PRETRAIN == 'imagenet': 29 | print('Loading imagenet pretrianed model ...') 30 | pretrained_path = os.path.expanduser(cfg.MODEL.PRETRIANED_PATH[cfg.MODEL.BACKBONE.NAME]) 31 | model.backbone.load_param(pretrained_path) 32 | elif os.path.exists(cfg.MODEL.PRETRAIN): 33 | ckp = torch.load(cfg.MODEL.PRETRAIN) 34 | model.load_state_dict(ckp['model']) 35 | return model 36 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from .build import build_head 9 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/heads/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | from ret_benchmark.modeling.registry import HEADS 9 | 10 | from .linear_norm import LinearNorm 11 | 12 | 13 | def build_head(cfg): 14 | assert cfg.MODEL.HEAD.NAME in HEADS, f"head {cfg.MODEL.HEAD.NAME} is not defined" 15 | return HEADS[cfg.MODEL.HEAD.NAME](cfg, in_channels=1024 if cfg.MODEL.BACKBONE.NAME == 'bninception' else 2048) 16 | 17 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/heads/linear_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from ret_benchmark.modeling.registry import HEADS 12 | from ret_benchmark.utils.init_methods import weights_init_kaiming 13 | 14 | 15 | @HEADS.register('linear_norm') 16 | class LinearNorm(nn.Module): 17 | def __init__(self, cfg, in_channels): 18 | super(LinearNorm, self).__init__() 19 | self.fc = nn.Linear(in_channels, cfg.MODEL.HEAD.DIM) 20 | self.fc.apply(weights_init_kaiming) 21 | 22 | def forward(self, x): 23 | x = self.fc(x) 24 | x = nn.functional.normalize(x, p=2, dim=1) 25 | return x 26 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | 9 | from ret_benchmark.utils.registry import Registry 10 | 11 | BACKBONES = Registry() 12 | HEADS = Registry() 13 | -------------------------------------------------------------------------------- /ret_benchmark/modeling/xbm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import tqdm 10 | from ret_benchmark.data.build import build_memory_data 11 | 12 | 13 | class XBM: 14 | def __init__(self, cfg, model): 15 | self.ratio = cfg.MEMORY.RATIO 16 | # init memory 17 | self.feats = list() 18 | self.labels = list() 19 | self.indices = list() 20 | model.train() 21 | for images, labels, indices in build_memory_data(cfg): 22 | with torch.no_grad(): 23 | feat = model(images.cuda()) 24 | self.feats.append(feat) 25 | self.labels.append(labels.cuda()) 26 | self.indices.append(indices.cuda()) 27 | self.feats = torch.cat(self.feats, dim=0) 28 | self.labels = torch.cat(self.labels, dim=0) 29 | self.indices = torch.cat(self.indices, dim=0) 30 | # if memory_ratio != 1.0 -> random sample init queue_mask to mimic fixed queue size 31 | if self.ratio != 1.0: 32 | rand_init_idx = torch.randperm(int(self.indices.shape[0] * self.ratio)).cuda() 33 | self.queue_mask = self.indices[rand_init_idx] 34 | 35 | def enqueue_dequeue(self, feats, indices): 36 | self.feats.data[indices] = feats 37 | if self.ratio != 1.0: 38 | # enqueue 39 | self.queue_mask = torch.cat((self.queue_mask, indices.cuda()), dim=0) 40 | # dequeue 41 | self.queue_mask = self.queue_mask[-int(self.indices.shape[0] * self.ratio):] 42 | 43 | def get(self): 44 | if self.ratio != 1.0: 45 | return self.feats[self.queue_mask], self.labels[self.queue_mask] 46 | else: 47 | return self.feats, self.labels 48 | -------------------------------------------------------------------------------- /ret_benchmark/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .build import build_optimizer 3 | from .build import build_lr_scheduler 4 | from .lr_scheduler import WarmupMultiStepLR 5 | -------------------------------------------------------------------------------- /ret_benchmark/solver/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .lr_scheduler import WarmupMultiStepLR 4 | 5 | 6 | def build_optimizer(cfg, model): 7 | params = [] 8 | for key, value in model.named_parameters(): 9 | if not value.requires_grad: 10 | continue 11 | lr_mul = 1.0 12 | if "backbone" in key: 13 | lr_mul = 0.1 14 | params += [{"params": [value], "lr_mul": lr_mul}] 15 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, 16 | lr=cfg.SOLVER.BASE_LR, 17 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 18 | return optimizer 19 | 20 | 21 | def build_lr_scheduler(cfg, optimizer): 22 | return WarmupMultiStepLR( 23 | optimizer, 24 | cfg.SOLVER.STEPS, 25 | cfg.SOLVER.GAMMA, 26 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 27 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 28 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 29 | ) 30 | -------------------------------------------------------------------------------- /ret_benchmark/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | 3 | import torch 4 | 5 | 6 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 7 | def __init__( 8 | self, 9 | optimizer, 10 | milestones, 11 | gamma=0.1, 12 | warmup_factor=1.0 / 3, 13 | warmup_iters=500, 14 | warmup_method="linear", 15 | last_epoch=-1, 16 | ): 17 | if not list(milestones) == sorted(milestones): 18 | raise ValueError( 19 | "Milestones should be a list of" " increasing integers. Got {}", 20 | milestones, 21 | ) 22 | 23 | if warmup_method not in ("constant", "linear"): 24 | raise ValueError( 25 | "Only 'constant' or 'linear' warmup_method accepted" 26 | "got {}".format(warmup_method) 27 | ) 28 | self.milestones = milestones 29 | self.gamma = gamma 30 | self.warmup_factor = warmup_factor 31 | self.warmup_iters = warmup_iters 32 | self.warmup_method = warmup_method 33 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | warmup_factor = 1 37 | if self.last_epoch < self.warmup_iters: 38 | if self.warmup_method == "constant": 39 | warmup_factor = self.warmup_factor 40 | elif self.warmup_method == "linear": 41 | alpha = float(self.last_epoch) / self.warmup_iters 42 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 43 | return [ 44 | base_lr * warmup_factor * self.gamma ** bisect_right( 45 | self.milestones, 46 | self.last_epoch 47 | ) 48 | for base_lr in self.base_lrs 49 | ] 50 | -------------------------------------------------------------------------------- /ret_benchmark/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | 5 | import torch 6 | from ret_benchmark.utils.model_serialization import load_state_dict 7 | 8 | 9 | class Checkpointer(object): 10 | def __init__( 11 | self, 12 | model, 13 | optimizer=None, 14 | scheduler=None, 15 | save_dir="", 16 | save_to_disk=None, 17 | logger=None, 18 | ): 19 | self.model = model 20 | self.optimizer = optimizer 21 | self.scheduler = scheduler 22 | self.save_dir = save_dir 23 | self.save_to_disk = save_to_disk 24 | if logger is None: 25 | logger = logging.getLogger(__name__) 26 | self.logger = logger 27 | 28 | def save(self, name): 29 | if not self.save_dir: 30 | return 31 | 32 | data = {} 33 | data["model"] = self.model.state_dict() 34 | if self.optimizer is not None: 35 | data["optimizer"] = self.optimizer.state_dict() 36 | if self.scheduler is not None: 37 | data["scheduler"] = self.scheduler.state_dict() 38 | 39 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 40 | self.logger.info("Saving checkpoint to {}".format(save_file)) 41 | torch.save(data, save_file) 42 | 43 | def load(self, f=None): 44 | if self.has_checkpoint(): 45 | # override argument with existing checkpoint 46 | f = self.get_checkpoint_file() 47 | if not f: 48 | # no checkpoint could be found 49 | self.logger.info("No checkpoint found. Initializing model from scratch") 50 | return {} 51 | self.logger.info("Loading checkpoint from {}".format(f)) 52 | checkpoint = self._load_file(f) 53 | self._load_model(checkpoint) 54 | if "optimizer" in checkpoint and self.optimizer: 55 | self.logger.info("Loading optimizer from {}".format(f)) 56 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 57 | if "scheduler" in checkpoint and self.scheduler: 58 | self.logger.info("Loading scheduler from {}".format(f)) 59 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 60 | 61 | # return any further checkpoint data 62 | return checkpoint 63 | 64 | def has_checkpoint(self): 65 | save_file = os.path.join(self.save_dir, "last_checkpoint") 66 | return os.path.exists(save_file) 67 | 68 | def get_checkpoint_file(self): 69 | save_file = os.path.join(self.save_dir, "last_checkpoint") 70 | try: 71 | with open(save_file, "r") as f: 72 | last_saved = f.read() 73 | last_saved = last_saved.strip() 74 | except IOError: 75 | # if file doesn't exist, maybe because it has just been 76 | # deleted by a separate process 77 | last_saved = "" 78 | return last_saved 79 | 80 | def tag_last_checkpoint(self, last_filename): 81 | save_file = os.path.join(self.save_dir, "last_checkpoint") 82 | with open(save_file, "w") as f: 83 | f.write(last_filename) 84 | 85 | def _load_file(self, f): 86 | return torch.load(f, map_location=torch.device("cpu")) 87 | 88 | def _load_model(self, checkpoint): 89 | load_state_dict(self.model, checkpoint.pop("model")) 90 | -------------------------------------------------------------------------------- /ret_benchmark/utils/config_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import copy 5 | import os 6 | 7 | from ret_benchmark.config import cfg as g_cfg 8 | 9 | 10 | def get_config_root_path(): 11 | ''' Path to configs for unit tests ''' 12 | # cur_file_dir is root/tests/env_tests 13 | cur_file_dir = os.path.dirname(os.path.abspath(os.path.realpath(__file__))) 14 | ret = os.path.dirname(os.path.dirname(cur_file_dir)) 15 | ret = os.path.join(ret, "configs") 16 | return ret 17 | 18 | 19 | def load_config(rel_path): 20 | ''' Load config from file path specified as path relative to config_root ''' 21 | cfg_path = os.path.join(get_config_root_path(), rel_path) 22 | return load_config_from_file(cfg_path) 23 | 24 | 25 | def load_config_from_file(file_path): 26 | ''' Load config from file path specified as absolute path ''' 27 | ret = copy.deepcopy(g_cfg) 28 | ret.merge_from_file(file_path) 29 | return ret 30 | -------------------------------------------------------------------------------- /ret_benchmark/utils/feat_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def feat_extractor(model, data_loader, logger=None): 13 | model.eval() 14 | feats = list() 15 | 16 | for i, batch in enumerate(data_loader): 17 | imgs = batch[0].cuda() 18 | 19 | with torch.no_grad(): 20 | out = model(imgs).data.cpu().numpy() 21 | feats.append(out) 22 | 23 | if logger is not None and (i + 1) % 100 == 0: 24 | logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]') 25 | del out 26 | feats = np.vstack(feats) 27 | return feats 28 | -------------------------------------------------------------------------------- /ret_benchmark/utils/freeze_bn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | # Batch Norm Freezer 9 | # Note: adds an additional 2% improvement on CUB (on others benchmarks, it brings no effect) 10 | 11 | def set_bn_eval(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('BatchNorm') != -1: 14 | m.eval() 15 | -------------------------------------------------------------------------------- /ret_benchmark/utils/img_reader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from PIL import Image 3 | 4 | 5 | def read_image(img_path, mode='RGB'): 6 | """Keep reading image until succeed. 7 | This can avoid IOError incurred by heavy IO process.""" 8 | got_img = False 9 | if not osp.exists(img_path): 10 | raise IOError(f"{img_path} does not exist") 11 | while not got_img: 12 | try: 13 | img = Image.open(img_path).convert("RGB") 14 | if mode == "BGR": 15 | r, g, b = img.split() 16 | img = Image.merge("RGB", (b, g, r)) 17 | got_img = True 18 | except IOError: 19 | print(f"IOError incurred when reading '{img_path}'. Will redo.") 20 | pass 21 | return img 22 | -------------------------------------------------------------------------------- /ret_benchmark/utils/init_methods.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Linear') != -1: 15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 16 | nn.init.constant_(m.bias, 0.0) 17 | elif classname.find('Conv') != -1: 18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | if m.bias is not None: 20 | nn.init.constant_(m.bias, 0.0) 21 | elif classname.find('BatchNorm') != -1: 22 | if m.affine: 23 | nn.init.constant_(m.weight, 1.0) 24 | nn.init.constant_(m.bias, 0.0) 25 | 26 | 27 | def weights_init_classifier(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Linear') != -1: 30 | nn.init.normal_(m.weight, std=0.001) 31 | if m.bias is not None: 32 | nn.init.constant_(m.bias, 0.0) 33 | -------------------------------------------------------------------------------- /ret_benchmark/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import sys 10 | import logging 11 | 12 | _streams = { 13 | "stdout": sys.stdout 14 | } 15 | 16 | 17 | def setup_logger(name: str, level: int, stream: str = "stdout") -> logging.Logger: 18 | global _streams 19 | if stream not in _streams: 20 | log_folder = os.path.dirname(stream) 21 | os.makedirs(log_folder, exist_ok=True) 22 | _streams[stream] = open(stream, 'w') 23 | logger = logging.getLogger(name) 24 | logger.propagate = False 25 | logger.setLevel(level) 26 | 27 | sh = logging.StreamHandler(stream=_streams[stream]) 28 | sh.setLevel(level) 29 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 30 | sh.setFormatter(formatter) 31 | logger.addHandler(sh) 32 | return logger 33 | -------------------------------------------------------------------------------- /ret_benchmark/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20): 14 | self.deque = deque(maxlen=window_size) 15 | self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | class MetricLogger(object): 41 | def __init__(self, delimiter="\t"): 42 | self.meters = defaultdict(SmoothedValue) 43 | self.delimiter = delimiter 44 | 45 | def update(self, **kwargs): 46 | for k, v in kwargs.items(): 47 | if isinstance(v, torch.Tensor): 48 | v = v.item() 49 | assert isinstance(v, (float, int)) 50 | self.meters[k].update(v) 51 | 52 | def __getattr__(self, attr): 53 | if attr in self.meters: 54 | return self.meters[attr] 55 | if attr in self.__dict__: 56 | return self.__dict__[attr] 57 | raise AttributeError("'{}' object has no attribute '{}'".format( 58 | type(self).__name__, attr)) 59 | 60 | def __str__(self): 61 | loss_str = [] 62 | for name, meter in self.meters.items(): 63 | loss_str.append( 64 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 65 | ) 66 | return self.delimiter.join(loss_str) 67 | -------------------------------------------------------------------------------- /ret_benchmark/utils/model_serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import OrderedDict 3 | import logging 4 | 5 | import torch 6 | 7 | 8 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict): 9 | """ 10 | Strategy: suppose that the models that we will create will have prefixes appended 11 | to each of its keys, for example due to an extra level of nesting that the original 12 | pre-trained weights from ImageNet won't contain. For example, model.state_dict() 13 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains 14 | res2.conv1.weight. We thus want to match both parameters together. 15 | For that, we look for each model weight, look among all loaded keys if there is one 16 | that is a suffix of the current weight name, and use it if that's the case. 17 | If multiple matches exist, take the one with longest size 18 | of the corresponding name. For example, for the same model as before, the pretrained 19 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, 20 | we want to match backbone[0].body.conv1.weight to conv1.weight, and 21 | backbone[0].body.res2.conv1.weight to res2.conv1.weight. 22 | """ 23 | current_keys = sorted(list(model_state_dict.keys())) 24 | loaded_keys = sorted(list(loaded_state_dict.keys())) 25 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 26 | # loaded_key string, if it matches 27 | match_matrix = [ 28 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 29 | ] 30 | match_matrix = torch.as_tensor(match_matrix).view( 31 | len(current_keys), len(loaded_keys) 32 | ) 33 | max_match_size, idxs = match_matrix.max(1) 34 | # remove indices that correspond to no-match 35 | idxs[max_match_size == 0] = -1 36 | 37 | # used for logging 38 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 39 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 40 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 41 | logger = logging.getLogger(__name__) 42 | for idx_new, idx_old in enumerate(idxs.tolist()): 43 | if idx_old == -1: 44 | continue 45 | key = current_keys[idx_new] 46 | key_old = loaded_keys[idx_old] 47 | model_state_dict[key] = loaded_state_dict[key_old] 48 | logger.info( 49 | log_str_template.format( 50 | key, 51 | max_size, 52 | key_old, 53 | max_size_loaded, 54 | tuple(loaded_state_dict[key_old].shape), 55 | ) 56 | ) 57 | 58 | 59 | def strip_prefix_if_present(state_dict, prefix): 60 | keys = sorted(state_dict.keys()) 61 | if not all(key.startswith(prefix) for key in keys): 62 | return state_dict 63 | stripped_state_dict = OrderedDict() 64 | for key, value in state_dict.items(): 65 | stripped_state_dict[key.replace(prefix, "")] = value 66 | return stripped_state_dict 67 | 68 | 69 | def load_state_dict(model, loaded_state_dict): 70 | model_state_dict = model.state_dict() 71 | # if the state_dict comes from a model that was wrapped in a 72 | # DataParallel or DistributedDataParallel during serialization, 73 | # remove the "module" prefix before performing the matching 74 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 75 | align_and_update_state_dicts(model_state_dict, loaded_state_dict) 76 | 77 | # use strict loading 78 | model.load_state_dict(model_state_dict) 79 | -------------------------------------------------------------------------------- /ret_benchmark/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | def _register_generic(module_dict, module_name, module): 5 | assert module_name not in module_dict 6 | module_dict[module_name] = module 7 | 8 | 9 | class Registry(dict): 10 | ''' 11 | A helper class for managing registering modules, it extends a dictionary 12 | and provides a register functions. 13 | 14 | Eg. creeting a registry: 15 | some_registry = Registry({"default": default_module}) 16 | 17 | There're two ways of registering new modules: 18 | 1): normal way is just calling register function: 19 | def foo(): 20 | ... 21 | some_registry.register("foo_module", foo) 22 | 2): used as decorator when declaring the module: 23 | @some_registry.register("foo_module") 24 | @some_registry.register("foo_modeul_nickname") 25 | def foo(): 26 | ... 27 | 28 | Access of module is just like using a dictionary, eg: 29 | f = some_registry["foo_modeul"] 30 | ''' 31 | 32 | def __init__(self, *args, **kwargs): 33 | super(Registry, self).__init__(*args, **kwargs) 34 | 35 | def register(self, module_name, module=None): 36 | # used as function call 37 | if module is not None: 38 | _register_generic(self, module_name, module) 39 | return 40 | 41 | # used as decorator 42 | def register_fn(fn): 43 | _register_generic(self, module_name, fn) 44 | return fn 45 | 46 | return register_fn 47 | -------------------------------------------------------------------------------- /scripts/prepare_cub.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | CUB_ROOT='resource/datasets/CUB_200_2011/' 5 | CUB_DATA='http://www.vision.caltech.edu.s3-us-west-2.amazonaws.com/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 6 | 7 | 8 | if [[ ! -d "${CUB_ROOT}" ]]; then 9 | mkdir -p resource/datasets 10 | pushd resource/datasets 11 | echo "Downloading CUB_200_2011 data-set..." 12 | wget ${CUB_DATA} 13 | tar -zxf CUB_200_2011.tgz 14 | popd 15 | fi 16 | # Generate train.txt and test.txt splits 17 | echo "Generating the train.txt/test.txt split files" 18 | python scripts/split_cub_for_ms_loss.py 19 | 20 | 21 | -------------------------------------------------------------------------------- /scripts/run_cub.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUT_DIR="output" 4 | if [[ ! -d "${OUT_DIR}" ]]; then 5 | echo "Creating output dir for training : ${OUT_DIR}" 6 | mkdir ${OUT_DIR} 7 | fi 8 | CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example.yaml 9 | -------------------------------------------------------------------------------- /scripts/run_cub_margin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUT_DIR="output_margin" 4 | if [[ ! -d "${OUT_DIR}" ]]; then 5 | echo "Creating output dir for training : ${OUT_DIR}" 6 | mkdir ${OUT_DIR} 7 | fi 8 | CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example_margin.yaml 9 | -------------------------------------------------------------------------------- /scripts/split_cub_for_ms_loss.py: -------------------------------------------------------------------------------- 1 | 2 | cub_root = 'resource/datasets/CUB_200_2011/' 3 | images_file = cub_root + 'images.txt' 4 | train_file = cub_root + 'train.txt' 5 | test_file = cub_root + 'test.txt' 6 | 7 | 8 | def main(): 9 | train = [] 10 | test = [] 11 | with open(images_file) as f_img: 12 | for l_img in f_img: 13 | i, fname = l_img.split() 14 | label = int(fname.split('.', 1)[0]) 15 | if label <= 100: 16 | train.append((fname, label - 1)) # labels 0 ... 99 (0-based labels for margin_loss) 17 | else: 18 | test.append((fname, label - 1)) # labels 100 ... 199 19 | 20 | for f, v in [(train_file, train), (test_file, test)]: 21 | with open(f, 'w') as tf: 22 | for fname, label in v: 23 | print("images/{},{}".format(fname, label), file=tf) 24 | 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | from setuptools import find_packages, setup 11 | from torch.utils.cpp_extension import CppExtension 12 | 13 | 14 | requirements = ["torch", "torchvision"] 15 | 16 | setup( 17 | name="ret_benchmark", 18 | version="0.1", 19 | author="Malong Technologies", 20 | url="https://github.com/MalongTech/research-ms-loss", 21 | description="ms-loss", 22 | packages=find_packages(exclude=("configs", "tests")), 23 | install_requires=requirements, 24 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 25 | ) 26 | -------------------------------------------------------------------------------- /tools/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import torch 10 | 11 | from ret_benchmark.config import cfg 12 | from ret_benchmark.data import build_data 13 | from ret_benchmark.engine.trainer import do_train 14 | from ret_benchmark.losses import build_loss 15 | from ret_benchmark.modeling import build_model 16 | from ret_benchmark.solver import build_lr_scheduler, build_optimizer 17 | from ret_benchmark.utils.logger import setup_logger 18 | from ret_benchmark.utils.checkpoint import Checkpointer 19 | 20 | 21 | def train(cfg): 22 | logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL) 23 | logger.info(cfg) 24 | model = build_model(cfg) 25 | device = torch.device(cfg.MODEL.DEVICE) 26 | model.to(device) 27 | 28 | criterion = build_loss(cfg) 29 | 30 | optimizer = build_optimizer(cfg, model) 31 | scheduler = build_lr_scheduler(cfg, optimizer) 32 | 33 | train_loader = build_data(cfg, is_train=True) 34 | val_loader = build_data(cfg, is_train=False) 35 | 36 | logger.info(train_loader.dataset) 37 | logger.info(val_loader.dataset) 38 | 39 | arguments = dict() 40 | arguments["iteration"] = 0 41 | 42 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 43 | checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR) 44 | 45 | do_train( 46 | cfg, 47 | model, 48 | train_loader, 49 | val_loader, 50 | optimizer, 51 | scheduler, 52 | criterion, 53 | checkpointer, 54 | device, 55 | checkpoint_period, 56 | arguments, 57 | logger 58 | ) 59 | 60 | 61 | def parse_args(): 62 | """ 63 | Parse input arguments 64 | """ 65 | parser = argparse.ArgumentParser(description='Train a retrieval network') 66 | parser.add_argument( 67 | '--cfg', 68 | dest='cfg_file', 69 | help='config file', 70 | default=None, 71 | type=str) 72 | return parser.parse_args() 73 | 74 | 75 | if __name__ == '__main__': 76 | args = parse_args() 77 | cfg.merge_from_file(args.cfg_file) 78 | train(cfg) 79 | --------------------------------------------------------------------------------