├── .flake8
├── .gitignore
├── LICENSE
├── README.md
├── ThirdPartyNotices.txt
├── configs
├── example.yaml
├── example_margin.yaml
└── example_resnet50.yaml
├── misc
└── ms_loss.png
├── requirements.txt
├── ret_benchmark
├── config
│ ├── __init__.py
│ ├── defaults.py
│ └── model_path.py
├── data
│ ├── __init__.py
│ ├── build.py
│ ├── collate_batch.py
│ ├── datasets
│ │ ├── __init__.py
│ │ └── base_dataset.py
│ ├── evaluations
│ │ ├── __init__.py
│ │ └── ret_metric.py
│ ├── samplers
│ │ ├── __init__.py
│ │ └── random_identity_sampler.py
│ └── transforms
│ │ ├── __init__.py
│ │ └── build.py
├── engine
│ ├── __init__.py
│ └── trainer.py
├── losses
│ ├── __init__.py
│ ├── build.py
│ ├── margin_loss.py
│ ├── multi_similarity_loss.py
│ └── registry.py
├── modeling
│ ├── __init__.py
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── bninception.py
│ │ ├── build.py
│ │ └── resnet.py
│ ├── build.py
│ ├── heads
│ │ ├── __init__.py
│ │ ├── build.py
│ │ └── linear_norm.py
│ ├── registry.py
│ └── xbm.py
├── solver
│ ├── __init__.py
│ ├── build.py
│ └── lr_scheduler.py
└── utils
│ ├── checkpoint.py
│ ├── config_util.py
│ ├── feat_extractor.py
│ ├── freeze_bn.py
│ ├── img_reader.py
│ ├── init_methods.py
│ ├── logger.py
│ ├── metric_logger.py
│ ├── model_serialization.py
│ └── registry.py
├── scripts
├── prepare_cub.sh
├── run_cub.sh
├── run_cub_margin.sh
└── split_cub_for_ms_loss.py
├── setup.py
└── tools
└── main.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = F401, F841, E402, E722, E999
3 | max-line-length = 128
4 | max-complexity=18
5 | format=pylint
6 | show_source = True
7 | statistics = True
8 | count = True
9 | exclude = tests,ret_benchmark/modeling/backbone
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | resource
2 | build
3 | *.pyc
4 | *.zip
5 | */__pycache__
6 | __pycache__
7 |
8 | # Package Files #
9 | *.pkl
10 | *.log
11 | *.jar
12 | *.war
13 | *.nar
14 | *.ear
15 | *.zip
16 | *.tar.gz
17 | *.rar
18 | *.egg-info
19 |
20 | #some local files
21 | */.settings/
22 | */.DS_Store
23 | .DS_Store
24 | */.idea/
25 | .idea/
26 | gradlew
27 | gradlew.bat
28 | unused.txt
29 | output/
30 | *.egg-info/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC-4.0)
2 | Public License
3 |
4 | For Multi-Similarity Loss for Deep Metric Learning (MS-Loss)
5 |
6 | Copyright (c) 2014-present, Malong Technologies Co., Ltd. All rights reserved.
7 |
8 |
9 | By exercising the Licensed Rights (defined below), You accept and agree
10 | to be bound by the terms and conditions of this Creative Commons
11 | Attribution-NonCommercial 4.0 International Public License ("Public
12 | License"). To the extent this Public License may be interpreted as a
13 | contract, You are granted the Licensed Rights in consideration of Your
14 | acceptance of these terms and conditions, and the Licensor grants You
15 | such rights in consideration of benefits the Licensor receives from
16 | making the Licensed Material available under these terms and
17 | conditions.
18 |
19 |
20 | Section 1 -- Definitions.
21 |
22 | a. Adapted Material means material subject to Copyright and Similar
23 | Rights that is derived from or based upon the Licensed Material
24 | and in which the Licensed Material is translated, altered,
25 | arranged, transformed, or otherwise modified in a manner requiring
26 | permission under the Copyright and Similar Rights held by the
27 | Licensor. For purposes of this Public License, where the Licensed
28 | Material is a musical work, performance, or sound recording,
29 | Adapted Material is always produced where the Licensed Material is
30 | synched in timed relation with a moving image.
31 |
32 | b. Adapter's License means the license You apply to Your Copyright
33 | and Similar Rights in Your contributions to Adapted Material in
34 | accordance with the terms and conditions of this Public License.
35 |
36 | c. Copyright and Similar Rights means copyright and/or similar rights
37 | closely related to copyright including, without limitation,
38 | performance, broadcast, sound recording, and Sui Generis Database
39 | Rights, without regard to how the rights are labeled or
40 | categorized. For purposes of this Public License, the rights
41 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
42 | Rights.
43 | d. Effective Technological Measures means those measures that, in the
44 | absence of proper authority, may not be circumvented under laws
45 | fulfilling obligations under Article 11 of the WIPO Copyright
46 | Treaty adopted on December 20, 1996, and/or similar international
47 | agreements.
48 |
49 | e. Exceptions and Limitations means fair use, fair dealing, and/or
50 | any other exception or limitation to Copyright and Similar Rights
51 | that applies to Your use of the Licensed Material.
52 |
53 | f. Licensed Material means the artistic or literary work, database,
54 | or other material to which the Licensor applied this Public
55 | License.
56 |
57 | g. Licensed Rights means the rights granted to You subject to the
58 | terms and conditions of this Public License, which are limited to
59 | all Copyright and Similar Rights that apply to Your use of the
60 | Licensed Material and that the Licensor has authority to license.
61 |
62 | h. Licensor means the individual(s) or entity(ies) granting rights
63 | under this Public License.
64 |
65 | i. NonCommercial means not primarily intended for or directed towards
66 | commercial advantage or monetary compensation. For purposes of
67 | this Public License, the exchange of the Licensed Material for
68 | other material subject to Copyright and Similar Rights by digital
69 | file-sharing or similar means is NonCommercial provided there is
70 | no payment of monetary compensation in connection with the
71 | exchange.
72 |
73 | j. Share means to provide material to the public by any means or
74 | process that requires permission under the Licensed Rights, such
75 | as reproduction, public display, public performance, distribution,
76 | dissemination, communication, or importation, and to make material
77 | available to the public including in ways that members of the
78 | public may access the material from a place and at a time
79 | individually chosen by them.
80 |
81 | k. Sui Generis Database Rights means rights other than copyright
82 | resulting from Directive 96/9/EC of the European Parliament and of
83 | the Council of 11 March 1996 on the legal protection of databases,
84 | as amended and/or succeeded, as well as other essentially
85 | equivalent rights anywhere in the world.
86 |
87 | l. You means the individual or entity exercising the Licensed Rights
88 | under this Public License. Your has a corresponding meaning.
89 |
90 |
91 | Section 2 -- Scope.
92 |
93 | a. License grant.
94 |
95 | 1. Subject to the terms and conditions of this Public License,
96 | the Licensor hereby grants You a worldwide, royalty-free,
97 | non-sublicensable, non-exclusive, irrevocable license to
98 | exercise the Licensed Rights in the Licensed Material to:
99 |
100 | a. reproduce and Share the Licensed Material, in whole or
101 | in part, for NonCommercial purposes only; and
102 |
103 | b. produce, reproduce, and Share Adapted Material for
104 | NonCommercial purposes only.
105 |
106 | 2. Exceptions and Limitations. For the avoidance of doubt, where
107 | Exceptions and Limitations apply to Your use, this Public
108 | License does not apply, and You do not need to comply with
109 | its terms and conditions.
110 |
111 | 3. Term. The term of this Public License is specified in Section
112 | 6(a).
113 |
114 | 4. Media and formats; technical modifications allowed. The
115 | Licensor authorizes You to exercise the Licensed Rights in
116 | all media and formats whether now known or hereafter created,
117 | and to make technical modifications necessary to do so. The
118 | Licensor waives and/or agrees not to assert any right or
119 | authority to forbid You from making technical modifications
120 | necessary to exercise the Licensed Rights, including
121 | technical modifications necessary to circumvent Effective
122 | Technological Measures. For purposes of this Public License,
123 | simply making modifications authorized by this Section 2(a)
124 | (4) never produces Adapted Material.
125 |
126 | 5. Downstream recipients.
127 |
128 | a. Offer from the Licensor -- Licensed Material. Every
129 | recipient of the Licensed Material automatically
130 | receives an offer from the Licensor to exercise the
131 | Licensed Rights under the terms and conditions of this
132 | Public License.
133 |
134 | b. No downstream restrictions. You may not offer or impose
135 | any additional or different terms or conditions on, or
136 | apply any Effective Technological Measures to, the
137 | Licensed Material if doing so restricts exercise of the
138 | Licensed Rights by any recipient of the Licensed
139 | Material.
140 |
141 | 6. No endorsement. Nothing in this Public License constitutes or
142 | may be construed as permission to assert or imply that You
143 | are, or that Your use of the Licensed Material is, connected
144 | with, or sponsored, endorsed, or granted official status by,
145 | the Licensor or others designated to receive attribution as
146 | provided in Section 3(a)(1)(A)(i).
147 |
148 | b. Other rights.
149 |
150 | 1. Moral rights, such as the right of integrity, are not
151 | licensed under this Public License, nor are publicity,
152 | privacy, and/or other similar personality rights; however, to
153 | the extent possible, the Licensor waives and/or agrees not to
154 | assert any such rights held by the Licensor to the limited
155 | extent necessary to allow You to exercise the Licensed
156 | Rights, but not otherwise.
157 |
158 | 2. Patent and trademark rights are not licensed under this
159 | Public License.
160 |
161 | 3. To the extent possible, the Licensor waives any right to
162 | collect royalties from You for the exercise of the Licensed
163 | Rights, whether directly or through a collecting society
164 | under any voluntary or waivable statutory or compulsory
165 | licensing scheme. In all other cases the Licensor expressly
166 | reserves any right to collect such royalties, including when
167 | the Licensed Material is used other than for NonCommercial
168 | purposes.
169 |
170 |
171 | Section 3 -- License Conditions.
172 |
173 | Your exercise of the Licensed Rights is expressly made subject to the
174 | following conditions.
175 |
176 | a. Attribution.
177 |
178 | 1. If You Share the Licensed Material (including in modified
179 | form), You must:
180 |
181 | a. retain the following if it is supplied by the Licensor
182 | with the Licensed Material:
183 |
184 | i. identification of the creator(s) of the Licensed
185 | Material and any others designated to receive
186 | attribution, in any reasonable manner requested by
187 | the Licensor (including by pseudonym if
188 | designated);
189 |
190 | ii. a copyright notice;
191 |
192 | iii. a notice that refers to this Public License;
193 |
194 | iv. a notice that refers to the disclaimer of
195 | warranties;
196 |
197 | v. a URI or hyperlink to the Licensed Material to the
198 | extent reasonably practicable;
199 |
200 | b. indicate if You modified the Licensed Material and
201 | retain an indication of any previous modifications; and
202 |
203 | c. indicate the Licensed Material is licensed under this
204 | Public License, and include the text of, or the URI or
205 | hyperlink to, this Public License.
206 |
207 | 2. You may satisfy the conditions in Section 3(a)(1) in any
208 | reasonable manner based on the medium, means, and context in
209 | which You Share the Licensed Material. For example, it may be
210 | reasonable to satisfy the conditions by providing a URI or
211 | hyperlink to a resource that includes the required
212 | information.
213 |
214 | 3. If requested by the Licensor, You must remove any of the
215 | information required by Section 3(a)(1)(A) to the extent
216 | reasonably practicable.
217 |
218 | 4. If You Share Adapted Material You produce, the Adapter's
219 | License You apply must not prevent recipients of the Adapted
220 | Material from complying with this Public License.
221 |
222 |
223 | Section 4 -- Sui Generis Database Rights.
224 |
225 | Where the Licensed Rights include Sui Generis Database Rights that
226 | apply to Your use of the Licensed Material:
227 |
228 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
229 | to extract, reuse, reproduce, and Share all or a substantial
230 | portion of the contents of the database for NonCommercial purposes
231 | only;
232 |
233 | b. if You include all or a substantial portion of the database
234 | contents in a database in which You have Sui Generis Database
235 | Rights, then the database in which You have Sui Generis Database
236 | Rights (but not its individual contents) is Adapted Material; and
237 |
238 | c. You must comply with the conditions in Section 3(a) if You Share
239 | all or a substantial portion of the contents of the database.
240 |
241 | For the avoidance of doubt, this Section 4 supplements and does not
242 | replace Your obligations under this Public License where the Licensed
243 | Rights include other Copyright and Similar Rights.
244 |
245 |
246 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
247 |
248 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
249 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
250 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
251 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
252 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
253 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
254 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
255 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
256 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
257 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
258 |
259 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
260 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
261 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
262 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
263 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
264 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
265 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
266 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
267 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
268 |
269 | c. The disclaimer of warranties and limitation of liability provided
270 | above shall be interpreted in a manner that, to the extent
271 | possible, most closely approximates an absolute disclaimer and
272 | waiver of all liability.
273 |
274 |
275 | Section 6 -- Term and Termination.
276 |
277 | a. This Public License applies for the term of the Copyright and
278 | Similar Rights licensed here. However, if You fail to comply with
279 | this Public License, then Your rights under this Public License
280 | terminate automatically.
281 |
282 | b. Where Your right to use the Licensed Material has terminated under
283 | Section 6(a), it reinstates:
284 |
285 | 1. automatically as of the date the violation is cured, provided
286 | it is cured within 30 days of Your discovery of the
287 | violation; or
288 |
289 | 2. upon express reinstatement by the Licensor.
290 |
291 | For the avoidance of doubt, this Section 6(b) does not affect any
292 | right the Licensor may have to seek remedies for Your violations
293 | of this Public License.
294 |
295 | c. For the avoidance of doubt, the Licensor may also offer the
296 | Licensed Material under separate terms or conditions or stop
297 | distributing the Licensed Material at any time; however, doing so
298 | will not terminate this Public License.
299 |
300 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
301 | License.
302 |
303 |
304 | Section 7 -- Other Terms and Conditions.
305 |
306 | a. The Licensor shall not be bound by any additional or different
307 | terms or conditions communicated by You unless expressly agreed.
308 |
309 | b. Any arrangements, understandings, or agreements regarding the
310 | Licensed Material not stated herein are separate from and
311 | independent of the terms and conditions of this Public License.
312 |
313 |
314 | Section 8 -- Interpretation.
315 |
316 | a. For the avoidance of doubt, this Public License does not, and
317 | shall not be interpreted to, reduce, limit, restrict, or impose
318 | conditions on any use of the Licensed Material that could lawfully
319 | be made without permission under this Public License.
320 |
321 | b. To the extent possible, if any provision of this Public License is
322 | deemed unenforceable, it shall be automatically reformed to the
323 | minimum extent necessary to make it enforceable. If the provision
324 | cannot be reformed, it shall be severed from this Public License
325 | without affecting the enforceability of the remaining terms and
326 | conditions.
327 |
328 | c. No term or condition of this Public License will be waived and no
329 | failure to comply consented to unless expressly agreed to by the
330 | Licensor.
331 |
332 | d. Nothing in this Public License constitutes or may be interpreted
333 | as a limitation upon, or waiver of, any privileges and immunities
334 | that apply to the Licensor or You, including from the legal
335 | processes of any jurisdiction or authority.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://creativecommons.org/licenses/by-nc/4.0/)
2 |
3 |
4 | # Multi-Similarity Loss for Deep Metric Learning (MS-Loss)
5 |
6 | Code for the CVPR 2019 paper [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
7 |
8 |
9 |
10 | ### Performance compared with SOTA methods on CUB-200-2011
11 |
12 | |Rank@K | 1 | 2 | 4 | 8 | 16 | 32 |
13 | |:--- |:-:|:-:|:-:|:-:|:-: |:-: |
14 | |Clustering64 | 48.2 | 61.4 | 71.8 | 81.9 | - | - |
15 | |ProxyNCA64 | 49.2 | 61.9 | 67.9 | 72.4 | - | - |
16 | |Smart Mining64 | 49.8 | 62.3 | 74.1 | 83.3 | - |
17 | |Our MS-Loss64| **57.4** |**69.8** |**80.0** |**87.8** |93.2 |96.4|
18 | |HTL512 | 57.1| 68.8| 78.7| 86.5| 92.5| 95.5 |
19 | |ABIER512 |57.5 |68.7 |78.3 |86.2 |91.9 |95.5 |
20 | |Our MS-Loss512|**65.7** |**77.0** |**86.3**|**91.2** |**95.0** |**97.3**|
21 |
22 |
23 | ### Prepare the data and the pretrained model
24 |
25 | The following script will prepare the [CUB](http://www.vision.caltech.edu.s3-us-west-2.amazonaws.com/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) dataset for training by downloading to the ./resource/datasets/ folder; which will then build the data list (train.txt test.txt):
26 |
27 | ```bash
28 | ./scripts/prepare_cub.sh
29 | ```
30 |
31 | Download the imagenet pretrained model of
32 | [bninception](http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth) and put it in the folder: ~/.torch/models/.
33 |
34 |
35 | ### Installation
36 |
37 | ```bash
38 | pip install -r requirements.txt
39 | python setup.py develop build
40 | ```
41 | ### Train and Test on CUB200-2011 with MS-Loss
42 |
43 | ```bash
44 | ./scripts/run_cub.sh
45 | ```
46 | Trained models will be saved in the ./output/ folder if using the default config.
47 |
48 | Best recall@1 higher than 66 (65.7 in the paper).
49 |
50 | ### Contact
51 |
52 | For any questions, please feel free to reach
53 | ```
54 | github@malongtech.com
55 | ```
56 |
57 | ### Citation
58 |
59 | If you use this method or this code in your research, please cite as:
60 |
61 | @inproceedings{wang2019multi,
62 | title={Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning},
63 | author={Wang, Xun and Han, Xintong and Huang, Weilin and Dong, Dengke and Scott, Matthew R},
64 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
65 | pages={5022--5030},
66 | year={2019}
67 | }
68 |
69 | ## License
70 |
71 | MS-Loss is CC-BY-NC 4.0 licensed, as found in the [LICENSE](LICENSE) file. It is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact sales@malongtech.com.
72 |
73 |
--------------------------------------------------------------------------------
/ThirdPartyNotices.txt:
--------------------------------------------------------------------------------
1 | THIRD PARTY SOFTWARE NOTICES AND INFORMATION
2 |
3 | Do Not Translate or Localize
4 |
5 | This software incorporates material from the following third parties.
6 |
7 | _____
8 |
9 | Cadene/pretrained-models.pytorch
10 |
11 | BSD 3-Clause License
12 |
13 | Copyright (c) 2017, Remi Cadene
14 | All rights reserved.
15 |
16 | Redistribution and use in source and binary forms, with or without
17 | modification, are permitted provided that the following conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright notice, this
20 | list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright notice,
23 | this list of conditions and the following disclaimer in the documentation
24 | and/or other materials provided with the distribution.
25 |
26 | * Neither the name of the copyright holder nor the names of its
27 | contributors may be used to endorse or promote products derived from
28 | this software without specific prior written permission.
29 |
30 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
40 |
41 | _____
42 |
43 | facebookresearch/maskrcnn-benchmark
44 |
45 | MIT License
46 |
47 | Copyright (c) 2018 Facebook
48 |
49 | Permission is hereby granted, free of charge, to any person obtaining a copy
50 | of this software and associated documentation files (the "Software"), to deal
51 | in the Software without restriction, including without limitation the rights
52 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
53 | copies of the Software, and to permit persons to whom the Software is
54 | furnished to do so, subject to the following conditions:
55 |
56 | The above copyright notice and this permission notice shall be included in all
57 | copies or substantial portions of the Software.
58 |
59 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
60 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
61 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
62 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
63 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
64 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
65 | SOFTWARE.
--------------------------------------------------------------------------------
/configs/example.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | MODEL:
9 | BACKBONE:
10 | NAME: bninception
11 |
12 | SOLVER:
13 | MAX_ITERS: 3000
14 | STEPS: [1200, 2400]
15 | OPTIMIZER_NAME: Adam
16 | BASE_LR: 0.00003
17 | WARMUP_ITERS: 0
18 | WEIGHT_DECAY: 0.0005
19 |
20 | DATA:
21 | TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt
22 | TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt
23 | TRAIN_BATCHSIZE: 80
24 | TEST_BATCHSIZE: 256
25 | NUM_WORKERS: 8
26 | NUM_INSTANCES: 5
27 |
28 | VALIDATION:
29 | VERBOSE: 200
--------------------------------------------------------------------------------
/configs/example_margin.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | MODEL:
9 | BACKBONE:
10 | NAME: bninception
11 |
12 | LOSSES:
13 | NAME: margin_loss
14 | MARGIN_LOSS:
15 | N_CLASSES: 100
16 | BETA_CONSTANT: False # if False (i.e. class specific beta) train.txt should have labels 0 .... N_CLASSES -1
17 |
18 | SOLVER:
19 | MAX_ITERS: 3000
20 | STEPS: [1200, 2400]
21 | OPTIMIZER_NAME: Adam
22 | BASE_LR: 0.00003
23 | WARMUP_ITERS: 0
24 | WEIGHT_DECAY: 0.0005
25 |
26 | DATA:
27 | TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt
28 | TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt
29 | TRAIN_BATCHSIZE: 120
30 | TEST_BATCHSIZE: 256
31 | NUM_WORKERS: 8
32 | NUM_INSTANCES: 5
33 |
34 | VALIDATION:
35 | VERBOSE: 200
36 |
37 | SAVE_DIR: output_margin
38 |
39 |
--------------------------------------------------------------------------------
/configs/example_resnet50.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | MODEL:
9 | BACKBONE:
10 | NAME: resnet50
11 |
12 | INPUT:
13 | MODE: 'RGB'
14 | PIXEL_MEAN: [0.485, 0.456, 0.406]
15 | PIXEL_STD: [0.229, 0.224, 0.225]
16 |
17 | SOLVER:
18 | MAX_ITERS: 3000
19 | STEPS: [1200, 2400]
20 | OPTIMIZER_NAME: Adam
21 | BASE_LR: 0.00003
22 | WARMUP_ITERS: 0
23 | WEIGHT_DECAY: 0.0005
24 |
25 | DATA:
26 | TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt
27 | TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt
28 | TRAIN_BATCHSIZE: 80
29 | TEST_BATCHSIZE: 256
30 | NUM_WORKERS: 8
31 | NUM_INSTANCES: 5
32 |
33 | VALIDATION:
34 | VERBOSE: 200
35 |
--------------------------------------------------------------------------------
/misc/ms_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msight-tech/research-ms-loss/b68507d4e22d8a6d3d3c0e6c31be708f9dcd20ee/misc/ms_loss.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.1.0
2 | numpy==1.15.4
3 | yacs==0.1.4
4 | setuptools==40.6.2
5 | pytest==4.4.0
6 | Pillow==8.3.2
7 | torchvision==0.3.0
8 |
--------------------------------------------------------------------------------
/ret_benchmark/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .defaults import _C as cfg
9 |
--------------------------------------------------------------------------------
/ret_benchmark/config/defaults.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from yacs.config import CfgNode as CN
9 | from .model_path import MODEL_PATH
10 |
11 | # -----------------------------------------------------------------------------
12 | # Config definition
13 | # -----------------------------------------------------------------------------
14 |
15 | _C = CN()
16 |
17 | _C.MODEL = CN()
18 | _C.MODEL.DEVICE = "cuda"
19 |
20 | _C.MODEL.BACKBONE = CN()
21 | _C.MODEL.BACKBONE.NAME = "bninception"
22 |
23 | _C.MODEL.PRETRAIN = 'imagenet'
24 | _C.MODEL.PRETRIANED_PATH = MODEL_PATH
25 |
26 | _C.MODEL.HEAD = CN()
27 | _C.MODEL.HEAD.NAME = "linear_norm"
28 | _C.MODEL.HEAD.DIM = 512
29 |
30 | _C.MODEL.WEIGHT = ""
31 |
32 | # Checkpoint save dir
33 | _C.SAVE_DIR = 'output'
34 |
35 | # Loss
36 | _C.LOSSES = CN()
37 | _C.LOSSES.NAME = 'ms_loss'
38 |
39 | # ms loss
40 | _C.LOSSES.MULTI_SIMILARITY_LOSS = CN()
41 | _C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS = 2.0
42 | _C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG = 40.0
43 | _C.LOSSES.MULTI_SIMILARITY_LOSS.HARD_MINING = True
44 |
45 | # margin loss
46 | _C.LOSSES.MARGIN_LOSS = CN()
47 | _C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False
48 | _C.LOSSES.MARGIN_LOSS.N_CLASSES = 100
49 | _C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False
50 | _C.LOSSES.MARGIN_LOSS.CUTOFF = 0.5
51 | _C.LOSSES.MARGIN_LOSS.UPPER_CUTOFF = 1.4
52 |
53 | # Data option
54 | _C.DATA = CN()
55 | _C.DATA.TRAIN_IMG_SOURCE = 'resource/datasets/CUB_200_2011/train.txt'
56 | _C.DATA.TEST_IMG_SOURCE = 'resource/datasets/CUB_200_2011/test.txt'
57 | _C.DATA.TRAIN_BATCHSIZE = 70
58 | _C.DATA.TEST_BATCHSIZE = 256
59 | _C.DATA.NUM_WORKERS = 8
60 | _C.DATA.NUM_INSTANCES = 5
61 |
62 | # Input option
63 | _C.INPUT = CN()
64 |
65 | # INPUT CONFIG
66 | _C.INPUT.MODE = 'BGR'
67 | _C.INPUT.PIXEL_MEAN = [104. / 255, 117. / 255, 128. / 255]
68 | _C.INPUT.PIXEL_STD = 3 * [1. / 255]
69 |
70 | _C.INPUT.FLIP_PROB = 0.5
71 | _C.INPUT.ORIGIN_SIZE = 256
72 | _C.INPUT.CROP_SCALE = [0.16, 1]
73 | _C.INPUT.CROP_SIZE = 227
74 |
75 | # SOLVER
76 | _C.SOLVER = CN()
77 | _C.SOLVER.IS_FINETURN = False
78 | _C.SOLVER.FINETURN_MODE_PATH = ''
79 | _C.SOLVER.MAX_ITERS = 4000
80 | _C.SOLVER.STEPS = [1000, 2000, 3000]
81 | _C.SOLVER.OPTIMIZER_NAME = 'SGD'
82 | _C.SOLVER.BASE_LR = 0.01
83 | _C.SOLVER.BIAS_LR_FACTOR = 1
84 | _C.SOLVER.WEIGHT_DECAY = 0.0005
85 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005
86 | _C.SOLVER.MOMENTUM = 0.9
87 | _C.SOLVER.GAMMA = 0.1
88 | _C.SOLVER.WARMUP_FACTOR = 0.01
89 | _C.SOLVER.WARMUP_ITERS = 200
90 | _C.SOLVER.WARMUP_METHOD = 'linear'
91 | _C.SOLVER.CHECKPOINT_PERIOD = 200
92 | _C.SOLVER.RNG_SEED = 1
93 |
94 | # Logger
95 | _C.LOGGER = CN()
96 | _C.LOGGER.LEVEL = 20
97 | _C.LOGGER.STREAM = 'stdout'
98 |
99 | # Validation
100 | _C.VALIDATION = CN()
101 | _C.VALIDATION.VERBOSE = 200
102 | _C.VALIDATION.IS_VALIDATION = True
103 |
--------------------------------------------------------------------------------
/ret_benchmark/config/model_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | # -----------------------------------------------------------------------------
9 | # Config definition of imagenet pretrained model path
10 | # -----------------------------------------------------------------------------
11 |
12 |
13 | from yacs.config import CfgNode as CN
14 |
15 | MODEL_PATH = {
16 | 'bninception': "~/.torch/models/bn_inception-52deb4733.pth",
17 | 'resnet50': "~/.torch/models/resnet50-19c8e357.pth",
18 | }
19 |
20 | MODEL_PATH = CN(MODEL_PATH)
21 |
--------------------------------------------------------------------------------
/ret_benchmark/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .build import build_data
9 |
--------------------------------------------------------------------------------
/ret_benchmark/data/build.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from torch.utils.data import DataLoader
9 |
10 | from .collate_batch import collate_fn
11 | from .datasets import BaseDataSet
12 | from .samplers import RandomIdentitySampler
13 | from .transforms import build_transforms
14 |
15 |
16 | def build_data(cfg, is_train=True):
17 | transforms = build_transforms(cfg, is_train=is_train)
18 | if is_train:
19 | dataset = BaseDataSet(cfg.DATA.TRAIN_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE)
20 | sampler = RandomIdentitySampler(dataset=dataset,
21 | batch_size=cfg.DATA.TRAIN_BATCHSIZE,
22 | num_instances=cfg.DATA.NUM_INSTANCES,
23 | max_iters=cfg.SOLVER.MAX_ITERS
24 | )
25 | data_loader = DataLoader(dataset,
26 | collate_fn=collate_fn,
27 | batch_sampler=sampler,
28 | num_workers=cfg.DATA.NUM_WORKERS,
29 | pin_memory=True
30 | )
31 | else:
32 | dataset = BaseDataSet(cfg.DATA.TEST_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE)
33 | data_loader = DataLoader(dataset,
34 | collate_fn=collate_fn,
35 | shuffle=False,
36 | batch_size=cfg.DATA.TEST_BATCHSIZE,
37 | num_workers=cfg.DATA.NUM_WORKERS
38 | )
39 | return data_loader
40 |
--------------------------------------------------------------------------------
/ret_benchmark/data/collate_batch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 |
10 |
11 | def collate_fn(batch):
12 | imgs, labels = zip(*batch)
13 | labels = [int(k) for k in labels]
14 | labels = torch.tensor(labels, dtype=torch.int64)
15 | return torch.stack(imgs, dim=0), labels
16 |
--------------------------------------------------------------------------------
/ret_benchmark/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .base_dataset import BaseDataSet
9 |
--------------------------------------------------------------------------------
/ret_benchmark/data/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 | import re
10 | from collections import defaultdict
11 |
12 | from torch.utils.data import Dataset
13 | from ret_benchmark.utils.img_reader import read_image
14 |
15 |
16 | class BaseDataSet(Dataset):
17 | """
18 | Basic Dataset read image path from img_source
19 | img_source: list of img_path and label
20 | """
21 |
22 | def __init__(self, img_source, transforms=None, mode="RGB"):
23 | self.mode = mode
24 | self.transforms = transforms
25 | self.root = os.path.dirname(img_source)
26 | assert os.path.exists(img_source), f"{img_source} NOT found."
27 | self.img_source = img_source
28 |
29 | self.label_list = list()
30 | self.path_list = list()
31 | self._load_data()
32 | self.label_index_dict = self._build_label_index_dict()
33 |
34 | def __len__(self):
35 | return len(self.label_list)
36 |
37 | def __repr__(self):
38 | return self.__str__()
39 |
40 | def __str__(self):
41 | return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|"
42 |
43 | def _load_data(self):
44 | with open(self.img_source, 'r') as f:
45 | for line in f:
46 | _path, _label = re.split(r",| ", line.strip())
47 | self.path_list.append(_path)
48 | self.label_list.append(_label)
49 |
50 | def _build_label_index_dict(self):
51 | index_dict = defaultdict(list)
52 | for i, label in enumerate(self.label_list):
53 | index_dict[label].append(i)
54 | return index_dict
55 |
56 | def __getitem__(self, index):
57 | path = self.path_list[index]
58 | img_path = os.path.join(self.root, path)
59 | label = self.label_list[index]
60 |
61 | img = read_image(img_path, mode=self.mode)
62 | if self.transforms is not None:
63 | img = self.transforms(img)
64 | return img, label
65 |
--------------------------------------------------------------------------------
/ret_benchmark/data/evaluations/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .ret_metric import RetMetric
9 |
--------------------------------------------------------------------------------
/ret_benchmark/data/evaluations/ret_metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import numpy as np
9 |
10 |
11 | class RetMetric(object):
12 | def __init__(self, feats, labels):
13 |
14 | if len(feats) == 2 and type(feats) == list:
15 | """
16 | feats = [gallery_feats, query_feats]
17 | labels = [gallery_labels, query_labels]
18 | """
19 | self.is_equal_query = False
20 |
21 | self.gallery_feats, self.query_feats = feats
22 | self.gallery_labels, self.query_labels = labels
23 |
24 | else:
25 | self.is_equal_query = True
26 | self.gallery_feats = self.query_feats = feats
27 | self.gallery_labels = self.query_labels = labels
28 |
29 | self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
30 |
31 | def recall_k(self, k=1):
32 | m = len(self.sim_mat)
33 |
34 | match_counter = 0
35 |
36 | for i in range(m):
37 | pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
38 | neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
39 |
40 | thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)
41 |
42 | if np.sum(neg_sim > thresh) < k:
43 | match_counter += 1
44 | return float(match_counter) / m
45 |
--------------------------------------------------------------------------------
/ret_benchmark/data/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .random_identity_sampler import RandomIdentitySampler
9 |
--------------------------------------------------------------------------------
/ret_benchmark/data/samplers/random_identity_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import copy
9 | import random
10 | from collections import defaultdict
11 |
12 | import numpy as np
13 | import torch
14 | from torch.utils.data.sampler import Sampler
15 |
16 |
17 | class RandomIdentitySampler(Sampler):
18 | """
19 | Randomly sample N identities, then for each identity,
20 | randomly sample K instances, therefore batch size is N*K.
21 | Args:
22 | - dataset (BaseDataSet).
23 | - num_instances (int): number of instances per identity in a batch.
24 | - batch_size (int): number of examples in a batch.
25 | """
26 |
27 | def __init__(self, dataset, batch_size, num_instances, max_iters):
28 | self.label_index_dict = dataset.label_index_dict
29 | self.batch_size = batch_size
30 | self.K = num_instances
31 | self.num_labels_per_batch = self.batch_size // self.K
32 | self.max_iters = max_iters
33 | self.labels = list(self.label_index_dict.keys())
34 |
35 | def __len__(self):
36 | return self.max_iters
37 |
38 | def __repr__(self):
39 | return self.__str__()
40 |
41 | def __str__(self):
42 | return f"|Sampler| iters {self.max_iters}| K {self.K}| M {self.batch_size}|"
43 |
44 | def _prepare_batch(self):
45 | batch_idxs_dict = defaultdict(list)
46 |
47 | for label in self.labels:
48 | idxs = copy.deepcopy(self.label_index_dict[label])
49 | if len(idxs) < self.K:
50 | idxs.extend(np.random.choice(idxs, size=self.K - len(idxs), replace=True))
51 | random.shuffle(idxs)
52 |
53 | batch_idxs_dict[label] = [idxs[i * self.K: (i + 1) * self.K] for i in range(len(idxs) // self.K)]
54 |
55 | avai_labels = copy.deepcopy(self.labels)
56 | return batch_idxs_dict, avai_labels
57 |
58 | def __iter__(self):
59 | batch_idxs_dict, avai_labels = self._prepare_batch()
60 | for _ in range(self.max_iters):
61 | batch = []
62 | if len(avai_labels) < self.num_labels_per_batch:
63 | batch_idxs_dict, avai_labels = self._prepare_batch()
64 |
65 | selected_labels = random.sample(avai_labels, self.num_labels_per_batch)
66 | for label in selected_labels:
67 | batch_idxs = batch_idxs_dict[label].pop(0)
68 | batch.extend(batch_idxs)
69 | if len(batch_idxs_dict[label]) == 0:
70 | avai_labels.remove(label)
71 | yield batch
72 |
--------------------------------------------------------------------------------
/ret_benchmark/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .build import build_transforms
9 |
--------------------------------------------------------------------------------
/ret_benchmark/data/transforms/build.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import torchvision.transforms as T
9 |
10 |
11 | def build_transforms(cfg, is_train=True):
12 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN,
13 | std=cfg.INPUT.PIXEL_STD)
14 | if is_train:
15 | transform = T.Compose([
16 | T.Resize(size=cfg.INPUT.ORIGIN_SIZE),
17 | T.RandomResizedCrop(
18 | scale=cfg.INPUT.CROP_SCALE,
19 | size=cfg.INPUT.CROP_SIZE
20 | ),
21 | T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB),
22 | T.ToTensor(),
23 | normalize_transform,
24 | ])
25 | else:
26 | transform = T.Compose([
27 | T.Resize(size=cfg.INPUT.ORIGIN_SIZE),
28 | T.CenterCrop(cfg.INPUT.CROP_SIZE),
29 | T.ToTensor(),
30 | normalize_transform
31 | ])
32 | return transform
33 |
--------------------------------------------------------------------------------
/ret_benchmark/engine/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .trainer import do_train
9 |
--------------------------------------------------------------------------------
/ret_benchmark/engine/trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import datetime
9 | import time
10 |
11 | import numpy as np
12 | import torch
13 |
14 | from ret_benchmark.data.evaluations import RetMetric
15 | from ret_benchmark.utils.feat_extractor import feat_extractor
16 | from ret_benchmark.utils.freeze_bn import set_bn_eval
17 | from ret_benchmark.utils.metric_logger import MetricLogger
18 |
19 |
20 | def do_train(
21 | cfg,
22 | model,
23 | train_loader,
24 | val_loader,
25 | optimizer,
26 | scheduler,
27 | criterion,
28 | checkpointer,
29 | device,
30 | checkpoint_period,
31 | arguments,
32 | logger
33 | ):
34 | logger.info("Start training")
35 | meters = MetricLogger(delimiter=" ")
36 | max_iter = len(train_loader)
37 |
38 | start_iter = arguments["iteration"]
39 | best_iteration = -1
40 | best_recall = 0
41 |
42 | start_training_time = time.time()
43 | end = time.time()
44 | for iteration, (images, targets) in enumerate(train_loader, start_iter):
45 |
46 | if iteration % cfg.VALIDATION.VERBOSE == 0 or iteration == max_iter:
47 | model.eval()
48 | logger.info('Validation')
49 | labels = val_loader.dataset.label_list
50 | labels = np.array([int(k) for k in labels])
51 | feats = feat_extractor(model, val_loader, logger=logger)
52 |
53 | ret_metric = RetMetric(feats=feats, labels=labels)
54 | recall_curr = ret_metric.recall_k(1)
55 |
56 | if recall_curr > best_recall:
57 | best_recall = recall_curr
58 | best_iteration = iteration
59 | logger.info(f'Best iteration {iteration}: recall@1: {best_recall:.3f}')
60 | checkpointer.save(f"best_model")
61 | else:
62 | logger.info(f'Recall@1 at iteration {iteration:06d}: {recall_curr:.3f}')
63 |
64 | model.train()
65 | model.apply(set_bn_eval)
66 |
67 | data_time = time.time() - end
68 | iteration = iteration + 1
69 | arguments["iteration"] = iteration
70 |
71 | scheduler.step()
72 |
73 | images = images.to(device)
74 | targets = torch.stack([target.to(device) for target in targets])
75 |
76 | feats = model(images)
77 | loss = criterion(feats, targets)
78 | optimizer.zero_grad()
79 | loss.backward()
80 | optimizer.step()
81 |
82 | batch_time = time.time() - end
83 | end = time.time()
84 | meters.update(time=batch_time, data=data_time, loss=loss.item())
85 |
86 | eta_seconds = meters.time.global_avg * (max_iter - iteration)
87 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
88 |
89 | if iteration % 20 == 0 or iteration == max_iter:
90 | logger.info(
91 | meters.delimiter.join(
92 | [
93 | "eta: {eta}",
94 | "iter: {iter}",
95 | "{meters}",
96 | "lr: {lr:.6f}",
97 | "max mem: {memory:.1f} GB",
98 | ]
99 | ).format(
100 | eta=eta_string,
101 | iter=iteration,
102 | meters=str(meters),
103 | lr=optimizer.param_groups[0]["lr"],
104 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0,
105 | )
106 | )
107 |
108 | if iteration % checkpoint_period == 0:
109 | checkpointer.save("model_{:06d}".format(iteration))
110 |
111 | total_training_time = time.time() - start_training_time
112 | total_time_str = str(datetime.timedelta(seconds=total_training_time))
113 | logger.info(
114 | "Total training time: {} ({:.4f} s / it)".format(
115 | total_time_str, total_training_time / (max_iter)
116 | )
117 | )
118 |
119 | logger.info(f"Best iteration: {best_iteration :06d} | best recall {best_recall} ")
120 |
--------------------------------------------------------------------------------
/ret_benchmark/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .build import build_loss
9 |
--------------------------------------------------------------------------------
/ret_benchmark/losses/build.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .multi_similarity_loss import MultiSimilarityLoss
9 | from .margin_loss import MarginLoss
10 | from .registry import LOSS
11 |
12 |
13 | def build_loss(cfg):
14 | loss_name = cfg.LOSSES.NAME
15 | assert loss_name in LOSS, \
16 | f'loss name {loss_name} is not registered in registry :{LOSS.keys()}'
17 | return LOSS[loss_name](cfg)
18 |
--------------------------------------------------------------------------------
/ret_benchmark/losses/margin_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | from ret_benchmark.losses.registry import LOSS
7 |
8 |
9 | class DistanceWeightedSampling(object):
10 | """
11 | """
12 | def __init__(self, cfg):
13 | super(DistanceWeightedSampling, self).__init__()
14 | self.cutoff = cfg.LOSSES.MARGIN_LOSS.CUTOFF
15 | self.upper_cutoff = cfg.LOSSES.MARGIN_LOSS.UPPER_CUTOFF
16 |
17 | def sample(self, batch, labels):
18 |
19 | if isinstance(labels, torch.Tensor):
20 | labels = labels.detach().cpu().numpy()
21 | bs = batch.shape[0]
22 | distances = self.p_dist(batch.detach()).clamp(min=self.cutoff)
23 |
24 | positives, negatives = [], []
25 |
26 | for i in range(bs):
27 | pos = labels == labels[i]
28 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i])
29 | # sample positives randomly
30 | pos[i] = 0
31 | positives.append(np.random.choice(np.where(pos)[0]))
32 | # sample negatives by distance
33 | negatives.append(np.random.choice(bs, p=q_d_inv))
34 |
35 | sampled_triplets = [[a, p, n] for a, p, n in zip(list(range(bs)), positives, negatives)]
36 | return sampled_triplets
37 |
38 | @staticmethod
39 | def p_dist(A, eps=1e-4):
40 | prod = torch.mm(A, A.t())
41 | norm = prod.diag().unsqueeze(1).expand_as(prod)
42 | res = (norm + norm.t() - 2 * prod).clamp(min=0)
43 | return res.clamp(min=eps).sqrt()
44 |
45 | def inverse_sphere_distances(self, batch, dist, labels, anchor_label):
46 | bs, dim = len(dist), batch.shape[-1]
47 | # negated log-distribution of distances of unit sphere in dimension
48 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dist) - (float(dim-3) / 2)
49 | * torch.log(1.0 - 0.25 * (dist.pow(2))))
50 | # set sampling probabilities of positives to zero
51 | log_q_d_inv[np.where(labels == anchor_label)[0]] = 0
52 |
53 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability
54 | # set sampling probabilities of positives to zero
55 | q_d_inv[np.where(labels == anchor_label)[0]] = 0
56 |
57 | # NOTE: Cutting of values with high distances made the results slightly worse.
58 | # q_d_inv[np.where(dist > self.upper_cutoff)[0]] = 0
59 |
60 | q_d_inv = q_d_inv/q_d_inv.sum()
61 | return q_d_inv.detach().cpu().numpy()
62 |
63 |
64 | @LOSS.register("margin_loss")
65 | class MarginLoss(nn.Module):
66 | """Margin based loss with DistanceWeightedSampling
67 | """
68 | def __init__(self, cfg):
69 | super(MarginLoss, self).__init__()
70 | self.beta_val = 1.2
71 | self.margin = 0.2
72 | self.nu = 0.0
73 | self.n_classes = cfg.LOSSES.MARGIN_LOSS.N_CLASSES
74 | self.beta_constant = cfg.LOSSES.MARGIN_LOSS.BETA_CONSTANT
75 | if self.beta_constant:
76 | self.beta = self.beta_val
77 | else:
78 | self.beta = torch.nn.Parameter(torch.ones(self.n_classes)*self.beta_val)
79 | self.sampler = DistanceWeightedSampling(cfg)
80 |
81 | def forward(self, batch, labels):
82 | if isinstance(labels, torch.Tensor):
83 | labels = labels.detach().cpu().numpy()
84 | sampled_triplets = self.sampler.sample(batch, labels)
85 |
86 | # compute distances between anchor-positive and anchor-negative.
87 | d_ap, d_an = [], []
88 | for triplet in sampled_triplets:
89 | train_triplet = {'Anchor': batch[triplet[0], :],
90 | 'Positive': batch[triplet[1], :], 'Negative': batch[triplet[2]]}
91 | pos_dist = ((train_triplet['Anchor']-train_triplet['Positive']).pow(2).sum()+1e-8).pow(1/2)
92 | neg_dist = ((train_triplet['Anchor']-train_triplet['Negative']).pow(2).sum()+1e-8).pow(1/2)
93 |
94 | d_ap.append(pos_dist)
95 | d_an.append(neg_dist)
96 | d_ap, d_an = torch.stack(d_ap), torch.stack(d_an)
97 |
98 | # group betas together by anchor class in sampled triplets (as each beta belongs to one class).
99 | if self.beta_constant:
100 | beta = self.beta
101 | else:
102 | beta = torch.stack([self.beta[labels[triplet[0]]] for
103 | triplet in sampled_triplets]).type(torch.cuda.FloatTensor)
104 | # compute actual margin positive and margin negative loss
105 | pos_loss = F.relu(d_ap-beta+self.margin)
106 | neg_loss = F.relu(beta-d_an+self.margin)
107 |
108 | # compute normalization constant
109 | pair_count = torch.sum((pos_loss > 0.)+(neg_loss > 0.)).type(torch.cuda.FloatTensor)
110 | # actual Margin Loss
111 | loss = torch.sum(pos_loss+neg_loss) if pair_count == 0. else torch.sum(pos_loss+neg_loss)/pair_count
112 |
113 | # (Optional) Add regularization penalty on betas.
114 | # if self.nu: loss = loss + beta_regularisation_loss.type(torch.cuda.FloatTensor)
115 | return loss
116 |
--------------------------------------------------------------------------------
/ret_benchmark/losses/multi_similarity_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | from torch import nn
10 |
11 | from ret_benchmark.losses.registry import LOSS
12 |
13 |
14 | @LOSS.register('ms_loss')
15 | class MultiSimilarityLoss(nn.Module):
16 | def __init__(self, cfg):
17 | super(MultiSimilarityLoss, self).__init__()
18 | self.thresh = 0.5
19 | self.margin = 0.1
20 |
21 | self.scale_pos = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS
22 | self.scale_neg = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG
23 |
24 | def forward(self, feats, labels):
25 | assert feats.size(0) == labels.size(0), \
26 | f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"
27 | batch_size = feats.size(0)
28 | sim_mat = torch.matmul(feats, torch.t(feats))
29 |
30 | epsilon = 1e-5
31 | loss = list()
32 |
33 | for i in range(batch_size):
34 | pos_pair_ = sim_mat[i][labels == labels[i]]
35 | pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
36 | neg_pair_ = sim_mat[i][labels != labels[i]]
37 |
38 | neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
39 | pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]
40 |
41 | if len(neg_pair) < 1 or len(pos_pair) < 1:
42 | continue
43 |
44 | # weighting step
45 | pos_loss = 1.0 / self.scale_pos * torch.log(
46 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
47 | neg_loss = 1.0 / self.scale_neg * torch.log(
48 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
49 | loss.append(pos_loss + neg_loss)
50 |
51 | if len(loss) == 0:
52 | return torch.zeros([], requires_grad=True)
53 |
54 | loss = sum(loss) / batch_size
55 | return loss
56 |
--------------------------------------------------------------------------------
/ret_benchmark/losses/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from ret_benchmark.utils.registry import Registry
9 |
10 | LOSS = Registry()
11 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .backbone import build_backbone
9 | from .build import build_model
10 | from .heads import build_head
11 | from .registry import BACKBONES, HEADS
12 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_backbone
2 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/backbone/bninception.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from ret_benchmark.modeling import registry
8 |
9 | @registry.BACKBONES.register('bninception')
10 | class BNInception(nn.Module):
11 |
12 | def __init__(self):
13 | super(BNInception, self).__init__()
14 | inplace = True
15 | self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
16 | self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True)
17 | self.conv1_relu_7x7 = nn.ReLU(inplace)
18 | self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
19 | self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
20 | self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
21 | self.conv2_relu_3x3_reduce = nn.ReLU(inplace)
22 | self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
23 | self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True)
24 | self.conv2_relu_3x3 = nn.ReLU(inplace)
25 | self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
26 | self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
27 | self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True)
28 | self.inception_3a_relu_1x1 = nn.ReLU(inplace)
29 | self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
30 | self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
31 | self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace)
32 | self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
33 | self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True)
34 | self.inception_3a_relu_3x3 = nn.ReLU(inplace)
35 | self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
36 | self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
37 | self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace)
38 | self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
39 | self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
40 | self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace)
41 | self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
42 | self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
43 | self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace)
44 | self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
45 | self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
46 | self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True)
47 | self.inception_3a_relu_pool_proj = nn.ReLU(inplace)
48 | self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
49 | self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True)
50 | self.inception_3b_relu_1x1 = nn.ReLU(inplace)
51 | self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
52 | self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
53 | self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace)
54 | self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
55 | self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True)
56 | self.inception_3b_relu_3x3 = nn.ReLU(inplace)
57 | self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
58 | self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
59 | self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace)
60 | self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
61 | self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
62 | self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace)
63 | self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
64 | self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
65 | self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace)
66 | self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
67 | self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
68 | self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True)
69 | self.inception_3b_relu_pool_proj = nn.ReLU(inplace)
70 | self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1))
71 | self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
72 | self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace)
73 | self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
74 | self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True)
75 | self.inception_3c_relu_3x3 = nn.ReLU(inplace)
76 | self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1))
77 | self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
78 | self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace)
79 | self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
80 | self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
81 | self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace)
82 | self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
83 | self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
84 | self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace)
85 | self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
86 | self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1))
87 | self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True)
88 | self.inception_4a_relu_1x1 = nn.ReLU(inplace)
89 | self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1))
90 | self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
91 | self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace)
92 | self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
93 | self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True)
94 | self.inception_4a_relu_3x3 = nn.ReLU(inplace)
95 | self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
96 | self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
97 | self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace)
98 | self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
99 | self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
100 | self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace)
101 | self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
102 | self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
103 | self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace)
104 | self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
105 | self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
106 | self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
107 | self.inception_4a_relu_pool_proj = nn.ReLU(inplace)
108 | self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1))
109 | self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True)
110 | self.inception_4b_relu_1x1 = nn.ReLU(inplace)
111 | self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
112 | self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
113 | self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace)
114 | self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
115 | self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True)
116 | self.inception_4b_relu_3x3 = nn.ReLU(inplace)
117 | self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
118 | self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
119 | self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace)
120 | self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
121 | self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
122 | self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace)
123 | self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
124 | self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
125 | self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace)
126 | self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
127 | self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
128 | self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
129 | self.inception_4b_relu_pool_proj = nn.ReLU(inplace)
130 | self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
131 | self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True)
132 | self.inception_4c_relu_1x1 = nn.ReLU(inplace)
133 | self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
134 | self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
135 | self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace)
136 | self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
137 | self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True)
138 | self.inception_4c_relu_3x3 = nn.ReLU(inplace)
139 | self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
140 | self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
141 | self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace)
142 | self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
143 | self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True)
144 | self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace)
145 | self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
146 | self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True)
147 | self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace)
148 | self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
149 | self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
150 | self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
151 | self.inception_4c_relu_pool_proj = nn.ReLU(inplace)
152 | self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1))
153 | self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True)
154 | self.inception_4d_relu_1x1 = nn.ReLU(inplace)
155 | self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
156 | self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
157 | self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace)
158 | self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
159 | self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True)
160 | self.inception_4d_relu_3x3 = nn.ReLU(inplace)
161 | self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1))
162 | self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
163 | self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace)
164 | self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
165 | self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True)
166 | self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace)
167 | self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
168 | self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True)
169 | self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace)
170 | self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
171 | self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
172 | self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
173 | self.inception_4d_relu_pool_proj = nn.ReLU(inplace)
174 | self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
175 | self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
176 | self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace)
177 | self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
178 | self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True)
179 | self.inception_4e_relu_3x3 = nn.ReLU(inplace)
180 | self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1))
181 | self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
182 | self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace)
183 | self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
184 | self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True)
185 | self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace)
186 | self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
187 | self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True)
188 | self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace)
189 | self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
190 | self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1))
191 | self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True)
192 | self.inception_5a_relu_1x1 = nn.ReLU(inplace)
193 | self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1))
194 | self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
195 | self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace)
196 | self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
197 | self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True)
198 | self.inception_5a_relu_3x3 = nn.ReLU(inplace)
199 | self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1))
200 | self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
201 | self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace)
202 | self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
203 | self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
204 | self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace)
205 | self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
206 | self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
207 | self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace)
208 | self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
209 | self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1))
210 | self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
211 | self.inception_5a_relu_pool_proj = nn.ReLU(inplace)
212 | self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1))
213 | self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True)
214 | self.inception_5b_relu_1x1 = nn.ReLU(inplace)
215 | self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
216 | self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
217 | self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace)
218 | self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
219 | self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True)
220 | self.inception_5b_relu_3x3 = nn.ReLU(inplace)
221 | self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
222 | self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
223 | self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace)
224 | self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
225 | self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
226 | self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace)
227 | self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
228 | self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
229 | self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace)
230 | self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True)
231 | self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1))
232 | self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
233 | self.inception_5b_relu_pool_proj = nn.ReLU(inplace)
234 |
235 | def features(self, input):
236 | conv1_7x7_s2_out = self.conv1_7x7_s2(input)
237 | conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out)
238 | conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out)
239 | pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out)
240 | conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out)
241 | conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out)
242 | conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out)
243 | conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out)
244 | conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out)
245 | conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out)
246 | pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out)
247 | inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out)
248 | inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out)
249 | inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out)
250 | inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out)
251 | inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out)
252 | inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out)
253 | inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out)
254 | inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out)
255 | inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out)
256 | inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out)
257 | inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(
258 | inception_3a_double_3x3_reduce_out)
259 | inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(
260 | inception_3a_double_3x3_reduce_bn_out)
261 | inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out)
262 | inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out)
263 | inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out)
264 | inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out)
265 | inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out)
266 | inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out)
267 | inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out)
268 | inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out)
269 | inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out)
270 | inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out)
271 | inception_3a_output_out = torch.cat(
272 | [inception_3a_relu_1x1_out, inception_3a_relu_3x3_out, inception_3a_relu_double_3x3_2_out,
273 | inception_3a_relu_pool_proj_out], 1)
274 | inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out)
275 | inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out)
276 | inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out)
277 | inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out)
278 | inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out)
279 | inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out)
280 | inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out)
281 | inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out)
282 | inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out)
283 | inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out)
284 | inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(
285 | inception_3b_double_3x3_reduce_out)
286 | inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(
287 | inception_3b_double_3x3_reduce_bn_out)
288 | inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out)
289 | inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out)
290 | inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out)
291 | inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out)
292 | inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out)
293 | inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out)
294 | inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out)
295 | inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out)
296 | inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out)
297 | inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out)
298 | inception_3b_output_out = torch.cat(
299 | [inception_3b_relu_1x1_out, inception_3b_relu_3x3_out, inception_3b_relu_double_3x3_2_out,
300 | inception_3b_relu_pool_proj_out], 1)
301 | inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out)
302 | inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out)
303 | inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out)
304 | inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out)
305 | inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out)
306 | inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out)
307 | inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out)
308 | inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(
309 | inception_3c_double_3x3_reduce_out)
310 | inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(
311 | inception_3c_double_3x3_reduce_bn_out)
312 | inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out)
313 | inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out)
314 | inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out)
315 | inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out)
316 | inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out)
317 | inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out)
318 | inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out)
319 | inception_3c_output_out = torch.cat(
320 | [inception_3c_relu_3x3_out, inception_3c_relu_double_3x3_2_out, inception_3c_pool_out], 1)
321 | inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out)
322 | inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out)
323 | inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out)
324 | inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out)
325 | inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out)
326 | inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out)
327 | inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out)
328 | inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out)
329 | inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out)
330 | inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out)
331 | inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(
332 | inception_4a_double_3x3_reduce_out)
333 | inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(
334 | inception_4a_double_3x3_reduce_bn_out)
335 | inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out)
336 | inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out)
337 | inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out)
338 | inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out)
339 | inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out)
340 | inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out)
341 | inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out)
342 | inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out)
343 | inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out)
344 | inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out)
345 | inception_4a_output_out = torch.cat(
346 | [inception_4a_relu_1x1_out, inception_4a_relu_3x3_out, inception_4a_relu_double_3x3_2_out,
347 | inception_4a_relu_pool_proj_out], 1)
348 | inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out)
349 | inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out)
350 | inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out)
351 | inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out)
352 | inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out)
353 | inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out)
354 | inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out)
355 | inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out)
356 | inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out)
357 | inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out)
358 | inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(
359 | inception_4b_double_3x3_reduce_out)
360 | inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(
361 | inception_4b_double_3x3_reduce_bn_out)
362 | inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out)
363 | inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out)
364 | inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out)
365 | inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out)
366 | inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out)
367 | inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out)
368 | inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out)
369 | inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out)
370 | inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out)
371 | inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out)
372 | inception_4b_output_out = torch.cat(
373 | [inception_4b_relu_1x1_out, inception_4b_relu_3x3_out, inception_4b_relu_double_3x3_2_out,
374 | inception_4b_relu_pool_proj_out], 1)
375 | inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out)
376 | inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out)
377 | inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out)
378 | inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out)
379 | inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out)
380 | inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out)
381 | inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out)
382 | inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out)
383 | inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out)
384 | inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out)
385 | inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(
386 | inception_4c_double_3x3_reduce_out)
387 | inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(
388 | inception_4c_double_3x3_reduce_bn_out)
389 | inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out)
390 | inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out)
391 | inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out)
392 | inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out)
393 | inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out)
394 | inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out)
395 | inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out)
396 | inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out)
397 | inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out)
398 | inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out)
399 | inception_4c_output_out = torch.cat(
400 | [inception_4c_relu_1x1_out, inception_4c_relu_3x3_out, inception_4c_relu_double_3x3_2_out,
401 | inception_4c_relu_pool_proj_out], 1)
402 | inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out)
403 | inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out)
404 | inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out)
405 | inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out)
406 | inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out)
407 | inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out)
408 | inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out)
409 | inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out)
410 | inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out)
411 | inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out)
412 | inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(
413 | inception_4d_double_3x3_reduce_out)
414 | inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(
415 | inception_4d_double_3x3_reduce_bn_out)
416 | inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out)
417 | inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out)
418 | inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out)
419 | inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out)
420 | inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out)
421 | inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out)
422 | inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out)
423 | inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out)
424 | inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out)
425 | inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out)
426 | inception_4d_output_out = torch.cat(
427 | [inception_4d_relu_1x1_out, inception_4d_relu_3x3_out, inception_4d_relu_double_3x3_2_out,
428 | inception_4d_relu_pool_proj_out], 1)
429 | inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out)
430 | inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out)
431 | inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out)
432 | inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out)
433 | inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out)
434 | inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out)
435 | inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out)
436 | inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(
437 | inception_4e_double_3x3_reduce_out)
438 | inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(
439 | inception_4e_double_3x3_reduce_bn_out)
440 | inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out)
441 | inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out)
442 | inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out)
443 | inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out)
444 | inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out)
445 | inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out)
446 | inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out)
447 | inception_4e_output_out = torch.cat(
448 | [inception_4e_relu_3x3_out, inception_4e_relu_double_3x3_2_out, inception_4e_pool_out], 1)
449 | inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out)
450 | inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out)
451 | inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out)
452 | inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out)
453 | inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out)
454 | inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out)
455 | inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out)
456 | inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out)
457 | inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out)
458 | inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out)
459 | inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(
460 | inception_5a_double_3x3_reduce_out)
461 | inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(
462 | inception_5a_double_3x3_reduce_bn_out)
463 | inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out)
464 | inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out)
465 | inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out)
466 | inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out)
467 | inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out)
468 | inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out)
469 | inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out)
470 | inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out)
471 | inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out)
472 | inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out)
473 | inception_5a_output_out = torch.cat(
474 | [inception_5a_relu_1x1_out, inception_5a_relu_3x3_out, inception_5a_relu_double_3x3_2_out,
475 | inception_5a_relu_pool_proj_out], 1)
476 | inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out)
477 | inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out)
478 | inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out)
479 | inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out)
480 | inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out)
481 | inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out)
482 | inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out)
483 | inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out)
484 | inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out)
485 | inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out)
486 | inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(
487 | inception_5b_double_3x3_reduce_out)
488 | inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(
489 | inception_5b_double_3x3_reduce_bn_out)
490 | inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out)
491 | inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out)
492 | inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out)
493 | inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out)
494 | inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out)
495 | inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out)
496 | inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out)
497 | inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out)
498 | inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out)
499 | inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out)
500 | inception_5b_output_out = torch.cat(
501 | [inception_5b_relu_1x1_out, inception_5b_relu_3x3_out, inception_5b_relu_double_3x3_2_out,
502 | inception_5b_relu_pool_proj_out], 1)
503 | return inception_5b_output_out
504 |
505 | def logits(self, features):
506 | x = F.adaptive_max_pool2d(features, output_size=1)
507 | x = x.view(x.size(0), -1)
508 | return x
509 |
510 | def forward(self, input):
511 | x = self.features(input)
512 | x = self.logits(x)
513 | return x
514 |
515 | def load_param(self, model_path):
516 | param_dict = torch.load(model_path)
517 | for i in param_dict:
518 | if 'last_linear' in i:
519 | continue
520 | self.state_dict()[i].copy_(param_dict[i])
521 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/backbone/build.py:
--------------------------------------------------------------------------------
1 | from ret_benchmark.modeling.registry import BACKBONES
2 |
3 | from .bninception import BNInception
4 | from .resnet import ResNet50
5 |
6 |
7 | def build_backbone(cfg):
8 | assert cfg.MODEL.BACKBONE.NAME in BACKBONES, \
9 | f"backbone {cfg.MODEL.BACKBONE} is not registered in registry : {BACKBONES.keys()}"
10 | return BACKBONES[cfg.MODEL.BACKBONE.NAME]()
11 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.models as models
6 | from ret_benchmark.modeling import registry
7 |
8 |
9 | @registry.BACKBONES.register('resnet50')
10 | class ResNet50(nn.Module):
11 |
12 | def __init__(self):
13 | super(ResNet50, self).__init__()
14 | self.model = models.resnet50(pretrained=True)
15 |
16 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
17 | module.eval()
18 | module.train = lambda _: None
19 |
20 | def forward(self, x):
21 | x = self.model.conv1(x)
22 | x = self.model.bn1(x)
23 | x = self.model.relu(x)
24 | x = self.model.maxpool(x)
25 |
26 | x = self.model.layer1(x)
27 | x = self.model.layer2(x)
28 | x = self.model.layer3(x)
29 | x = self.model.layer4(x)
30 |
31 | x = self.model.avgpool(x)
32 | x = x.view(x.size(0), -1)
33 | # x = self.model.fc(x) --remove
34 | return x
35 |
36 | def load_param(self, model_path):
37 | param_dict = torch.load(model_path)
38 | for i in param_dict:
39 | if 'last_linear' in i:
40 | continue
41 | self.model.state_dict()[i].copy_(param_dict[i])
42 |
43 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/build.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import os
10 | from collections import OrderedDict
11 |
12 | import torch
13 | from torch.nn.modules import Sequential
14 |
15 | from .backbone import build_backbone
16 | from .heads import build_head
17 |
18 |
19 | def build_model(cfg):
20 | backbone = build_backbone(cfg)
21 | head = build_head(cfg)
22 |
23 | model = Sequential(OrderedDict([
24 | ('backbone', backbone),
25 | ('head', head)
26 | ]))
27 |
28 | if cfg.MODEL.PRETRAIN == 'imagenet':
29 | print('Loading imagenet pretrianed model ...')
30 | pretrained_path = os.path.expanduser(cfg.MODEL.PRETRIANED_PATH[cfg.MODEL.BACKBONE.NAME])
31 | model.backbone.load_param(pretrained_path)
32 | elif os.path.exists(cfg.MODEL.PRETRAIN):
33 | ckp = torch.load(cfg.MODEL.PRETRAIN)
34 | model.load_state_dict(ckp['model'])
35 | return model
36 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from .build import build_head
9 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/heads/build.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | from ret_benchmark.modeling.registry import HEADS
9 |
10 | from .linear_norm import LinearNorm
11 |
12 |
13 | def build_head(cfg):
14 | assert cfg.MODEL.HEAD.NAME in HEADS, f"head {cfg.MODEL.HEAD.NAME} is not defined"
15 | return HEADS[cfg.MODEL.HEAD.NAME](cfg, in_channels=1024 if cfg.MODEL.BACKBONE.NAME == 'bninception' else 2048)
16 |
17 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/heads/linear_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | from torch import nn
10 |
11 | from ret_benchmark.modeling.registry import HEADS
12 | from ret_benchmark.utils.init_methods import weights_init_kaiming
13 |
14 |
15 | @HEADS.register('linear_norm')
16 | class LinearNorm(nn.Module):
17 | def __init__(self, cfg, in_channels):
18 | super(LinearNorm, self).__init__()
19 | self.fc = nn.Linear(in_channels, cfg.MODEL.HEAD.DIM)
20 | self.fc.apply(weights_init_kaiming)
21 |
22 | def forward(self, x):
23 | x = self.fc(x)
24 | x = nn.functional.normalize(x, p=2, dim=1)
25 | return x
26 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 |
9 | from ret_benchmark.utils.registry import Registry
10 |
11 | BACKBONES = Registry()
12 | HEADS = Registry()
13 |
--------------------------------------------------------------------------------
/ret_benchmark/modeling/xbm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | import tqdm
10 | from ret_benchmark.data.build import build_memory_data
11 |
12 |
13 | class XBM:
14 | def __init__(self, cfg, model):
15 | self.ratio = cfg.MEMORY.RATIO
16 | # init memory
17 | self.feats = list()
18 | self.labels = list()
19 | self.indices = list()
20 | model.train()
21 | for images, labels, indices in build_memory_data(cfg):
22 | with torch.no_grad():
23 | feat = model(images.cuda())
24 | self.feats.append(feat)
25 | self.labels.append(labels.cuda())
26 | self.indices.append(indices.cuda())
27 | self.feats = torch.cat(self.feats, dim=0)
28 | self.labels = torch.cat(self.labels, dim=0)
29 | self.indices = torch.cat(self.indices, dim=0)
30 | # if memory_ratio != 1.0 -> random sample init queue_mask to mimic fixed queue size
31 | if self.ratio != 1.0:
32 | rand_init_idx = torch.randperm(int(self.indices.shape[0] * self.ratio)).cuda()
33 | self.queue_mask = self.indices[rand_init_idx]
34 |
35 | def enqueue_dequeue(self, feats, indices):
36 | self.feats.data[indices] = feats
37 | if self.ratio != 1.0:
38 | # enqueue
39 | self.queue_mask = torch.cat((self.queue_mask, indices.cuda()), dim=0)
40 | # dequeue
41 | self.queue_mask = self.queue_mask[-int(self.indices.shape[0] * self.ratio):]
42 |
43 | def get(self):
44 | if self.ratio != 1.0:
45 | return self.feats[self.queue_mask], self.labels[self.queue_mask]
46 | else:
47 | return self.feats, self.labels
48 |
--------------------------------------------------------------------------------
/ret_benchmark/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | from .build import build_optimizer
3 | from .build import build_lr_scheduler
4 | from .lr_scheduler import WarmupMultiStepLR
5 |
--------------------------------------------------------------------------------
/ret_benchmark/solver/build.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .lr_scheduler import WarmupMultiStepLR
4 |
5 |
6 | def build_optimizer(cfg, model):
7 | params = []
8 | for key, value in model.named_parameters():
9 | if not value.requires_grad:
10 | continue
11 | lr_mul = 1.0
12 | if "backbone" in key:
13 | lr_mul = 0.1
14 | params += [{"params": [value], "lr_mul": lr_mul}]
15 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params,
16 | lr=cfg.SOLVER.BASE_LR,
17 | weight_decay=cfg.SOLVER.WEIGHT_DECAY)
18 | return optimizer
19 |
20 |
21 | def build_lr_scheduler(cfg, optimizer):
22 | return WarmupMultiStepLR(
23 | optimizer,
24 | cfg.SOLVER.STEPS,
25 | cfg.SOLVER.GAMMA,
26 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
27 | warmup_iters=cfg.SOLVER.WARMUP_ITERS,
28 | warmup_method=cfg.SOLVER.WARMUP_METHOD,
29 | )
30 |
--------------------------------------------------------------------------------
/ret_benchmark/solver/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from bisect import bisect_right
2 |
3 | import torch
4 |
5 |
6 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
7 | def __init__(
8 | self,
9 | optimizer,
10 | milestones,
11 | gamma=0.1,
12 | warmup_factor=1.0 / 3,
13 | warmup_iters=500,
14 | warmup_method="linear",
15 | last_epoch=-1,
16 | ):
17 | if not list(milestones) == sorted(milestones):
18 | raise ValueError(
19 | "Milestones should be a list of" " increasing integers. Got {}",
20 | milestones,
21 | )
22 |
23 | if warmup_method not in ("constant", "linear"):
24 | raise ValueError(
25 | "Only 'constant' or 'linear' warmup_method accepted"
26 | "got {}".format(warmup_method)
27 | )
28 | self.milestones = milestones
29 | self.gamma = gamma
30 | self.warmup_factor = warmup_factor
31 | self.warmup_iters = warmup_iters
32 | self.warmup_method = warmup_method
33 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
34 |
35 | def get_lr(self):
36 | warmup_factor = 1
37 | if self.last_epoch < self.warmup_iters:
38 | if self.warmup_method == "constant":
39 | warmup_factor = self.warmup_factor
40 | elif self.warmup_method == "linear":
41 | alpha = float(self.last_epoch) / self.warmup_iters
42 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
43 | return [
44 | base_lr * warmup_factor * self.gamma ** bisect_right(
45 | self.milestones,
46 | self.last_epoch
47 | )
48 | for base_lr in self.base_lrs
49 | ]
50 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import logging
3 | import os
4 |
5 | import torch
6 | from ret_benchmark.utils.model_serialization import load_state_dict
7 |
8 |
9 | class Checkpointer(object):
10 | def __init__(
11 | self,
12 | model,
13 | optimizer=None,
14 | scheduler=None,
15 | save_dir="",
16 | save_to_disk=None,
17 | logger=None,
18 | ):
19 | self.model = model
20 | self.optimizer = optimizer
21 | self.scheduler = scheduler
22 | self.save_dir = save_dir
23 | self.save_to_disk = save_to_disk
24 | if logger is None:
25 | logger = logging.getLogger(__name__)
26 | self.logger = logger
27 |
28 | def save(self, name):
29 | if not self.save_dir:
30 | return
31 |
32 | data = {}
33 | data["model"] = self.model.state_dict()
34 | if self.optimizer is not None:
35 | data["optimizer"] = self.optimizer.state_dict()
36 | if self.scheduler is not None:
37 | data["scheduler"] = self.scheduler.state_dict()
38 |
39 | save_file = os.path.join(self.save_dir, "{}.pth".format(name))
40 | self.logger.info("Saving checkpoint to {}".format(save_file))
41 | torch.save(data, save_file)
42 |
43 | def load(self, f=None):
44 | if self.has_checkpoint():
45 | # override argument with existing checkpoint
46 | f = self.get_checkpoint_file()
47 | if not f:
48 | # no checkpoint could be found
49 | self.logger.info("No checkpoint found. Initializing model from scratch")
50 | return {}
51 | self.logger.info("Loading checkpoint from {}".format(f))
52 | checkpoint = self._load_file(f)
53 | self._load_model(checkpoint)
54 | if "optimizer" in checkpoint and self.optimizer:
55 | self.logger.info("Loading optimizer from {}".format(f))
56 | self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
57 | if "scheduler" in checkpoint and self.scheduler:
58 | self.logger.info("Loading scheduler from {}".format(f))
59 | self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
60 |
61 | # return any further checkpoint data
62 | return checkpoint
63 |
64 | def has_checkpoint(self):
65 | save_file = os.path.join(self.save_dir, "last_checkpoint")
66 | return os.path.exists(save_file)
67 |
68 | def get_checkpoint_file(self):
69 | save_file = os.path.join(self.save_dir, "last_checkpoint")
70 | try:
71 | with open(save_file, "r") as f:
72 | last_saved = f.read()
73 | last_saved = last_saved.strip()
74 | except IOError:
75 | # if file doesn't exist, maybe because it has just been
76 | # deleted by a separate process
77 | last_saved = ""
78 | return last_saved
79 |
80 | def tag_last_checkpoint(self, last_filename):
81 | save_file = os.path.join(self.save_dir, "last_checkpoint")
82 | with open(save_file, "w") as f:
83 | f.write(last_filename)
84 |
85 | def _load_file(self, f):
86 | return torch.load(f, map_location=torch.device("cpu"))
87 |
88 | def _load_model(self, checkpoint):
89 | load_state_dict(self.model, checkpoint.pop("model"))
90 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/config_util.py:
--------------------------------------------------------------------------------
1 | from __future__ import (absolute_import, division, print_function,
2 | unicode_literals)
3 |
4 | import copy
5 | import os
6 |
7 | from ret_benchmark.config import cfg as g_cfg
8 |
9 |
10 | def get_config_root_path():
11 | ''' Path to configs for unit tests '''
12 | # cur_file_dir is root/tests/env_tests
13 | cur_file_dir = os.path.dirname(os.path.abspath(os.path.realpath(__file__)))
14 | ret = os.path.dirname(os.path.dirname(cur_file_dir))
15 | ret = os.path.join(ret, "configs")
16 | return ret
17 |
18 |
19 | def load_config(rel_path):
20 | ''' Load config from file path specified as path relative to config_root '''
21 | cfg_path = os.path.join(get_config_root_path(), rel_path)
22 | return load_config_from_file(cfg_path)
23 |
24 |
25 | def load_config_from_file(file_path):
26 | ''' Load config from file path specified as absolute path '''
27 | ret = copy.deepcopy(g_cfg)
28 | ret.merge_from_file(file_path)
29 | return ret
30 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/feat_extractor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | import numpy as np
10 |
11 |
12 | def feat_extractor(model, data_loader, logger=None):
13 | model.eval()
14 | feats = list()
15 |
16 | for i, batch in enumerate(data_loader):
17 | imgs = batch[0].cuda()
18 |
19 | with torch.no_grad():
20 | out = model(imgs).data.cpu().numpy()
21 | feats.append(out)
22 |
23 | if logger is not None and (i + 1) % 100 == 0:
24 | logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]')
25 | del out
26 | feats = np.vstack(feats)
27 | return feats
28 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/freeze_bn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | # Batch Norm Freezer
9 | # Note: adds an additional 2% improvement on CUB (on others benchmarks, it brings no effect)
10 |
11 | def set_bn_eval(m):
12 | classname = m.__class__.__name__
13 | if classname.find('BatchNorm') != -1:
14 | m.eval()
15 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/img_reader.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from PIL import Image
3 |
4 |
5 | def read_image(img_path, mode='RGB'):
6 | """Keep reading image until succeed.
7 | This can avoid IOError incurred by heavy IO process."""
8 | got_img = False
9 | if not osp.exists(img_path):
10 | raise IOError(f"{img_path} does not exist")
11 | while not got_img:
12 | try:
13 | img = Image.open(img_path).convert("RGB")
14 | if mode == "BGR":
15 | r, g, b = img.split()
16 | img = Image.merge("RGB", (b, g, r))
17 | got_img = True
18 | except IOError:
19 | print(f"IOError incurred when reading '{img_path}'. Will redo.")
20 | pass
21 | return img
22 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/init_methods.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | from torch import nn
10 |
11 |
12 | def weights_init_kaiming(m):
13 | classname = m.__class__.__name__
14 | if classname.find('Linear') != -1:
15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
16 | nn.init.constant_(m.bias, 0.0)
17 | elif classname.find('Conv') != -1:
18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | if m.bias is not None:
20 | nn.init.constant_(m.bias, 0.0)
21 | elif classname.find('BatchNorm') != -1:
22 | if m.affine:
23 | nn.init.constant_(m.weight, 1.0)
24 | nn.init.constant_(m.bias, 0.0)
25 |
26 |
27 | def weights_init_classifier(m):
28 | classname = m.__class__.__name__
29 | if classname.find('Linear') != -1:
30 | nn.init.normal_(m.weight, std=0.001)
31 | if m.bias is not None:
32 | nn.init.constant_(m.bias, 0.0)
33 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import os
9 | import sys
10 | import logging
11 |
12 | _streams = {
13 | "stdout": sys.stdout
14 | }
15 |
16 |
17 | def setup_logger(name: str, level: int, stream: str = "stdout") -> logging.Logger:
18 | global _streams
19 | if stream not in _streams:
20 | log_folder = os.path.dirname(stream)
21 | os.makedirs(log_folder, exist_ok=True)
22 | _streams[stream] = open(stream, 'w')
23 | logger = logging.getLogger(name)
24 | logger.propagate = False
25 | logger.setLevel(level)
26 |
27 | sh = logging.StreamHandler(stream=_streams[stream])
28 | sh.setLevel(level)
29 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
30 | sh.setFormatter(formatter)
31 | logger.addHandler(sh)
32 | return logger
33 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/metric_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | from collections import defaultdict
3 | from collections import deque
4 |
5 | import torch
6 |
7 |
8 | class SmoothedValue(object):
9 | """Track a series of values and provide access to smoothed values over a
10 | window or the global series average.
11 | """
12 |
13 | def __init__(self, window_size=20):
14 | self.deque = deque(maxlen=window_size)
15 | self.series = []
16 | self.total = 0.0
17 | self.count = 0
18 |
19 | def update(self, value):
20 | self.deque.append(value)
21 | self.series.append(value)
22 | self.count += 1
23 | self.total += value
24 |
25 | @property
26 | def median(self):
27 | d = torch.tensor(list(self.deque))
28 | return d.median().item()
29 |
30 | @property
31 | def avg(self):
32 | d = torch.tensor(list(self.deque))
33 | return d.mean().item()
34 |
35 | @property
36 | def global_avg(self):
37 | return self.total / self.count
38 |
39 |
40 | class MetricLogger(object):
41 | def __init__(self, delimiter="\t"):
42 | self.meters = defaultdict(SmoothedValue)
43 | self.delimiter = delimiter
44 |
45 | def update(self, **kwargs):
46 | for k, v in kwargs.items():
47 | if isinstance(v, torch.Tensor):
48 | v = v.item()
49 | assert isinstance(v, (float, int))
50 | self.meters[k].update(v)
51 |
52 | def __getattr__(self, attr):
53 | if attr in self.meters:
54 | return self.meters[attr]
55 | if attr in self.__dict__:
56 | return self.__dict__[attr]
57 | raise AttributeError("'{}' object has no attribute '{}'".format(
58 | type(self).__name__, attr))
59 |
60 | def __str__(self):
61 | loss_str = []
62 | for name, meter in self.meters.items():
63 | loss_str.append(
64 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
65 | )
66 | return self.delimiter.join(loss_str)
67 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/model_serialization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | from collections import OrderedDict
3 | import logging
4 |
5 | import torch
6 |
7 |
8 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
9 | """
10 | Strategy: suppose that the models that we will create will have prefixes appended
11 | to each of its keys, for example due to an extra level of nesting that the original
12 | pre-trained weights from ImageNet won't contain. For example, model.state_dict()
13 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
14 | res2.conv1.weight. We thus want to match both parameters together.
15 | For that, we look for each model weight, look among all loaded keys if there is one
16 | that is a suffix of the current weight name, and use it if that's the case.
17 | If multiple matches exist, take the one with longest size
18 | of the corresponding name. For example, for the same model as before, the pretrained
19 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
20 | we want to match backbone[0].body.conv1.weight to conv1.weight, and
21 | backbone[0].body.res2.conv1.weight to res2.conv1.weight.
22 | """
23 | current_keys = sorted(list(model_state_dict.keys()))
24 | loaded_keys = sorted(list(loaded_state_dict.keys()))
25 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the
26 | # loaded_key string, if it matches
27 | match_matrix = [
28 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
29 | ]
30 | match_matrix = torch.as_tensor(match_matrix).view(
31 | len(current_keys), len(loaded_keys)
32 | )
33 | max_match_size, idxs = match_matrix.max(1)
34 | # remove indices that correspond to no-match
35 | idxs[max_match_size == 0] = -1
36 |
37 | # used for logging
38 | max_size = max([len(key) for key in current_keys]) if current_keys else 1
39 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
40 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
41 | logger = logging.getLogger(__name__)
42 | for idx_new, idx_old in enumerate(idxs.tolist()):
43 | if idx_old == -1:
44 | continue
45 | key = current_keys[idx_new]
46 | key_old = loaded_keys[idx_old]
47 | model_state_dict[key] = loaded_state_dict[key_old]
48 | logger.info(
49 | log_str_template.format(
50 | key,
51 | max_size,
52 | key_old,
53 | max_size_loaded,
54 | tuple(loaded_state_dict[key_old].shape),
55 | )
56 | )
57 |
58 |
59 | def strip_prefix_if_present(state_dict, prefix):
60 | keys = sorted(state_dict.keys())
61 | if not all(key.startswith(prefix) for key in keys):
62 | return state_dict
63 | stripped_state_dict = OrderedDict()
64 | for key, value in state_dict.items():
65 | stripped_state_dict[key.replace(prefix, "")] = value
66 | return stripped_state_dict
67 |
68 |
69 | def load_state_dict(model, loaded_state_dict):
70 | model_state_dict = model.state_dict()
71 | # if the state_dict comes from a model that was wrapped in a
72 | # DataParallel or DistributedDataParallel during serialization,
73 | # remove the "module" prefix before performing the matching
74 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
75 | align_and_update_state_dicts(model_state_dict, loaded_state_dict)
76 |
77 | # use strict loading
78 | model.load_state_dict(model_state_dict)
79 |
--------------------------------------------------------------------------------
/ret_benchmark/utils/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 |
3 |
4 | def _register_generic(module_dict, module_name, module):
5 | assert module_name not in module_dict
6 | module_dict[module_name] = module
7 |
8 |
9 | class Registry(dict):
10 | '''
11 | A helper class for managing registering modules, it extends a dictionary
12 | and provides a register functions.
13 |
14 | Eg. creeting a registry:
15 | some_registry = Registry({"default": default_module})
16 |
17 | There're two ways of registering new modules:
18 | 1): normal way is just calling register function:
19 | def foo():
20 | ...
21 | some_registry.register("foo_module", foo)
22 | 2): used as decorator when declaring the module:
23 | @some_registry.register("foo_module")
24 | @some_registry.register("foo_modeul_nickname")
25 | def foo():
26 | ...
27 |
28 | Access of module is just like using a dictionary, eg:
29 | f = some_registry["foo_modeul"]
30 | '''
31 |
32 | def __init__(self, *args, **kwargs):
33 | super(Registry, self).__init__(*args, **kwargs)
34 |
35 | def register(self, module_name, module=None):
36 | # used as function call
37 | if module is not None:
38 | _register_generic(self, module_name, module)
39 | return
40 |
41 | # used as decorator
42 | def register_fn(fn):
43 | _register_generic(self, module_name, fn)
44 | return fn
45 |
46 | return register_fn
47 |
--------------------------------------------------------------------------------
/scripts/prepare_cub.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | CUB_ROOT='resource/datasets/CUB_200_2011/'
5 | CUB_DATA='http://www.vision.caltech.edu.s3-us-west-2.amazonaws.com/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
6 |
7 |
8 | if [[ ! -d "${CUB_ROOT}" ]]; then
9 | mkdir -p resource/datasets
10 | pushd resource/datasets
11 | echo "Downloading CUB_200_2011 data-set..."
12 | wget ${CUB_DATA}
13 | tar -zxf CUB_200_2011.tgz
14 | popd
15 | fi
16 | # Generate train.txt and test.txt splits
17 | echo "Generating the train.txt/test.txt split files"
18 | python scripts/split_cub_for_ms_loss.py
19 |
20 |
21 |
--------------------------------------------------------------------------------
/scripts/run_cub.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | OUT_DIR="output"
4 | if [[ ! -d "${OUT_DIR}" ]]; then
5 | echo "Creating output dir for training : ${OUT_DIR}"
6 | mkdir ${OUT_DIR}
7 | fi
8 | CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example.yaml
9 |
--------------------------------------------------------------------------------
/scripts/run_cub_margin.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | OUT_DIR="output_margin"
4 | if [[ ! -d "${OUT_DIR}" ]]; then
5 | echo "Creating output dir for training : ${OUT_DIR}"
6 | mkdir ${OUT_DIR}
7 | fi
8 | CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example_margin.yaml
9 |
--------------------------------------------------------------------------------
/scripts/split_cub_for_ms_loss.py:
--------------------------------------------------------------------------------
1 |
2 | cub_root = 'resource/datasets/CUB_200_2011/'
3 | images_file = cub_root + 'images.txt'
4 | train_file = cub_root + 'train.txt'
5 | test_file = cub_root + 'test.txt'
6 |
7 |
8 | def main():
9 | train = []
10 | test = []
11 | with open(images_file) as f_img:
12 | for l_img in f_img:
13 | i, fname = l_img.split()
14 | label = int(fname.split('.', 1)[0])
15 | if label <= 100:
16 | train.append((fname, label - 1)) # labels 0 ... 99 (0-based labels for margin_loss)
17 | else:
18 | test.append((fname, label - 1)) # labels 100 ... 199
19 |
20 | for f, v in [(train_file, train), (test_file, test)]:
21 | with open(f, 'w') as tf:
22 | for fname, label in v:
23 | print("images/{},{}".format(fname, label), file=tf)
24 |
25 |
26 | if __name__ == '__main__':
27 | main()
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import torch
10 | from setuptools import find_packages, setup
11 | from torch.utils.cpp_extension import CppExtension
12 |
13 |
14 | requirements = ["torch", "torchvision"]
15 |
16 | setup(
17 | name="ret_benchmark",
18 | version="0.1",
19 | author="Malong Technologies",
20 | url="https://github.com/MalongTech/research-ms-loss",
21 | description="ms-loss",
22 | packages=find_packages(exclude=("configs", "tests")),
23 | install_requires=requirements,
24 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
25 | )
26 |
--------------------------------------------------------------------------------
/tools/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in the root directory of this source tree.
7 |
8 | import argparse
9 | import torch
10 |
11 | from ret_benchmark.config import cfg
12 | from ret_benchmark.data import build_data
13 | from ret_benchmark.engine.trainer import do_train
14 | from ret_benchmark.losses import build_loss
15 | from ret_benchmark.modeling import build_model
16 | from ret_benchmark.solver import build_lr_scheduler, build_optimizer
17 | from ret_benchmark.utils.logger import setup_logger
18 | from ret_benchmark.utils.checkpoint import Checkpointer
19 |
20 |
21 | def train(cfg):
22 | logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL)
23 | logger.info(cfg)
24 | model = build_model(cfg)
25 | device = torch.device(cfg.MODEL.DEVICE)
26 | model.to(device)
27 |
28 | criterion = build_loss(cfg)
29 |
30 | optimizer = build_optimizer(cfg, model)
31 | scheduler = build_lr_scheduler(cfg, optimizer)
32 |
33 | train_loader = build_data(cfg, is_train=True)
34 | val_loader = build_data(cfg, is_train=False)
35 |
36 | logger.info(train_loader.dataset)
37 | logger.info(val_loader.dataset)
38 |
39 | arguments = dict()
40 | arguments["iteration"] = 0
41 |
42 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
43 | checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)
44 |
45 | do_train(
46 | cfg,
47 | model,
48 | train_loader,
49 | val_loader,
50 | optimizer,
51 | scheduler,
52 | criterion,
53 | checkpointer,
54 | device,
55 | checkpoint_period,
56 | arguments,
57 | logger
58 | )
59 |
60 |
61 | def parse_args():
62 | """
63 | Parse input arguments
64 | """
65 | parser = argparse.ArgumentParser(description='Train a retrieval network')
66 | parser.add_argument(
67 | '--cfg',
68 | dest='cfg_file',
69 | help='config file',
70 | default=None,
71 | type=str)
72 | return parser.parse_args()
73 |
74 |
75 | if __name__ == '__main__':
76 | args = parse_args()
77 | cfg.merge_from_file(args.cfg_file)
78 | train(cfg)
79 |
--------------------------------------------------------------------------------