├── .gitignore ├── DeepFtoCoco.py ├── LICENSE.md ├── README.md ├── datasets ├── DF2Dataset.py ├── MFDataset.py └── MultiDF2Dataset.py ├── evaluate_movingfashion.py ├── evaluate_multiDF2.py ├── models ├── match_head.py ├── matchrcnn.py ├── nlb.py └── video_matchrcnn.py ├── stuffs ├── engine.py ├── mask_utils.py ├── transform.py └── utils.py ├── train_matchrcnn.py ├── train_movingfashion.py └── train_multiDF2.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /DeepFtoCoco.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | dataset = { 9 | "info": {}, 10 | "licenses": [], 11 | "images": [], 12 | "annotations": [], 13 | "categories": [] 14 | } 15 | 16 | lst_name = ['short_sleeved_shirt', 'long_sleeved_shirt', 'short_sleeved_outwear', 'long_sleeved_outwear', 17 | 'vest', 'sling', 'shorts', 'trousers', 'skirt', 'short_sleeved_dress', 18 | 'long_sleeved_dress', 'vest_dress', 'sling_dress'] 19 | 20 | for idx, e in enumerate(lst_name): 21 | dataset['categories'].append({ 22 | 'id': idx + 1, 23 | 'name': e, 24 | 'supercategory': "clothes", 25 | 'keypoints': ['%i' % (i) for i in range(1, 295)], 26 | 'skeleton': [] 27 | }) 28 | 29 | path = '' 30 | 31 | imgs = glob.glob(path + 'image/*.jpg') 32 | bar = tqdm(total=len(imgs)) 33 | 34 | sub_index = 0 # the index of ground truth instance 35 | for num in range(1, len(imgs) + 1): 36 | bar.update(1) 37 | 38 | json_name = path + 'annos/' + str(num).zfill(6) + '.json' 39 | image_name = path + 'image/' + str(num).zfill(6) + '.jpg' 40 | 41 | if (num >= 0): 42 | imag = Image.open(image_name) 43 | width, height = imag.size 44 | with open(json_name, 'r') as f: 45 | temp = json.loads(f.read()) 46 | pair_id = temp['pair_id'] 47 | source = temp['source'] 48 | styles = {} 49 | for i in temp: 50 | if i == 'source' or i == 'pair_id': 51 | continue 52 | else: 53 | points = np.zeros((294, 3)) 54 | sub_index = sub_index + 1 55 | box = temp[i]['bounding_box'] 56 | w = box[2] - box[0] 57 | h = box[3] - box[1] 58 | x_1 = box[0] 59 | y_1 = box[1] 60 | bbox = [x_1, y_1, w, h] 61 | cat = temp[i]['category_id'] 62 | style = temp[i]['style'] 63 | styles.update({style: pair_id}) 64 | seg = temp[i]['segmentation'] 65 | landmarks = temp[i]['landmarks'] 66 | 67 | points_x = landmarks[0::3] 68 | points_y = landmarks[1::3] 69 | points_v = landmarks[2::3] 70 | points_x = np.array(points_x) 71 | points_y = np.array(points_y) 72 | points_v = np.array(points_v) 73 | case = [0, 25, 58, 89, 128, 143, 158, 168, 182, 190, 219, 256, 275, 294] 74 | idx_i, idx_j = case[cat - 1], case[cat] 75 | 76 | for n in range(idx_i, idx_j): 77 | points[n] = points_x[n - idx_i] 78 | points[n, 1] = points_y[n - idx_i] 79 | points[n, 2] = points_v[n - idx_i] 80 | 81 | num_points = len(np.where(points_v > 0)[0]) 82 | 83 | dataset['annotations'].append({ 84 | 'area': w * h, 85 | 'source': source, 86 | 'bbox': bbox, 87 | 'category_id': cat, 88 | 'id': sub_index, 89 | 'pair_id': pair_id, 90 | 'image_id': num, 91 | 'iscrowd': 0, 92 | 'style': style, 93 | 'num_keypoints': num_points, 94 | 'keypoints': points.tolist(), 95 | 'segmentation': seg, 96 | }) 97 | 98 | dataset['images'].append({ 99 | 'coco_url': '', 100 | 'date_captured': '', 101 | 'source': source, 102 | 'file_name': str(num).zfill(6) + '.jpg', 103 | 'flickr_url': '', 104 | 'id': num, 105 | 'license': 0, 106 | 'width': width, 107 | 'height': height, 108 | 'match_desc': styles, 109 | }) 110 | 111 | json_name = path + 'annots.json' 112 | with open(json_name, 'w') as f: 113 | json.dump(dataset, f) 114 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. 438 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/movingfashion-a-benchmark-for-the-video-to/video-to-shop-on-movingfashion)](https://paperswithcode.com/sota/video-to-shop-on-movingfashion?p=movingfashion-a-benchmark-for-the-video-to) 4 | 5 | 6 | # SEAM Match-RCNN 7 | Official code of [**MovingFashion: a Benchmark for the Video-to-Shop Challenge**](https://arxiv.org/abs/2110.02627) paper 8 | 9 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa] 10 | 11 | 12 | ## Installation 13 | 14 | ### Requirements: 15 | - Pytorch 1.5.1 or more recent, with cudatoolkit (10.2) 16 | - torchvision 17 | - tensorboard 18 | - cocoapi 19 | - OpenCV Python 20 | - tqdm 21 | - cython 22 | - CUDA >= 10 23 | 24 | ### Step-by-step installation 25 | 26 | ```bash 27 | # first, make sure that your conda is setup properly with the right environment 28 | # for that, check that `which conda`, `which pip` and `which python` points to the 29 | # right path. From a clean conda env, this is what you need to do 30 | 31 | conda create --name seam -y python=3 32 | conda activate seam 33 | 34 | pip install cython tqdm opencv-python 35 | 36 | # follow PyTorch installation in https://pytorch.org/get-started/locally/ 37 | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch 38 | 39 | conda install tensorboard 40 | 41 | export INSTALL_DIR=$PWD 42 | 43 | # install pycocotools 44 | cd $INSTALL_DIR 45 | git clone https://github.com/cocodataset/cocoapi.git 46 | cd cocoapi/PythonAPI 47 | python setup.py build_ext install 48 | 49 | # download SEAM 50 | cd $INSTALL_DIR 51 | git clone https://github.com/VIPS4/SEAM-Match-RCNN.git 52 | cd SEAM-Match-RCNN 53 | mkdir data 54 | mkdir ckpt 55 | 56 | unset INSTALL_DIR 57 | ``` 58 | ## Dataset 59 | 60 | SEAM Match-RCNN has been trained and test on MovingFashion and DeepFashion2 datasets. 61 | Follow the instruction to download and extract the datasets. 62 | 63 | We suggest to download the datasets inside the folder **data**. 64 | 65 | ### MovingFashion 66 | 67 | MovingFashion dataset is available for academic purposes [here](https://bit.ly/4bTZGeS). 68 | 69 | 70 | ### Deepfashion2 71 | DeepFashion2 dataset is available [here](https://drive.google.com/drive/folders/125F48fsMBz2EF0Cpqk6aaHet5VH399Ok?usp=sharing). You need fill in the [form](https://docs.google.com/forms/d/e/1FAIpQLSeIoGaFfCQILrtIZPykkr8q_h9qQ5BoTYbjvf95aXbid0v2Bw/viewform?usp=sf_link) to get password for unzipping files. 72 | 73 | 74 | Once the dataset will be extracted, use the reserved DeepFtoCoco.py script to convert the annotations in COCO format, specifying dataset path. 75 | ```bash 76 | python DeepFtoCoco.py --path 77 | ``` 78 | 79 | 80 | 81 | ## Training 82 | We provide the scripts to train both Match-RCNN and SEAM Match-RCNN. Check the scripts for all the possible parameters. 83 | 84 | ### Single GPU 85 | ```bash 86 | #training of Match-RCNN 87 | python train_matchrcnn.py --root_train --train_annots --save_path 88 | 89 | #training on movingfashion 90 | python train_movingfashion.py --root --train_annots --test_annots --pretrained_path 91 | 92 | 93 | #training on multi-deepfashion2 94 | python train_multiDF2.py --root --train_annots --test_annots --pretrained_path 95 | ``` 96 | 97 | 98 | ### Multi GPU 99 | We use internally ```torch.distributed.launch``` in order to launch multi-gpu training. This utility function from PyTorch spawns as many Python processes as the number of GPUs we want to use, and each Python process will only use a single GPU. 100 | 101 | ```bash 102 | #training of Match-RCNN 103 | python -m torch.distributed.launch --nproc_per_node= train_matchrcnn.py --root_train --train_annots --save_path 104 | 105 | #training on movingfashion 106 | python -m torch.distributed.launch --nproc_per_node= train_movingfashion.py --root --train_annots --test_annots --pretrained_path 107 | 108 | #training on multi-deepfashion2 109 | python -m torch.distributed.launch --nproc_per_node= train_multiDF2.py --root --train_annots --test_annots --pretrained_path 110 | ``` 111 | 112 | 113 | ### Pre-Trained models 114 | It is possibile to start training using the MatchRCNN pre-trained model. 115 | 116 | **[MatchRCNN]** Pre-trained model on Deepfashion2 is available to download [here](https://bit.ly/3m3y6C4). This model can be used to start the training at the second phase (training directly SEAM Match-RCNN). 117 | 118 | 119 | 120 | 121 | 122 | We suggest to download the model inside the folder **ckpt**. 123 | 124 | ## Evaluation 125 | To evaluate the models of SEAM Match-RCNN please use the following scripts. 126 | 127 | ```bash 128 | #evaluation on movingfashion 129 | python evaluate_movingfashion.py --root_test --test_annots --ckpt_path 130 | 131 | 132 | #evaluation on multi-deepfashion2 133 | python evaluate_multiDF2.py --root_test --test_annots --ckpt_path 134 | ``` 135 | 136 | ## Citation 137 | ``` 138 | @misc{godi2021movingfashion, 139 | title={MovingFashion: a Benchmark for the Video-to-Shop Challenge}, 140 | author={Marco Godi and Christian Joppi and Geri Skenderi and Marco Cristani}, 141 | year={2021}, 142 | eprint={2110.02627}, 143 | archivePrefix={arXiv}, 144 | primaryClass={cs.CV} 145 | } 146 | ``` 147 | 148 | This work is licensed under a 149 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa]. 150 | 151 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa] 152 | 153 | [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/ 154 | [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png 155 | [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg 156 | 157 | 158 | -------------------------------------------------------------------------------- /datasets/DF2Dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import sys 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torchvision 8 | from torch._six import int_classes as _int_classes 9 | from torch.utils.data.sampler import BatchSampler 10 | from torch.utils.data.sampler import Sampler 11 | 12 | from stuffs import utils 13 | from stuffs.mask_utils import annToMask 14 | 15 | 16 | def get_size(obj, seen=None): 17 | """Recursively finds size of objects""" 18 | size = sys.getsizeof(obj) 19 | if seen is None: 20 | seen = set() 21 | obj_id = id(obj) 22 | if obj_id in seen: 23 | return 0 24 | # Important mark as seen *before* entering recursion to gracefully handle 25 | # self-referential objects 26 | seen.add(obj_id) 27 | if isinstance(obj, dict): 28 | size += sum([get_size(v, seen) for v in obj.values()]) 29 | size += sum([get_size(k, seen) for k in obj.keys()]) 30 | elif hasattr(obj, '__dict__'): 31 | size += get_size(obj.__dict__, seen) 32 | elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): 33 | size += sum([get_size(i, seen) for i in obj]) 34 | return size 35 | 36 | 37 | def _count_visible_keypoints(anno): 38 | return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) 39 | 40 | 41 | def _has_only_empty_bbox(anno): 42 | return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) 43 | 44 | 45 | def has_valid_annotation(anno): 46 | # if it's empty, there is no annotation 47 | if len(anno) == 0: 48 | return False 49 | # if all boxes have close to zero area, there is no annotation 50 | if _has_only_empty_bbox(anno): 51 | return False 52 | # keypoints task have a slight different critera for considering 53 | # if an annotation is valid 54 | if "keypoints" not in anno[0]: 55 | return True 56 | # for keypoint detection tasks, only consider valid images those 57 | # containing at least min_keypoints_per_image 58 | if _count_visible_keypoints(anno) >= 10: 59 | return True 60 | return False 61 | 62 | 63 | class DeepFashion2Dataset(torchvision.datasets.coco.CocoDetection): 64 | def __init__( 65 | self, ann_file, root, transforms=None 66 | ): 67 | super(DeepFashion2Dataset, self).__init__(root, ann_file) 68 | self.ids = sorted(self.ids) 69 | 70 | self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()} 71 | 72 | self.json_category_id_to_contiguous_id = { 73 | v: i + 1 for i, v in enumerate(self.coco.getCatIds()) 74 | } 75 | self.contiguous_category_id_to_json_id = { 76 | v: k for k, v in self.json_category_id_to_contiguous_id.items() 77 | } 78 | self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} 79 | self.idx_to_id_map = {v: k for k, v in enumerate(self.ids)} 80 | 81 | self._transforms = transforms 82 | self.street_inds = self._getTypeInds('user') 83 | self.shop_inds = self._getTypeInds('shop') 84 | 85 | self.match_map_shop = {} 86 | self.match_map_street = {} 87 | print("Computing Street Match Descriptors map") 88 | for i in self.street_inds: 89 | e = self.coco.imgs[i] 90 | for x in e['match_desc']: 91 | if x == '0': 92 | continue 93 | hashable_key = x + '_' + str(e['match_desc'].get(x)) 94 | inds = self.match_map_street.get(hashable_key) 95 | if inds is None: 96 | self.match_map_street.update({hashable_key: [i]}) 97 | else: 98 | inds.append(i) 99 | self.match_map_street.update({hashable_key: inds}) 100 | print("Computing Shop Match Descriptors map") 101 | for i in self.shop_inds: 102 | e = self.coco.imgs[i] 103 | for x in e['match_desc']: 104 | if x == '0': 105 | continue 106 | hashable_key = x + '_' + str(e['match_desc'].get(x)) 107 | inds = self.match_map_shop.get(hashable_key) 108 | if inds is None: 109 | self.match_map_shop.update({hashable_key: [i]}) 110 | else: 111 | inds.append(i) 112 | self.match_map_shop.update({hashable_key: inds}) 113 | 114 | print("Filtering images with no matches") 115 | street_match_keys = list(self.match_map_street.keys()) 116 | shop_match_keys = self.match_map_shop.keys() 117 | self.accepted_entries = [] 118 | for x in self.match_map_street: 119 | if x in shop_match_keys: 120 | self.accepted_entries = self.accepted_entries + self.match_map_street.get(x) 121 | 122 | for x in self.match_map_shop: 123 | if x in street_match_keys: 124 | self.accepted_entries = self.accepted_entries + self.match_map_shop.get(x) 125 | 126 | self.accepted_entries = list(set(self.accepted_entries)) 127 | print("Total images after filtering:" + str(len(self.accepted_entries))) 128 | 129 | def __getitem__(self, idx): 130 | img, anno = super(DeepFashion2Dataset, self).__getitem__(idx) 131 | 132 | anno = [obj for obj in anno if obj["iscrowd"] == 0] 133 | 134 | boxes = [obj["bbox"] for obj in anno if obj['area'] != 0] 135 | boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes 136 | boxes[:, 2] = boxes[:, 2] + boxes[:, 0] 137 | boxes[:, 3] = boxes[:, 3] + boxes[:, 1] 138 | 139 | classes = [obj["category_id"] for obj in anno if obj['area'] != 0] 140 | classes = [self.json_category_id_to_contiguous_id[c] for c in classes] 141 | classes = torch.tensor(classes) 142 | 143 | target = {} 144 | target["labels"] = classes 145 | target["boxes"] = boxes 146 | target["classes"] = classes 147 | 148 | if anno and "area" in anno[0]: 149 | area = torch.stack([torch.as_tensor(obj['area'], dtype=torch.float32) for obj in anno if obj['area'] != 0]) 150 | target["area"] = area 151 | 152 | if anno and "segmentation" in anno[0]: 153 | masks = torch.stack( 154 | [torch.as_tensor(annToMask(obj, size=[img.height, img.width]), dtype=torch.uint8) for obj in anno if obj['area'] != 0]) 155 | target["masks"] = masks 156 | 157 | if anno and "pair_id" in anno[0]: 158 | pair_ids = [obj['pair_id'] for obj in anno if obj['area'] != 0] 159 | pair_ids = torch.tensor(pair_ids) 160 | target["pair_ids"] = pair_ids 161 | 162 | 163 | if anno and "style" in anno[0]: 164 | styles = [obj['style'] for obj in anno if obj['area'] != 0] 165 | styles = torch.tensor(styles) 166 | target["styles"] = styles 167 | 168 | if anno and "source" in anno[0]: 169 | sources = [0 if obj['source'] == 'user' else 1 for obj in anno if obj['area'] != 0] 170 | # print("-->", idx, sources) 171 | sources = torch.tensor(sources) 172 | target["sources"] = sources 173 | 174 | if self._transforms is not None: 175 | img, target = self._transforms(img, target) 176 | 177 | return img, target, anno[0]['image_id'] 178 | 179 | def get_img_info(self, index): 180 | img_id = self.id_to_img_map[index] 181 | img_data = self.coco.imgs[img_id] 182 | return img_data 183 | 184 | def _getTypeInds(self, type_s): 185 | inds = [] 186 | N = len(self.coco.imgs) 187 | for i in self.ids: 188 | if self.coco.imgs[i]['source'] == type_s: 189 | inds.append(i) 190 | 191 | return inds 192 | 193 | 194 | def get_dataloader(dataset, batch_size, is_parallel, num_workers=0): 195 | if is_parallel: 196 | sampler = DistributedSampler(dataset, shuffle=True) 197 | else: 198 | sampler = RandomSampler(dataset) 199 | 200 | batch_sampler = DF2MatchingSampler(dataset, sampler, batch_size, drop_last=True) 201 | 202 | data_loader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_sampler=batch_sampler, 203 | collate_fn=utils.collate_fn) 204 | return data_loader 205 | 206 | 207 | 208 | class RandomSampler(Sampler): 209 | r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. 210 | If with replacement, then user can specify :attr:`num_samples` to draw. 211 | 212 | Arguments: 213 | data_source (Dataset): dataset to sample from 214 | replacement (bool): samples are drawn with replacement if ``True``, default=``False`` 215 | num_samples (int): number of samples to draw, default=`len(dataset)`. This argument 216 | is supposed to be specified only when `replacement` is ``True``. 217 | """ 218 | 219 | def __init__(self, data_source, replacement=False, num_samples=None): 220 | self.data_source = data_source.accepted_entries 221 | 222 | self.replacement = replacement 223 | self._num_samples = num_samples 224 | 225 | if not isinstance(self.replacement, bool): 226 | raise ValueError("replacement should be a boolean value, but got " 227 | "replacement={}".format(self.replacement)) 228 | 229 | if self._num_samples is not None and not replacement: 230 | raise ValueError("With replacement=False, num_samples should not be specified, " 231 | "since a random permute will be performed.") 232 | 233 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 234 | raise ValueError("num_samples should be a positive integer " 235 | "value, but got num_samples={}".format(self.num_samples)) 236 | 237 | @property 238 | def num_samples(self): 239 | # dataset size might change at runtime 240 | if self._num_samples is None: 241 | return len(self.data_source) 242 | return self._num_samples 243 | 244 | def __iter__(self): 245 | n = len(self.data_source) 246 | if self.replacement: 247 | return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) 248 | return iter(torch.randperm(n).tolist()) 249 | 250 | def __len__(self): 251 | return self.num_samples 252 | 253 | 254 | 255 | 256 | class DistributedSampler(Sampler): 257 | """Sampler that restricts data loading to a subset of the dataset. 258 | It is especially useful in conjunction with 259 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 260 | process can pass a DistributedSampler instance as a DataLoader sampler, 261 | and load a subset of the original dataset that is exclusive to it. 262 | .. note:: 263 | Dataset is assumed to be of constant size. 264 | Arguments: 265 | dataset: Dataset used for sampling. 266 | num_replicas (optional): Number of processes participating in 267 | distributed training. 268 | rank (optional): Rank of the current process within num_replicas. 269 | """ 270 | 271 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 272 | if num_replicas is None: 273 | if not dist.is_available(): 274 | raise RuntimeError("Requires distributed package to be available") 275 | num_replicas = dist.get_world_size() 276 | if rank is None: 277 | if not dist.is_available(): 278 | raise RuntimeError("Requires distributed package to be available") 279 | rank = dist.get_rank() 280 | self.dataset = dataset.accepted_entries 281 | 282 | self.num_replicas = num_replicas 283 | self.rank = rank 284 | self.epoch = 0 285 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 286 | self.total_size = self.num_samples * self.num_replicas 287 | self.shuffle = shuffle 288 | 289 | def __iter__(self): 290 | if self.shuffle: 291 | # deterministically shuffle based on epoch 292 | g = torch.Generator() 293 | g.manual_seed(self.epoch) 294 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 295 | else: 296 | indices = torch.arange(len(self.dataset)).tolist() 297 | 298 | # add extra samples to make it evenly divisible 299 | indices += indices[: (self.total_size - len(indices))] 300 | assert len(indices) == self.total_size 301 | 302 | # subsample 303 | offset = self.num_samples * self.rank 304 | indices = indices[offset: offset + self.num_samples] 305 | assert len(indices) == self.num_samples 306 | 307 | return iter(indices) 308 | 309 | def __len__(self): 310 | return self.num_samples 311 | 312 | def set_epoch(self, epoch): 313 | self.epoch = epoch 314 | 315 | 316 | class DF2MatchingSampler(Sampler): 317 | r"""Wraps another sampler to yield a mini-batch of indices. 318 | 319 | Args: 320 | sampler (Sampler): Base sampler. 321 | batch_size (int): Size of mini-batch. 322 | drop_last (bool): If ``True``, the sampler will drop the last batch if 323 | its size would be less than ``batch_size`` 324 | 325 | Example: 326 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 327 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 328 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 329 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 330 | """ 331 | 332 | def __init__(self, dataset, sampler, batch_size, drop_last): 333 | if not isinstance(sampler, Sampler): 334 | raise ValueError("sampler should be an instance of " 335 | "torch.utils.data.Sampler, but got sampler={}" 336 | .format(sampler)) 337 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 338 | batch_size <= 0: 339 | raise ValueError("batch_size should be a positive integer value, " 340 | "but got batch_size={}".format(batch_size)) 341 | if not isinstance(drop_last, bool): 342 | raise ValueError("drop_last should be a boolean value, but got " 343 | "drop_last={}".format(drop_last)) 344 | self.data = dataset 345 | self.sampler = sampler 346 | self.batch_size = batch_size 347 | self.drop_last = drop_last 348 | self.customer_inds = self.data.street_inds 349 | self.shop_inds = self.data.shop_inds 350 | self.customer_used = torch.zeros((len(self.customer_inds, ))) 351 | self.shop_used = torch.zeros((len(self.shop_inds, ))) 352 | self.match_map_shop = self.data.match_map_shop 353 | self.match_map_street = self.data.match_map_street 354 | self.seed_dict = {} 355 | self.tmp_index = [] 356 | 357 | def __iter__(self): 358 | batch = [] 359 | for idx in self.sampler: 360 | ind = self.data.accepted_entries[idx] 361 | if ind in self.customer_inds: 362 | street_ind = ind 363 | shop_inds = self._getSamePairInShop(street_ind) 364 | if len(shop_inds) != 0: 365 | shop_ind = random.choice(shop_inds) 366 | batch.append(self.data.idx_to_id_map.get(street_ind)) 367 | batch.append(self.data.idx_to_id_map.get(shop_ind)) 368 | self.tmp_index.append(str(shop_ind) + '_' + str(street_ind)) 369 | else: 370 | print(idx) 371 | 372 | else: 373 | shop_ind = ind 374 | street_inds = self._getSamePairInStreet(shop_ind) 375 | 376 | if len(street_inds) != 0: 377 | street_ind = random.choice(street_inds) 378 | batch.append(self.data.idx_to_id_map.get(street_ind)) 379 | batch.append(self.data.idx_to_id_map.get(shop_ind)) 380 | self.tmp_index.append(str(shop_ind) + '_' + str(street_ind)) 381 | else: 382 | print(idx) 383 | if len(batch) == self.batch_size: 384 | yield batch 385 | batch = [] 386 | if len(batch) > 0 and not self.drop_last: 387 | yield batch 388 | 389 | def __len__(self): 390 | if self.drop_last: 391 | return len(self.sampler) // (self.batch_size // 2) 392 | else: 393 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 394 | 395 | def _getTypeInds(self, type_s): 396 | inds = [] 397 | N = len(self.data) 398 | for i in range(1, N + 1): 399 | if self.data.coco.imgs[i]['source'] == type_s: 400 | inds.append(i) 401 | return inds 402 | 403 | def _getSamePairInShop(self, id): 404 | match_desc = self.data.coco.imgs[id]['match_desc'] 405 | ids = [] 406 | for x in match_desc: 407 | hashable_key = x + '_' + str(match_desc.get(x)) 408 | matches = self.match_map_shop.get(hashable_key) 409 | if matches is not None: 410 | ids = ids + matches 411 | return ids 412 | 413 | def _getSamePairInStreet(self, id): 414 | match_desc = self.data.coco.imgs[id]['match_desc'] 415 | ids = [] 416 | 417 | for x in match_desc: 418 | hashable_key = x + '_' + str(match_desc.get(x)) 419 | matches = self.match_map_street.get(hashable_key) 420 | if matches is not None: 421 | ids = ids + matches 422 | 423 | return ids -------------------------------------------------------------------------------- /datasets/MFDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from torch.utils.data.distributed import DistributedSampler 11 | 12 | 13 | class MovingFashionDataset(Dataset): 14 | 15 | def __init__(self, jsonpath, transform=None, noise=True, root="", blacklist=None, whitelist=None): 16 | with open(jsonpath, "r") as fp: 17 | self.data = json.load(fp) 18 | if blacklist is not None: 19 | self.product_ids = sorted([k for k in self.data.keys() if k not in blacklist]) 20 | else: 21 | if whitelist is not None: 22 | self.product_ids = sorted([k for k in self.data.keys() if k in whitelist]) 23 | else: 24 | self.product_ids = sorted([k for k in self.data.keys()]) 25 | self.product_list = [self.data[k] for k in self.product_ids] 26 | self.noise = noise 27 | self.transform = transform 28 | self.root = root 29 | 30 | def __len__(self): 31 | return len(self.product_list) 32 | 33 | def __getitem__(self, x): 34 | if isinstance(x, int): 35 | i = x 36 | tag, index = None, None 37 | elif isinstance(x, tuple): 38 | if len(x) == 3: 39 | i, tag, index = x 40 | video_i = None 41 | else: 42 | i, tag, index, video_i = x 43 | ret = {} 44 | ret["paths"] = {'video_paths': self.product_list[i]['video_paths'], 45 | 'img_path': self.product_list[i]['img_path']} 46 | ret['source'] = self.product_list[i]['source'] 47 | ret['tracklet'] = None 48 | ret["i"] = i 49 | tmp_paths = self.product_list[i] 50 | ret["video_i"] = -1 51 | if tag == "video": 52 | video_paths = tmp_paths["video_paths"] 53 | if video_i is None: 54 | video_name = random.choice(video_paths) 55 | ret["video_i"] = video_paths.index(video_name) 56 | else: 57 | ret["video_i"] = video_i 58 | video_name = video_paths[ret["video_i"]] 59 | video = cv2.VideoCapture(os.path.join(self.root, video_name)) 60 | if isinstance(index, int): 61 | # if int you should find a value between 0.0 and 1.0 (index / fps * videolen) 62 | assert False 63 | n_frames = video.get(7) 64 | index2 = int(n_frames * index) 65 | video.set(1, index2) 66 | success, image = video.read() 67 | ret['valid'] = success 68 | if 'tracklets' in self.product_list[i]: 69 | if video_i == None: 70 | if str(index2) in self.product_list[i]['tracklets'][0]: 71 | ret['tracklet'] = self.product_list[i]['tracklets'][0][str(index2)] 72 | else: 73 | if str(index2) in self.product_list[i]['tracklets'][video_i]: 74 | ret['tracklet'] = self.product_list[i]['tracklets'][video_i][str(index2)] 75 | if ret['tracklet'] is not None: 76 | ret['tracklet'] = np.asarray(ret['tracklet']) 77 | else: 78 | ret['tracklet'] = np.asarray([-1, -1, -1, -1]) 79 | # assert success 80 | # from cv2 to PIL 81 | if success: 82 | image = image[:, :, ::-1] 83 | tmp_noise = 0.25 if random.random() > 0.75 else 0.05 84 | if self.noise: 85 | image = image / 255.0 86 | image += np.random.randn(*image.shape) * tmp_noise 87 | image = image * 255.0 88 | image = np.clip(image, 0, 255.0) 89 | image = np.asarray(image, dtype=np.uint8) 90 | img = Image.fromarray(image) 91 | if self.noise: 92 | # img = img.resize((image.shape[1] // 3, image.shape[0] // 3)) 93 | img = img.resize((image.shape[1] // 2, image.shape[0] // 2)) 94 | else: 95 | img = Image.fromarray(np.zeros((100, 100, 3), dtype=np.uint8)) 96 | ret['index2'] = index2 97 | video.release() 98 | else: 99 | tmp_path = tmp_paths["img_path"] 100 | img = Image.open(os.path.join(self.root, tmp_path)) 101 | ret['source'] = tmp_paths["source"] 102 | # ret['img'] = np.asarray(img) 103 | ret['index'] = index 104 | ret['tag'] = 1 if tag != "video" else 0 105 | ret['labels'] = torch.tensor([0]) 106 | ret['boxes'] = torch.tensor([[0.0, 0.0, img.size[0], img.size[1]]], dtype=torch.float32) 107 | ret['masks'] = torch.ones(1, img.size[1], img.size[0], dtype=torch.uint8) 108 | if self.transform is not None: 109 | img, ret = self.transform(img, ret) 110 | return img, ret 111 | 112 | def collate_fn(batch): 113 | return tuple(zip(*batch)) 114 | 115 | 116 | def get_dataloader(dataset, batch_size, is_parallel, n_products=1 117 | , first_n_withvideo=None, uniform_sampling=False, fixed_frame=None 118 | , is_seq=False, num_workers=8, fixed_ind=None, fixed_video_i=None): 119 | if is_parallel: 120 | sampler = DistributedSampler(dataset, shuffle=True) 121 | else: 122 | if is_seq: 123 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 124 | else: 125 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 126 | 127 | batch_sampler = MFBatchSampler(dataset, sampler, batch_size, drop_last=True, n_products=n_products 128 | , first_n_withvideo=first_n_withvideo, uniform_sampling=uniform_sampling 129 | , fixed_frame=fixed_frame, fixed_ind=fixed_ind, fixed_video_i=fixed_video_i) 130 | 131 | data_loader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_sampler=batch_sampler, 132 | collate_fn=collate_fn) 133 | # print("%d %d" % (rank, len(list(data_loader)))) 134 | return data_loader 135 | 136 | 137 | class MFBatchSampler(torch.utils.data.BatchSampler): 138 | def __init__(self, dataset, sampler, batch_size, drop_last, n_samples=100 139 | , n_products=1, first_n_withvideo=None, uniform_sampling=False, fixed_frame=None 140 | , fixed_ind=None, fixed_video_i=None): 141 | super(MFBatchSampler, self).__init__(sampler, batch_size, drop_last) 142 | self.data = dataset 143 | self.n_video_samples = n_samples 144 | self.n_products = n_products 145 | self.first_n_withvideo = first_n_withvideo 146 | self.uniform_sampling = uniform_sampling 147 | self.fixed_frame = fixed_frame 148 | self.fixed_ind = fixed_ind 149 | self.fixed_video_i = fixed_video_i 150 | 151 | def __iter__(self): 152 | batch = [] 153 | count = -1 154 | for idx in self.sampler: 155 | if self.fixed_ind is not None: 156 | idx = self.fixed_ind 157 | batch.append((idx, "in", None)) 158 | count += 1 159 | if self.batch_size == 1: 160 | tmp_video_samples = [x for x in np.linspace(0.0, 1.0, self.n_video_samples + 1)][:-1] 161 | else: 162 | if not self.uniform_sampling: 163 | if self.fixed_frame is None: 164 | tmp_video_samples = sorted([random.random() 165 | for _ in range((self.batch_size // self.n_products) - 1)]) 166 | # tmp_video_samples = sorted(random.choices([0.5, 0.5], k=(self.batch_size // self.n_products) - 1)) 167 | else: 168 | if isinstance(self.fixed_frame, list): 169 | tmp_video_samples = [x for x in self.fixed_frame] 170 | else: 171 | tmp_video_samples = [self.fixed_frame for _ in range((self.batch_size // self.n_products) - 1)] 172 | else: 173 | tmp_video_samples = [x for x in np.linspace(0.00, 1.0, (self.batch_size // self.n_products) - 1)] 174 | if self.first_n_withvideo is None or count < self.first_n_withvideo: 175 | for t in tmp_video_samples: 176 | if self.fixed_video_i is None: 177 | batch.append((idx, "video", t)) 178 | else: 179 | batch.append((idx, "video", t, self.fixed_video_i)) 180 | if self.batch_size == 1 or len(batch) == self.batch_size \ 181 | or self.first_n_withvideo is not None: 182 | yield batch 183 | batch = [] 184 | if not self.drop_last: 185 | yield batch 186 | batch = [] 187 | 188 | def __len__(self): 189 | if self.drop_last: 190 | return len(self.sampler) // self.n_products 191 | else: 192 | return 1 + (len(self.sampler) // self.n_products) 193 | -------------------------------------------------------------------------------- /datasets/MultiDF2Dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from PIL import Image 8 | from torch._six import int_classes as _int_classes 9 | from torch.utils.data import RandomSampler, DistributedSampler 10 | from torch.utils.data.sampler import BatchSampler 11 | from torch.utils.data.sampler import Sampler 12 | 13 | from stuffs import utils 14 | from stuffs.mask_utils import annToMask 15 | 16 | 17 | def get_size(obj, seen=None): 18 | """Recursively finds size of objects""" 19 | size = sys.getsizeof(obj) 20 | if seen is None: 21 | seen = set() 22 | obj_id = id(obj) 23 | if obj_id in seen: 24 | return 0 25 | # Important mark as seen *before* entering recursion to gracefully handle 26 | # self-referential objects 27 | seen.add(obj_id) 28 | if isinstance(obj, dict): 29 | size += sum([get_size(v, seen) for v in obj.values()]) 30 | size += sum([get_size(k, seen) for k in obj.keys()]) 31 | elif hasattr(obj, '__dict__'): 32 | size += get_size(obj.__dict__, seen) 33 | elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): 34 | size += sum([get_size(i, seen) for i in obj]) 35 | return size 36 | 37 | 38 | def _count_visible_keypoints(anno): 39 | return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) 40 | 41 | 42 | def _has_only_empty_bbox(anno): 43 | return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) 44 | 45 | 46 | def has_valid_annotation(anno): 47 | # if it's empty, there is no annotation 48 | if len(anno) == 0: 49 | return False 50 | # if all boxes have close to zero area, there is no annotation 51 | if _has_only_empty_bbox(anno): 52 | return False 53 | # keypoints task have a slight different critera for considering 54 | # if an annotation is valid 55 | if "keypoints" not in anno[0]: 56 | return True 57 | # for keypoint detection tasks, only consider valid images those 58 | # containing at least min_keypoints_per_image 59 | if _count_visible_keypoints(anno) >= 10: 60 | return True 61 | return False 62 | 63 | 64 | class MultiDeepFashion2Dataset(torchvision.datasets.coco.CocoDetection): 65 | def __init__( 66 | self, ann_file, root, transforms=None, noise=False, filter_onestreet=False 67 | ): 68 | super(MultiDeepFashion2Dataset, self).__init__(root, ann_file) 69 | self.ids = sorted(self.ids) 70 | 71 | print(len(self.ids)) 72 | 73 | self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()} 74 | 75 | self.json_category_id_to_contiguous_id = { 76 | v: i + 1 for i, v in enumerate(self.coco.getCatIds()) 77 | } 78 | self.contiguous_category_id_to_json_id = { 79 | v: k for k, v in self.json_category_id_to_contiguous_id.items() 80 | } 81 | self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} 82 | self.idx_to_id_map = {v: k for k, v in enumerate(self.ids)} 83 | 84 | self._transforms = transforms 85 | self.street_inds = self._getTypeInds('user') 86 | self.shop_inds = self._getTypeInds('shop') 87 | 88 | self.match_map_shop = {} 89 | self.match_map_street = {} 90 | print("Computing Street Match Descriptors map") 91 | for i in self.street_inds: 92 | e = self.coco.imgs[i] 93 | for x in e['match_desc']: 94 | if x == '0': 95 | continue 96 | hashable_key = x + '_' + str(e['match_desc'].get(x)) 97 | inds = self.match_map_street.get(hashable_key) 98 | if inds is None: 99 | self.match_map_street.update({hashable_key: [i]}) 100 | else: 101 | inds.append(i) 102 | self.match_map_street.update({hashable_key: inds}) 103 | print("Computing Shop Match Descriptors map") 104 | for i in self.shop_inds: 105 | e = self.coco.imgs[i] 106 | for x in e['match_desc']: 107 | if x == '0': 108 | continue 109 | hashable_key = x + '_' + str(e['match_desc'].get(x)) 110 | inds = self.match_map_shop.get(hashable_key) 111 | if inds is None: 112 | self.match_map_shop.update({hashable_key: [i]}) 113 | else: 114 | inds.append(i) 115 | self.match_map_shop.update({hashable_key: inds}) 116 | 117 | if filter_onestreet: 118 | print("Filtering products with one street or less") 119 | 120 | to_del = [] 121 | self.shop_match_keys = self.match_map_shop.keys() 122 | for x in self.match_map_street: 123 | if x not in self.shop_match_keys or len(self.match_map_street[x]) < 2: 124 | to_del.append(x) 125 | for x in to_del: 126 | del self.match_map_street[x] 127 | 128 | to_del = [] 129 | self.street_match_keys = list(self.match_map_street.keys()) 130 | for x in self.match_map_shop: 131 | if x not in self.street_match_keys: 132 | to_del.append(x) 133 | for x in to_del: 134 | del self.match_map_shop[x] 135 | 136 | self.noise = noise 137 | 138 | 139 | def __len__(self): 140 | return len(self.match_map_street) 141 | 142 | 143 | def __getitem__(self, x): 144 | # i: product id 145 | # tag: "shop" or "street" 146 | # index: None if tag is "shop" else index of street 147 | i, tag, index = x 148 | if tag == "shop": 149 | idx = random.choice(self.match_map_shop[i]) 150 | else: 151 | index2 = int(len(self.match_map_street[i]) * index) 152 | idx = self.match_map_street[i][index2] 153 | 154 | img, anno = super(MultiDeepFashion2Dataset, self).__getitem__(self.idx_to_id_map[idx]) 155 | 156 | # **************************************************** 157 | image = np.array(img) 158 | if self.noise: 159 | tmp_noise = 0.1 if random.random() > 0.75 else 0.0 160 | else: 161 | tmp_noise = 0.0 162 | image = image / 255.0 163 | image += np.random.randn(*image.shape) * tmp_noise 164 | image = image * 255.0 165 | image = np.clip(image, 0, 255.0) 166 | image = np.asarray(image, dtype=np.uint8) 167 | img = Image.fromarray(image) 168 | # **************************************************** 169 | 170 | anno = [obj for obj in anno if obj["iscrowd"] == 0] 171 | 172 | boxes = [obj["bbox"] for obj in anno if obj['area'] != 0] 173 | boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes 174 | boxes[:, 2] = boxes[:, 2] + boxes[:, 0] 175 | boxes[:, 3] = boxes[:, 3] + boxes[:, 1] 176 | 177 | classes = [obj["category_id"] for obj in anno if obj['area'] != 0] 178 | classes = [self.json_category_id_to_contiguous_id[c] for c in classes] 179 | classes = torch.tensor(classes) 180 | 181 | target = {} 182 | target["labels"] = classes 183 | target["boxes"] = boxes 184 | target["classes"] = classes 185 | 186 | if anno and "area" in anno[0]: 187 | area = torch.stack([torch.as_tensor(obj['area'], dtype=torch.float32) for obj in anno if obj['area'] != 0]) 188 | target["area"] = area 189 | 190 | if anno and "segmentation" in anno[0]: 191 | masks = torch.stack( 192 | [torch.as_tensor(annToMask(obj, size=[img.height, img.width]), dtype=torch.uint8) for obj in anno if obj['area'] != 0]) 193 | target["masks"] = masks 194 | 195 | if anno and "pair_id" in anno[0]: 196 | pair_ids = [obj['pair_id'] for obj in anno if obj['area'] != 0] 197 | pair_ids = torch.tensor(pair_ids) 198 | target["pair_ids"] = pair_ids 199 | 200 | 201 | if anno and "style" in anno[0]: 202 | styles = [obj['style'] for obj in anno if obj['area'] != 0] 203 | styles = torch.tensor(styles) 204 | target["styles"] = styles 205 | 206 | if anno and "source" in anno[0]: 207 | sources = [0 if obj['source'] == 'user' else 1 for obj in anno if obj['area'] != 0] 208 | # print("-->", idx, sources) 209 | sources = torch.tensor(sources) 210 | target["sources"] = sources 211 | 212 | if self._transforms is not None: 213 | img, target = self._transforms(img, target) 214 | 215 | 216 | target["i"] = i 217 | target['tag'] = 1 if tag == "shop" else 0 218 | 219 | return img, target, anno[0]['image_id'] 220 | 221 | def get_img_info(self, index): 222 | img_id = self.id_to_img_map[index] 223 | img_data = self.coco.imgs[img_id] 224 | return img_data 225 | 226 | def _getTypeInds(self, type_s): 227 | inds = [] 228 | N = len(self.coco.imgs) 229 | for i in self.ids: 230 | if self.coco.imgs[i]['source'] == type_s: 231 | inds.append(i) 232 | 233 | return inds 234 | 235 | def get_dataloader(dataset, batch_size, is_parallel, n_products=0, n_workers=0): 236 | if is_parallel: 237 | sampler = DistributedSampler(dataset, shuffle=True) 238 | else: 239 | sampler = RandomSampler(dataset) 240 | 241 | batch_sampler = MultiDF2BatchSampler(dataset, sampler, batch_size, drop_last=True, n_products=n_products) 242 | 243 | data_loader = torch.utils.data.DataLoader(dataset, num_workers=n_workers, batch_sampler=batch_sampler, 244 | collate_fn=utils.collate_fn) 245 | return data_loader 246 | 247 | 248 | class MultiDF2BatchSampler(Sampler): 249 | r"""Wraps another sampler to yield a mini-batch of indices. 250 | 251 | Args: 252 | sampler (Sampler): Base sampler. 253 | batch_size (int): Size of mini-batch. 254 | drop_last (bool): If ``True``, the sampler will drop the last batch if 255 | its size would be less than ``batch_size`` 256 | 257 | Example: 258 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 259 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 260 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 261 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 262 | """ 263 | 264 | def __init__(self, dataset, sampler, batch_size, drop_last, n_products=0): 265 | if not isinstance(sampler, Sampler): 266 | raise ValueError("sampler should be an instance of " 267 | "torch.utils.data.Sampler, but got sampler={}" 268 | .format(sampler)) 269 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 270 | batch_size <= 0: 271 | raise ValueError("batch_size should be a positive integer value, " 272 | "but got batch_size={}".format(batch_size)) 273 | if not isinstance(drop_last, bool): 274 | raise ValueError("drop_last should be a boolean value, but got " 275 | "drop_last={}".format(drop_last)) 276 | self.data = dataset 277 | self.sampler = sampler 278 | self.batch_size = batch_size 279 | self.drop_last = drop_last 280 | self.customer_inds = self.data.street_inds 281 | self.shop_inds = self.data.shop_inds 282 | self.customer_used = torch.zeros((len(self.customer_inds, ))) 283 | self.shop_used = torch.zeros((len(self.shop_inds, ))) 284 | self.match_map_shop = self.data.match_map_shop 285 | self.match_map_street = self.data.match_map_street 286 | self.n_products = n_products 287 | self.seed_dict = {} 288 | self.tmp_index = [] 289 | pair_keys = [k for k in self.data.match_map_street.keys() if k in self.data.match_map_shop] 290 | pair_keys += [k for k in self.data.match_map_shop.keys() if k in self.data.match_map_street] 291 | self.pair_keys = list(set(pair_keys)) 292 | 293 | def __iter__(self): 294 | batch = [] 295 | count = -1 296 | pair_keys = self.pair_keys 297 | for idx in self.sampler: 298 | batch.append((pair_keys[idx], "shop", None)) 299 | count += 1 300 | tmp_video_samples = sorted([random.random() for x in range((self.batch_size // self.n_products) - 1)]) 301 | 302 | for t in tmp_video_samples: 303 | batch.append((pair_keys[idx], "street", t)) 304 | if self.batch_size == 1 or len(batch) == self.batch_size: 305 | yield batch 306 | batch = [] 307 | if not self.drop_last: 308 | yield batch 309 | batch = [] 310 | 311 | def __len__(self): 312 | if self.drop_last: 313 | return len(self.sampler) // self.n_products 314 | else: 315 | return 1 + (len(self.sampler) // self.n_products) 316 | 317 | def _getTypeInds(self, type_s): 318 | inds = [] 319 | N = len(self.data) 320 | for i in range(1, N + 1): 321 | if self.data.coco.imgs[i]['source'] == type_s: 322 | inds.append(i) 323 | 324 | return inds 325 | 326 | def _getSamePairInShop(self, id): 327 | match_desc = self.data.coco.imgs[id]['match_desc'] 328 | ids = [] 329 | 330 | for x in match_desc: 331 | hashable_key = x + '_' + str(match_desc.get(x)) 332 | matches = self.match_map_shop.get(hashable_key) 333 | if matches is not None: 334 | ids = ids + matches 335 | 336 | return ids 337 | 338 | def _getSamePairInStreet(self, id): 339 | match_desc = self.data.coco.imgs[id]['match_desc'] 340 | ids = [] 341 | 342 | for x in match_desc: 343 | hashable_key = x + '_' + str(match_desc.get(x)) 344 | matches = self.match_map_street.get(hashable_key) 345 | if matches is not None: 346 | ids = ids + matches 347 | 348 | return ids 349 | -------------------------------------------------------------------------------- /evaluate_multiDF2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from copy import deepcopy 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from pycocotools import mask as maskUtils 9 | from tqdm import tqdm 10 | 11 | from datasets.MultiDF2Dataset import MultiDeepFashion2Dataset, get_dataloader 12 | from models.video_matchrcnn import videomatchrcnn_resnet50_fpn 13 | from stuffs import transform as T 14 | 15 | 16 | def evaluate(model, data_loader, device, strategy="best_match" 17 | , score_threshold=0.1, k_thresholds=[1, 5, 10, 20] 18 | , frames_per_product=3, tracking_threshold=0.7, first_n_withvideo=None, use_gt=False): 19 | count_products = 0 20 | count_street = 0 21 | shop_descrs = [] 22 | street_descrs = [] 23 | street_aggr_feats = [] 24 | w = None 25 | b = None 26 | temporal_aggregator = model.roi_heads.temporal_aggregator 27 | for images, targets, ids in tqdm(data_loader): 28 | count_products += 1 29 | images = list(image.to(device) for image in images) 30 | targets = [{k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()} for t in targets] 31 | targets = [{k: (v.float() if k == "boxes" else v) for k, v in t.items()} for t in targets] 32 | # forward 33 | targets2 = deepcopy(targets) 34 | with torch.no_grad(): 35 | step = 6 36 | if use_gt: 37 | output = [model(images[x:x + step], targets=targets2[x:x + step]) for x in range(0, len(images), step)] 38 | else: 39 | output = [model(images[x:x + step]) for x in range(0, len(images), step)] 40 | output = [y for x in output for y in x] 41 | 42 | if not any(output[0]["scores"] >= score_threshold): 43 | continue 44 | if w is None: 45 | w = output[0]["w"].detach().cpu().numpy() 46 | b = output[0]["b"].detach().cpu().numpy() 47 | indexes = (output[0]["scores"] >= score_threshold).nonzero().view(-1) 48 | pr_boxes = output[0]["boxes"][indexes].detach().cpu().numpy() 49 | gt_boxes = targets[0]["boxes"].detach().cpu().numpy() 50 | pr_boxes[:, 2] = pr_boxes[:, 2] - pr_boxes[:, 0] 51 | pr_boxes[:, 3] = pr_boxes[:, 3] - pr_boxes[:, 1] 52 | gt_boxes[:, 2] = gt_boxes[:, 2] - gt_boxes[:, 0] 53 | gt_boxes[:, 3] = gt_boxes[:, 3] - gt_boxes[:, 1] 54 | iou = maskUtils.iou(gt_boxes, pr_boxes, np.zeros((pr_boxes.shape[0]))) # gts x preds 55 | style, pair_id = [int(x) for x in targets[0]["i"].split("_")] 56 | prodind = -1 57 | for iind in range(gt_boxes.shape[0]): 58 | if targets[0]["styles"][iind] == style and targets[0]["pair_ids"][iind] == pair_id: 59 | prodind = iind 60 | break 61 | maxind = iou[prodind].argmax() 62 | 63 | tmp_descr = temporal_aggregator(output[0]['roi_features'][maxind].unsqueeze(0) 64 | , torch.IntTensor([1]).to(device) 65 | , torch.LongTensor([0]).to(device))[1].detach().cpu().numpy() 66 | shop_descrs.append((output[0]['match_features'][maxind].detach().cpu().numpy() 67 | , count_products - 1, tmp_descr, targets[0]["i"]) 68 | ) 69 | 70 | if first_n_withvideo is not None and count_products >= first_n_withvideo: 71 | continue 72 | 73 | count_street += 1 74 | 75 | current_start = len(street_descrs) 76 | tmp_roi_feats = [] 77 | for i, o in enumerate(output[1:]): 78 | if any(o['scores'] >= score_threshold): 79 | t = targets[i + 1] 80 | indexes = (o["scores"] >= score_threshold).nonzero().view(-1) 81 | pr_boxes = o["boxes"][indexes].detach().cpu().numpy() 82 | gt_boxes = t["boxes"].detach().cpu().numpy() 83 | pr_boxes[:, 2] = pr_boxes[:, 2] - pr_boxes[:, 0] 84 | pr_boxes[:, 3] = pr_boxes[:, 3] - pr_boxes[:, 1] 85 | gt_boxes[:, 2] = gt_boxes[:, 2] - gt_boxes[:, 0] 86 | gt_boxes[:, 3] = gt_boxes[:, 3] - gt_boxes[:, 1] 87 | iou = maskUtils.iou(gt_boxes, pr_boxes, np.zeros((pr_boxes.shape[0]))) # gts x preds 88 | prodind = -1 89 | for iind in range(gt_boxes.shape[0]): 90 | if targets[0]["styles"][iind] == style and targets[0]["pair_ids"][iind] == pair_id: 91 | prodind = iind 92 | break 93 | maxind = indexes[iou[prodind].argmax()] 94 | street_descrs.append((o['match_features'][maxind].detach().cpu().numpy() 95 | , count_products - 1 96 | , i 97 | , int(maxind.detach().cpu()) 98 | , float(o["scores"][maxind].detach().cpu()) 99 | , o["boxes"][maxind].detach().cpu() 100 | , 101 | )) 102 | tmp_roi_feats.append(o['roi_features'][maxind].unsqueeze(0)) 103 | 104 | current_end = len(street_descrs) 105 | current_street_descrs = street_descrs[current_start:current_end] 106 | street_mat = np.concatenate([x[0][np.newaxis] for x in current_street_descrs]) 107 | tmp_roi_feats = torch.cat(tmp_roi_feats, 0) 108 | 109 | 110 | aggr_feats = temporal_aggregator(tmp_roi_feats.to(device) 111 | , torch.IntTensor([0 for x in range(tmp_roi_feats.shape[0])]).to(device) 112 | , torch.LongTensor([0 for x in range(tmp_roi_feats.shape[0])]) 113 | )[3][1:] 114 | aggr_feats = aggr_feats.view(-1, aggr_feats.shape[-1]).detach().cpu().numpy() 115 | street_aggr_feats.append(aggr_feats) 116 | 117 | torch.cuda.empty_cache() 118 | 119 | shop_mat = np.concatenate([x[0][np.newaxis].astype(np.float16) for x in shop_descrs]) 120 | shop_prods = np.asarray([x[1] for x in shop_descrs]) 121 | shop_datais = np.asarray([x[3] for x in shop_descrs]) 122 | street_mat = np.concatenate([x[0][np.newaxis].astype(np.float16) for x in street_descrs]) 123 | street_prods = np.asarray([x[1] for x in street_descrs]) 124 | street_imgs = np.asarray([x[2] for x in street_descrs]) 125 | street_scores = np.asarray([x[4] for x in street_descrs]) 126 | street_aggr_feats = np.concatenate([x.astype(np.float16) for x in street_aggr_feats]) 127 | shop_aggregated_descrs = np.concatenate([x[2][np.newaxis].astype(np.float16) for x in shop_descrs]).squeeze() 128 | 129 | def compute_ranking(inds): 130 | sq_diffs = (shop_mat[np.newaxis] - street_mat[inds, np.newaxis]) ** 2 131 | match_scores_raw = sq_diffs @ w.transpose().astype(np.float16) + b.astype(np.float16) 132 | match_scores_cls = np.exp(match_scores_raw) / np.exp(match_scores_raw).sum(2)[:, :, np.newaxis] 133 | match_scores = match_scores_cls[:, :, 1] 134 | match_rankings = np.argsort(match_scores, 1)[:, ::-1] 135 | return match_rankings 136 | 137 | def compute_distances(inds): 138 | sq_diffs = (shop_mat[np.newaxis] - street_mat[inds, np.newaxis]) ** 2 139 | match_scores_raw = sq_diffs @ w.transpose().astype(np.float16) + b.astype(np.float16) 140 | match_scores_cls = np.exp(match_scores_raw) / np.exp(match_scores_raw).sum(2)[:, :, np.newaxis] 141 | match_scores = match_scores_cls[:, :, 1] 142 | return match_scores 143 | 144 | aggrW = temporal_aggregator.last.weight.detach().cpu().numpy().astype(np.float16) 145 | aggrB = temporal_aggregator.last.bias.detach().cpu().numpy().astype(np.float16) 146 | 147 | perf = np.zeros((8, len(k_thresholds))) 148 | 149 | k_accs = [0] * len(k_thresholds) 150 | k_accs_avg = [0] * len(k_thresholds) 151 | k_accs_avg_desc = [0] * len(k_thresholds) 152 | k_accs_aggr_desc = [0] * len(k_thresholds) 153 | total_querys = count_street * frames_per_product 154 | k_accs_avg_dist = [0] * len(k_thresholds) 155 | k_accs_max_dist = [0] * len(k_thresholds) 156 | k_accs_max_score = [0] * len(k_thresholds) 157 | 158 | accs_per_product = {} 159 | 160 | all_ranks_list = [] 161 | for p_i in tqdm(range(count_street)): 162 | if p_i in shop_prods: 163 | shop_prod_index = int((shop_prods == p_i).nonzero()[0][0]) 164 | street_prod_indexes = (street_prods == p_i).nonzero()[0] 165 | unique_imgs = np.unique(street_imgs[street_prod_indexes]) 166 | ranks_list = [] 167 | best_inds = [] 168 | distances = [] 169 | scores = [] 170 | 171 | datakey = shop_datais[shop_prod_index] 172 | accs_per_product[datakey] = { 173 | "sfmr": [0] * len(k_thresholds) 174 | , "seamrcnn": [0] * len(k_thresholds) 175 | , "bmfm": [0] * len(k_thresholds) 176 | , "avgdist": [0] * len(k_thresholds) 177 | , "maxdist": [0] * len(k_thresholds) 178 | , "maxscore": [0] * len(k_thresholds) 179 | } 180 | 181 | for i, ii in enumerate(unique_imgs): 182 | tmp_box_inds = ((street_prods == p_i) & (street_imgs == ii)).nonzero()[0] 183 | if strategy == "best_box_only": 184 | tmp_scores = street_scores[tmp_box_inds] 185 | tmp_box_inds = tmp_scores.argmax()[np.newaxis] 186 | 187 | tmp_ranks = (compute_ranking(tmp_box_inds) == shop_prod_index).nonzero()[1] 188 | assert (tmp_ranks.size == 1) 189 | tmp_best_rank = tmp_ranks.item() 190 | best_inds.append(tmp_box_inds[0]) 191 | ranks_list.append(tmp_best_rank) 192 | for j, k in enumerate(k_thresholds): 193 | if tmp_best_rank < k: 194 | accs_per_product[datakey]["sfmr"][j] += 1 195 | k_accs[j] += 1 196 | 197 | distances.append(compute_distances(tmp_box_inds)[tmp_ranks.argmin()]) 198 | scores.append(street_scores[tmp_box_inds[0]]) 199 | 200 | # MAX PER IMAGE 201 | tmp_best_rank = int(np.mean(np.asarray(ranks_list))) 202 | for j, k in enumerate(k_thresholds): 203 | if tmp_best_rank < k: 204 | k_accs_avg[j] += 1 205 | best_inds = np.asarray(best_inds) 206 | 207 | all_ranks_list.extend(ranks_list) 208 | 209 | # AGGR DESC 210 | seq_descs = torch.from_numpy(street_aggr_feats[best_inds]).unsqueeze(1).to(device) 211 | seq_mask = torch.zeros((1, 1 + seq_descs.shape[0]), device=seq_descs.device, dtype=torch.bool) 212 | new_seq_descs = torch.zeros((1 + seq_descs.shape[0], 1, 256) 213 | , device=seq_descs.device, dtype=seq_descs.dtype, requires_grad=False) 214 | new_seq_descs[1:] = seq_descs 215 | aggr_desc = temporal_aggregator(None, None, None 216 | , x3_1_seq=new_seq_descs.to(torch.float32) 217 | , x3_1_mask=seq_mask 218 | , x3_2=torch.from_numpy(shop_aggregated_descrs[shop_prod_index]) 219 | .to(device).to(torch.float32))[0][0].detach().cpu().numpy() 220 | sq_diffs = (shop_aggregated_descrs[np.newaxis] - aggr_desc[np.newaxis, np.newaxis]) ** 2 221 | tmp_aggr_match_scores_raw = sq_diffs @ aggrW.transpose() + aggrB 222 | tmp_aggr_match_scores_cls = np.exp(tmp_aggr_match_scores_raw) \ 223 | / np.exp(tmp_aggr_match_scores_raw).sum(2)[:, :, np.newaxis] 224 | tmp_aggr_match_scores = tmp_aggr_match_scores_cls[:, :, 1] 225 | tmp_aggr_match_rankings = np.argsort(tmp_aggr_match_scores, 1)[:, ::-1] 226 | aggr_desc_rank = (tmp_aggr_match_rankings == shop_prod_index).nonzero()[1].item() 227 | for j, k in enumerate(k_thresholds): 228 | if aggr_desc_rank < k: 229 | accs_per_product[datakey]["seamrcnn"][j] += 1 230 | k_accs_aggr_desc[j] += 1 231 | 232 | # AVG DESC 233 | avg_desc = street_mat[best_inds].mean(0) 234 | sq_diffs = (shop_mat[np.newaxis] - avg_desc[np.newaxis, np.newaxis]) ** 2 235 | match_scores_raw = sq_diffs @ w.transpose().astype(np.float16) + b.astype(np.float16) 236 | match_scores_cls = np.exp(match_scores_raw) / np.exp(match_scores_raw).sum(2)[:, :, np.newaxis] 237 | match_scores_cls = match_scores_cls[:, :, 1] 238 | avg_match_scores = match_scores_cls[0] 239 | tmp_ranks = np.argsort(avg_match_scores)[::-1] 240 | avg_desc_rank = (tmp_ranks == shop_prod_index).nonzero()[0].item() 241 | for j, k in enumerate(k_thresholds): 242 | if avg_desc_rank < k: 243 | accs_per_product[datakey]["bmfm"][j] += 1 244 | k_accs_avg_desc[j] += 1 245 | 246 | # AVG & MAX DISTANCE 247 | distances = np.stack(distances) 248 | avg_distances = distances.mean(0) 249 | tmp_ranks = np.argsort(avg_distances)[::-1] 250 | avg_dist_rank = (tmp_ranks == shop_prod_index).nonzero()[0].item() 251 | for j, k in enumerate(k_thresholds): 252 | if avg_dist_rank < k: 253 | accs_per_product[datakey]["avgdist"][j] += 1 254 | k_accs_avg_dist[j] += 1 255 | max_distances = distances.max(0) 256 | tmp_ranks = np.argsort(max_distances)[::-1] 257 | max_dist_rank = (tmp_ranks == shop_prod_index).nonzero()[0].item() 258 | for j, k in enumerate(k_thresholds): 259 | if max_dist_rank < k: 260 | accs_per_product[datakey]["maxscore"][j] += 1 261 | accs_per_product[datakey]["maxdist"][j] += 1 262 | k_accs_max_dist[j] += 1 263 | 264 | # MAX CONFIDENCE SCORE 265 | scores = np.asarray(scores) 266 | max_score_ind = best_inds[scores.argmax()][np.newaxis] 267 | tmp_ranks = (compute_ranking(max_score_ind) == shop_prod_index).nonzero()[1] 268 | tmp_best_rank = tmp_ranks.item() 269 | for j, k in enumerate(k_thresholds): 270 | if tmp_best_rank < k: 271 | k_accs_max_score[j] += 1 272 | 273 | # PER PRODUCT RESULTS 274 | accs_per_product[datakey]["sfmr"] = np.asarray( 275 | accs_per_product[datakey]["sfmr"]) / frames_per_product 276 | accs_per_product[datakey]["seamrcnn"] = np.asarray(accs_per_product[datakey]["seamrcnn"]) / 1.0 277 | accs_per_product[datakey]["bmfm"] = np.asarray(accs_per_product[datakey]["bmfm"]) / 1.0 278 | accs_per_product[datakey]["avgdist"] = np.asarray(accs_per_product[datakey]["avgdist"]) / 1.0 279 | accs_per_product[datakey]["maxdist"] = np.asarray(accs_per_product[datakey]["maxdist"]) / 1.0 280 | accs_per_product[datakey]["maxscore"] = np.asarray(accs_per_product[datakey]["maxscore"]) / 1.0 281 | 282 | torch.save(accs_per_product, "accs_per_product_10frame_df2.pth") 283 | 284 | for k, k_acc in zip(k_thresholds, k_accs): 285 | print("Top-%d Retrieval Accuracy: %1.4f" % (k, k_acc / total_querys)) 286 | ret1 = k_accs[0] / total_querys 287 | print("*" * 50) 288 | 289 | for k, k_acc in zip(k_thresholds, k_accs_avg_desc): 290 | print("Top-%d Retrieval Accuracy Product Avg Desc: %1.4f" % (k, k_acc / count_street)) 291 | ret2 = k_accs_avg_desc[0] / count_street 292 | print("*" * 50) 293 | 294 | for k, k_acc in zip(k_thresholds, k_accs_aggr_desc): 295 | print("Top-%d Retrieval Accuracy Product Aggr Desc: %1.4f" % (k, k_acc / count_street)) 296 | ret3 = k_accs_aggr_desc[0] / count_street 297 | print("*" * 50) 298 | 299 | for k, k_acc in zip(k_thresholds, k_accs_avg_dist): 300 | print("Top-%d Retrieval Accuracy Product Avg Dist: %1.4f" % (k, k_acc / count_street)) 301 | print("*" * 50) 302 | 303 | for k, k_acc in zip(k_thresholds, k_accs_max_dist): 304 | print("Top-%d Retrieval Accuracy Product Max Dist: %1.4f" % (k, k_acc / count_street)) 305 | print("*" * 50) 306 | 307 | for k, k_acc in zip(k_thresholds, k_accs_max_score): 308 | print("Top-%d Retrieval Accuracy Product Max Score: %1.4f" % (k, k_acc / count_street)) 309 | print("*" * 50) 310 | 311 | all_ranks_list = np.asarray(all_ranks_list) 312 | rm = np.median(all_ranks_list) 313 | rmq1 = np.percentile(all_ranks_list, 25) 314 | rmq3 = np.percentile(all_ranks_list, 75) 315 | print(f"Rank median: {rm}; rank 1st quartile: {rmq1}; rank 3rd quartile: {rmq3}") 316 | 317 | perf[0] = np.asarray(k_accs, dtype=np.float32) / total_querys 318 | perf[1] = np.asarray(k_accs_avg, dtype=np.float32) / count_street 319 | perf[2] = np.asarray(k_accs_avg_desc, dtype=np.float32) / count_street 320 | perf[3] = np.asarray(k_accs_aggr_desc, dtype=np.float32) / count_street 321 | 322 | import time 323 | perf = perf * 100 324 | os.makedirs("logs_mdf2", exist_ok=True) 325 | np.savetxt(os.path.join("logs_mdf2", str(time.time()) + ".csv"), perf, fmt="%02.2f", delimiter="\t") 326 | 327 | return ret1, ret2, ret3 328 | 329 | 330 | if __name__ == '__main__': 331 | 332 | parser = argparse.ArgumentParser(description="PyTorch Object Detection Testing") 333 | parser.add_argument("--local_rank", type=int, default=0) 334 | parser.add_argument("--gpus", type=str, default="0,1") 335 | parser.add_argument("--n_workers", type=int, default=8) 336 | 337 | parser.add_argument("--frames_per_shop_test", type=int, default=10) 338 | parser.add_argument("--first_n_withvideo", type=int, default=100) 339 | parser.add_argument("--fixed_frame", type=int, default=None) 340 | parser.add_argument("--score_threshold", type=float, default=0.0) 341 | 342 | parser.add_argument("--root_test", type=str, default='data/deepfashion2/validation/image') 343 | parser.add_argument("--test_annots", type=str, default='data/deepfashion2/validation/annots.json') 344 | parser.add_argument("--noise", type=bool, default=True) 345 | 346 | parser.add_argument('--ckpt_path', type=str, default="ckpt/SEAM/multiDF2/DF2_epoch031") 347 | 348 | args = parser.parse_args() 349 | 350 | args.batch_size = (1 + args.frames_per_shop_test) * 1 351 | 352 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 353 | gpu_map = [0, 1, 2, 3] 354 | 355 | if 'WORLD_SIZE' in os.environ: 356 | distributed = int(os.environ['WORLD_SIZE']) > 1 357 | rank = args.local_rank 358 | print("Distributed testing with %d processors. This is #%s" 359 | % (int(os.environ['WORLD_SIZE']), rank)) 360 | else: 361 | distributed = False 362 | rank = 0 363 | print("Not distributed testing") 364 | 365 | if distributed: 366 | os.environ['NCCL_BLOCKING_WAIT'] = "1" 367 | torch.cuda.set_device(gpu_map[rank]) 368 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 369 | device = torch.device(torch.cuda.current_device()) 370 | else: 371 | device = torch.device(gpu_map[0]) if torch.cuda.is_available() else torch.device('cpu') 372 | 373 | test_dataset = MultiDeepFashion2Dataset(root=args.root_test 374 | , ann_file=args.test_annots, 375 | transforms=T.ToTensor(), filter_onestreet=True) 376 | 377 | data_loader_test = get_dataloader(test_dataset, batch_size=args.batch_size, is_parallel=distributed, n_products=1, 378 | n_workers=args.n_workers) 379 | 380 | model = videomatchrcnn_resnet50_fpn(pretrained_backbone=True, num_classes=14) 381 | 382 | ckpt = torch.load(args.ckpt_path) 383 | model.load_state_dict(ckpt['model_state_dict']) 384 | 385 | model.to(device) 386 | model.eval() 387 | 388 | evaluate(model, data_loader_test, device, frames_per_product=args.frames_per_shop_test 389 | , first_n_withvideo=args.first_n_withvideo, score_threshold=args.score_threshold) 390 | -------------------------------------------------------------------------------- /models/match_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from pycocotools import mask as maskUtils 5 | from .nlb import NONLocalBlock1D 6 | 7 | 8 | def boxlist_iou(boxlist1, boxlist2): 9 | """Compute the intersection over union of two set of boxes. 10 | The box order must be (xmin, ymin, xmax, ymax). 11 | 12 | Arguments: 13 | box1: (BoxList) bounding boxes, sized [N,4]. 14 | box2: (BoxList) bounding boxes, sized [M,4]. 15 | 16 | Returns: 17 | (tensor) iou, sized [N,M]. 18 | 19 | Reference: 20 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py 21 | """ 22 | if boxlist1.size != boxlist2.size: 23 | raise RuntimeError( 24 | "boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2)) 25 | boxlist1 = boxlist1.convert("xyxy") 26 | boxlist2 = boxlist2.convert("xyxy") 27 | N = len(boxlist1) 28 | M = len(boxlist2) 29 | 30 | area1 = boxlist1.area() 31 | area2 = boxlist2.area() 32 | 33 | box1, box2 = boxlist1.bbox, boxlist2.bbox 34 | 35 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] 36 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] 37 | 38 | TO_REMOVE = 1 39 | 40 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] 41 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 42 | 43 | iou = inter / (area1[:, None] + area2 - inter) 44 | return iou 45 | 46 | 47 | class MatchPredictor(nn.Module): 48 | def __init__(self): 49 | super(MatchPredictor, self).__init__() 50 | self.conv_seq = nn.Sequential(nn.Conv2d(256, 256, 3), 51 | nn.ReLU(), 52 | nn.Conv2d(256, 256, 3), 53 | nn.ReLU(), 54 | nn.Conv2d(256, 256, 3), 55 | nn.ReLU(), 56 | nn.Conv2d(256, 1024, 3), 57 | nn.ReLU(), 58 | ) 59 | self.pool = nn.Sequential(nn.AvgPool2d((6, 6)), 60 | nn.ReLU(), ) 61 | self.linear = nn.Sequential(nn.Linear(1024, 256), 62 | nn.BatchNorm1d(256), ) 63 | 64 | self.last = nn.Linear(256, 2) 65 | 66 | def forward(self, x, types): 67 | x1 = self.conv_seq(x) 68 | x2 = self.pool(x1) 69 | x3 = self.linear(x2.view(x2.size(0), -1)) 70 | x3_1 = x3[types == 0].unsqueeze(1) 71 | x3_2 = x3[types == 1].unsqueeze(0) 72 | 73 | x4 = (x3_1 - x3_2) ** 2 74 | x5 = self.last(x4) 75 | # return x3, F.softmax(x5, dim=-1) 76 | return x3, x5 77 | 78 | 79 | class TemporalAggregationNLB(MatchPredictor): 80 | 81 | def __init__(self, d_model=256): 82 | super(TemporalAggregationNLB, self).__init__() 83 | # same parameters and same forward as standard MatchPredictor 84 | # except for temporal aggregation 85 | self.n_frames = -1 86 | self.attention_scorer = nn.Linear(d_model, 1) 87 | self.newnlb = NONLocalBlock1D(in_channels=d_model, sub_sample=False, bn_layer=False) 88 | self.nlb = True 89 | 90 | def forward(self, x, types, ids, x3_1_seq=None, x3_1_mask=None, x3_2=None, getatt=False): 91 | # x should be (K*(n_frames + 1))x256x14x14 where the one is shop and the n_frames are frames 92 | if x3_1_seq is None: 93 | x1 = self.conv_seq(x) 94 | x2 = self.pool(x1) 95 | x3 = self.linear(x2.view(x2.size(0), -1)) 96 | x3_1 = x3[types == 0] # should be (K*n_frames)x256 97 | x3_1_ids = ids[types == 0] 98 | if x3_1_ids.numel() > 0: 99 | maxlen = (x3_1_ids == x3_1_ids.mode()[0]).sum() 100 | n_seqs = x3_1_ids.unique().numel() 101 | # first token is a dummy where the output is going to be 102 | x3_1_seq = torch.zeros((1 + maxlen, n_seqs, 256), device=x3_1.device, dtype=x3_1.dtype, requires_grad=False) 103 | # True values are to be masked 104 | # https://github.com/pytorch/pytorch/blob/5f25e98fc758ab2f32791364d855be8ff9cb36e7/torch/nn/modules/transformer.py#L66 105 | x3_1_mask = torch.zeros((n_seqs, 1 + maxlen), device=x3_1.device, dtype=torch.bool) 106 | x3_1_list = [] 107 | for i, idd in enumerate(x3_1_ids.unique()): 108 | tmp_n = (x3_1_ids == idd).sum().item() 109 | x3_1_seq[1:tmp_n + 1, i] = x3_1[x3_1_ids == idd] 110 | x3_1_mask[i, tmp_n + 1:] = 1 111 | x3_1_list.append(x3_1[x3_1_ids == idd]) 112 | 113 | 114 | if self.nlb: 115 | x3_1_list = [self.newnlb(x.transpose(0, 1).unsqueeze(0))[0].transpose(0, 1) 116 | if x.shape[0] > 1 else x 117 | for x in x3_1_list] 118 | 119 | x3_1b = [(F.softmax(self.attention_scorer(x), 0) * x3_1_list[i]).sum(0).unsqueeze(0) 120 | for i, x in enumerate(x3_1_list)] 121 | x3_1b = torch.cat(x3_1b, 0) 122 | 123 | if getatt: 124 | attention_scores = [F.softmax(self.attention_scorer(x), 0) for x in x3_1_list] 125 | 126 | x3_1c = x3_1b.unsqueeze(1) 127 | else: 128 | x3_1b = None 129 | x3_1c = None 130 | 131 | x3_2 = x3[types == 1] 132 | x3_2b = x3_2.unsqueeze(0) 133 | else: 134 | 135 | # build list 136 | x3_1_inds = [(x3_1_mask[i]).nonzero()[0].item() if (x3_1_mask[i]).any() 137 | else x3_1_mask[i].numel() 138 | for i in range(x3_1_seq.shape[1])] 139 | x3_1_list = [x3_1_seq[1:x3_1_inds[i], i] for i in range(x3_1_seq.shape[1])] 140 | 141 | 142 | 143 | 144 | if self.nlb: 145 | x3_1_list = [self.newnlb(x.transpose(0, 1).unsqueeze(0))[0].transpose(0, 1) 146 | if x.shape[0] > 1 else x 147 | for x in x3_1_list] 148 | 149 | x3_1b = [(F.softmax(self.attention_scorer(x), 0) * x3_1_list[i]).sum(0).unsqueeze(0) 150 | for i, x in enumerate(x3_1_list)] 151 | x3_1b = torch.cat(x3_1b, 0) 152 | 153 | if getatt: 154 | attention_scores = [F.softmax(self.attention_scorer(x), 0) for x in x3_1_list] 155 | 156 | x3_1c = x3_1b.unsqueeze(1) 157 | x3_2b = x3_2.unsqueeze(0) 158 | x3_1_ids = torch.zeros((1, 2)) # just to have numel > 0 159 | 160 | if x3_1_ids.numel() > 0: 161 | x4 = (x3_1c - x3_2b) ** 2 162 | x5 = self.last(x4) 163 | else: 164 | x5 = None 165 | 166 | if getatt: 167 | return x3_1b, x3_2, x5, x3_1_seq, x3_1_mask, x3_1_ids, attention_scores 168 | 169 | return x3_1b, x3_2, x5, x3_1_seq, x3_1_mask, x3_1_ids 170 | 171 | 172 | class MatchLoss(object): 173 | def __init__(self): 174 | super(MatchLoss, self).__init__() 175 | self.criterion = nn.CrossEntropyLoss() 176 | 177 | def __call__(self, logits, proposals, gt_pairs, gt_styles, types, matched_idxs): 178 | 179 | target_pairs = [l[idxs] for l, idxs in zip(gt_pairs, matched_idxs)] 180 | target_styles = [l[idxs] for l, idxs in zip(gt_styles, matched_idxs)] 181 | # target_pairs, target_styles = self._prepare_target(proposals, targets) 182 | 183 | target_pairs_user = torch.cat(target_pairs)[types == 0] 184 | target_styles_user = torch.cat(target_styles)[types == 0] 185 | 186 | target_pairs_shop = torch.cat(target_pairs)[types == 1] 187 | target_styles_shop = torch.cat(target_styles)[types == 1] 188 | 189 | gts = torch.zeros(len(target_pairs_user), len(target_pairs_shop), dtype=torch.int64).to(logits.device) 190 | for i in range(len(target_pairs_user)): 191 | for j in range(len(target_pairs_shop)): 192 | tpu = target_pairs_user[i] 193 | tps = target_pairs_shop[j] 194 | tsu = target_styles_user[i] 195 | tss = target_styles_shop[j] 196 | if tps == tpu and tsu == tss: 197 | gts[i, j] = 1 198 | else: 199 | gts[i, j] = 0 200 | 201 | gts = gts.view(-1) 202 | logits = logits.view(-1, 2) 203 | 204 | loss = self.criterion(logits, gts) 205 | if loss > 1.0: 206 | loss = loss / 2.0 207 | return loss 208 | 209 | 210 | class MatchLossWeak(object): 211 | 212 | def __init__(self, device, match_threshold=-10.0): 213 | super(MatchLossWeak, self).__init__() 214 | self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0]).to(device)) 215 | self.match_threshold = match_threshold 216 | 217 | def __call__(self, logits, types, prod_ids, img_ids): 218 | img_ids = torch.tensor(img_ids) 219 | prod_ids = torch.tensor(prod_ids) 220 | gts = torch.zeros(logits.shape[0], logits.shape[1], dtype=torch.int64).to(logits.device) 221 | street_inds = (types == 0).nonzero().view(-1) 222 | shop_inds = (types == 1).nonzero().view(-1) 223 | # associate to each detection the corresponding index in the logits 224 | reverse_street_shop_inds = torch.zeros_like(types, dtype=torch.int64) 225 | reverse_street_shop_inds[street_inds] = torch.arange(street_inds.shape[0]) 226 | reverse_street_shop_inds[shop_inds] = torch.arange(shop_inds.shape[0]) 227 | for ii in torch.unique(img_ids): 228 | # get data on detections of this image 229 | tmp_type = int(types[img_ids == ii][0]) 230 | if tmp_type == 1: 231 | continue 232 | # da qui solo se ho uno street 233 | tmp_prod_id = int(prod_ids[img_ids == ii][0]) 234 | tmp_inds = (img_ids == ii).nonzero().view(-1) 235 | # cerco lo shop corrispondente 236 | shop_ind = ((prod_ids == tmp_prod_id) & (types == 1)).nonzero().view(-1) 237 | tmp_logits = logits[reverse_street_shop_inds[tmp_inds], reverse_street_shop_inds[shop_ind], 1].view(-1) 238 | max_score, max_score_ind = tmp_logits.max(), tmp_inds[tmp_logits.argmax()] 239 | 240 | if max_score > self.match_threshold: 241 | gts[reverse_street_shop_inds[max_score_ind], reverse_street_shop_inds[shop_ind]] = 1 242 | 243 | gts = gts.view(-1) 244 | logits = logits.view(-1, 2) 245 | loss = self.criterion(logits, gts) 246 | return loss 247 | 248 | def isin(ar1, ar2): 249 | return (ar1[..., None] == ar2).any(-1) 250 | 251 | 252 | class NEWBalancedAggregationMatchLossWeak(object): 253 | 254 | def __init__(self, device, temporal_aggregator, match_threshold=-10.0): 255 | super(NEWBalancedAggregationMatchLossWeak, self).__init__() 256 | self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 0.3]).to(device)) 257 | self.match_threshold = match_threshold 258 | self.temporal_aggregator = temporal_aggregator 259 | 260 | def __call__(self, match_logits, types, prod_ids, img_ids, roi_features): 261 | 262 | img_ids = torch.tensor(img_ids) 263 | prod_ids = torch.tensor(prod_ids) 264 | street_inds = (types == 0).nonzero().view(-1) 265 | shop_inds = (types == 1).nonzero().view(-1) 266 | # associate to each detection the corresponding index in the logits 267 | reverse_street_shop_inds = torch.zeros_like(types, dtype=torch.int64) 268 | reverse_street_shop_inds[street_inds] = torch.arange(street_inds.shape[0]) 269 | reverse_street_shop_inds[shop_inds] = torch.arange(shop_inds.shape[0]) 270 | aggregation_candidates = [] 271 | 272 | for pi in torch.unique(prod_ids): 273 | tmp_prod_inds = (prod_ids == pi).nonzero().view(-1) 274 | for iiind, ii in enumerate(torch.unique(img_ids[tmp_prod_inds])): 275 | tmp_type = int(types[img_ids == ii][0]) 276 | if tmp_type == 1: 277 | 278 | continue 279 | 280 | tmp_prod_id = pi 281 | tmp_inds = (img_ids == ii).nonzero().view(-1) 282 | shop_ind = ((prod_ids == tmp_prod_id) & (types == 1)).nonzero().view(-1) 283 | tmp_logits = match_logits[reverse_street_shop_inds[tmp_inds], reverse_street_shop_inds[shop_ind], 1].view(-1) 284 | # max_score_ind is the index within the boxes in this image id 285 | max_score, max_score_ind = tmp_logits.max(), tmp_logits.argmax() 286 | if max_score > self.match_threshold: 287 | # save the corresponding tracking ind 288 | aggregation_candidates.append(tmp_inds[max_score_ind]) 289 | if len(aggregation_candidates) == 0: 290 | # doesn't find enough aggregation candidates 291 | return torch.tensor(0, dtype=torch.float32).to(match_logits.device) 292 | # tracking_candidates will contain the "best" box for each image 293 | aggregation_candidates = torch.tensor(aggregation_candidates) 294 | # CLEAN TRACKING LOGITS 295 | 296 | valid_prods = [] 297 | street_feature_inds = [] 298 | gt_flag = [] 299 | seq_ids = [] 300 | seq_count = 0 301 | # build aggregation combinations 302 | for pi in torch.unique(prod_ids[aggregation_candidates]): 303 | tmp_cands = aggregation_candidates[prod_ids[aggregation_candidates] == pi] 304 | if tmp_cands.numel() < self.temporal_aggregator.n_frames: 305 | continue 306 | valid_prods.append(pi) 307 | 308 | tmp_combs = tmp_cands 309 | street_feature_inds.append(tmp_combs) 310 | gt_flag.append(torch.tensor([1] * tmp_combs.numel())) 311 | seq_ids.append(torch.tensor([seq_count] * tmp_combs.numel())) 312 | seq_count += 1 313 | 314 | 315 | 316 | if len(valid_prods) == 0: 317 | # doesn't find enough valid frames 318 | return torch.tensor(0, dtype=torch.float32).to(roi_features.device) 319 | # products for which we have at least n_frames frames (we can compute aggregation matching for them) 320 | # we take the shop images for them 321 | valid_prods = torch.tensor(valid_prods) 322 | 323 | shop_feature_inds = [] 324 | for pi in valid_prods: 325 | shop_ind = ((prod_ids == pi) & (types == 1)).nonzero().view(-1) 326 | shop_feature_inds.append(shop_ind) 327 | shop_feature_inds = torch.tensor(shop_feature_inds) 328 | # used to index features in groups of n_frames, by repeating them 329 | street_feature_inds = torch.cat(street_feature_inds) 330 | gt_flag = torch.cat(gt_flag) 331 | seq_ids = torch.cat(seq_ids) 332 | 333 | feature_inds = torch.cat([street_feature_inds, shop_feature_inds]) 334 | seq_ids = torch.cat([seq_ids] 335 | + [torch.tensor(i + seq_count).view(-1) for i in range(shop_feature_inds.numel())]) 336 | new_roi_features = roi_features[feature_inds] 337 | new_types = types[feature_inds] 338 | 339 | _, _, aggregator_logits, _, _, _ = self.temporal_aggregator(new_roi_features, new_types, seq_ids) 340 | 341 | # gts has a row for every subset of frames and a columns for every valid shop product 342 | 343 | gts = torch.zeros(seq_ids[:street_feature_inds.numel()].unique().numel() 344 | , shop_feature_inds.numel(), dtype=torch.int64) 345 | # for every subset of frames 346 | for i, seq_id in enumerate(seq_ids[:street_feature_inds.numel()].unique()): 347 | # inds of this sequence 348 | street_inds = (seq_ids == seq_id).nonzero().view(-1) 349 | seq_inds = street_feature_inds[street_inds] 350 | # prod_id of this sequence 351 | tmp_prod_id = prod_ids[seq_inds][0] 352 | # find the column for it 353 | j = (tmp_prod_id == valid_prods).nonzero().view(-1) 354 | # set gt to true 355 | gts[i, j] = gt_flag[street_inds[0]] 356 | 357 | gts = gts.view(-1) 358 | logits = aggregator_logits.view(-1, 2) 359 | loss = self.criterion(logits, gts.to(logits.device)) 360 | return loss 361 | 362 | 363 | class MatchLossDF2(object): 364 | 365 | def __init__(self, device): 366 | super(MatchLossDF2, self).__init__() 367 | self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0]).to(device)) 368 | 369 | def __call__(self, logits, types, raw_gt): 370 | street_inds = (types == 0).nonzero().view(-1) 371 | shop_inds = (types == 1).nonzero().view(-1) 372 | raw_gt = torch.tensor(raw_gt) 373 | shop_prods = raw_gt[shop_inds] 374 | street_prods = raw_gt[street_inds] 375 | gts = shop_prods.unsqueeze(0) == street_prods.unsqueeze(1) 376 | gts = gts.view(-1).to(logits.device).to(torch.int64) 377 | logits = logits.view(-1, 2) 378 | loss = self.criterion(logits, gts) 379 | return loss 380 | 381 | 382 | class AggregationMatchLossDF2(object): 383 | 384 | def __init__(self, device, temporal_aggregator): 385 | super(AggregationMatchLossDF2, self).__init__() 386 | self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 0.3]).to(device)) 387 | self.temporal_aggregator = temporal_aggregator 388 | 389 | def __call__(self, types, roi_features, raw_gt): 390 | 391 | street_inds = (types == 0).nonzero().view(-1) 392 | shop_inds = (types == 1).nonzero().view(-1) 393 | raw_gt = torch.tensor(raw_gt) 394 | unprods = raw_gt.unique() 395 | unprods = unprods[unprods > 0] 396 | 397 | 398 | valid_prods = [] 399 | street_feature_inds = [] 400 | seq_ids = [] 401 | 402 | seq_count = 0 403 | # build aggregation combinations 404 | for pi in unprods: 405 | tmp_combs = street_inds[raw_gt[street_inds] == pi] 406 | if tmp_combs.numel() < 3: #self.temporal_aggregator.n_frames: 407 | continue 408 | valid_prods.append(pi) 409 | street_feature_inds.append(tmp_combs) 410 | seq_ids.append(torch.tensor([seq_count] * tmp_combs.numel())) 411 | seq_count += 1 412 | valid_prods = torch.tensor(valid_prods) 413 | street_feature_inds = torch.cat(street_feature_inds) 414 | try: 415 | seq_ids = torch.cat(seq_ids) 416 | except: 417 | print(seq_ids) 418 | quit() 419 | 420 | shop_feature_inds = shop_inds 421 | shop_feature_inds = torch.tensor(shop_feature_inds) 422 | 423 | feature_inds = torch.cat([street_feature_inds, shop_feature_inds]) 424 | seq_ids = torch.cat([seq_ids] 425 | + [torch.tensor(i + seq_count).view(-1) for i in range(shop_feature_inds.numel())]) 426 | new_roi_features = roi_features[feature_inds] 427 | new_types = types[feature_inds] 428 | 429 | _, _, aggregator_logits, _, _, _ = self.temporal_aggregator(new_roi_features, new_types, seq_ids) 430 | 431 | shop_prods = raw_gt[shop_inds] 432 | street_prods = valid_prods 433 | gts = shop_prods.unsqueeze(0) == street_prods.unsqueeze(1) 434 | 435 | gts = gts.view(-1).to(aggregator_logits.device).to(torch.int64) 436 | logits = aggregator_logits.view(-1, 2) 437 | loss = self.criterion(logits, gts) 438 | return loss 439 | 440 | 441 | def filter_proposals(proposals, mask_roi_features, gt_proposals, matched_idxs): 442 | match_imgs_mask = [] 443 | new_mask_roi_features = [] 444 | for i, pr_prop in enumerate(proposals): 445 | g_prop = gt_proposals[i] 446 | match_idxs = matched_idxs[i] 447 | match_imgs_mask = match_imgs_mask + ([i] * match_idxs.size(0)) 448 | 449 | n_valid = g_prop.size(0) 450 | ious = torch.FloatTensor( 451 | maskUtils.iou(pr_prop.detach().cpu().numpy(), g_prop.detach().cpu().numpy(), 452 | [0] * n_valid)).squeeze() 453 | if len(pr_prop) > 1: 454 | topKidxs = torch.argsort(ious, descending=True, dim=0)[ 455 | :torch.min(torch.tensor([(8 // n_valid), len(pr_prop)]))].view(-1) 456 | proposals[i] = pr_prop[topKidxs, :] 457 | matched_idxs[i] = match_idxs[topKidxs] 458 | new_mask_roi_features.append(mask_roi_features[torch.where(torch.IntTensor(match_imgs_mask) == i)[0], ...]) 459 | new_mask_roi_features[i] = new_mask_roi_features[i][topKidxs, ...] 460 | else: 461 | new_mask_roi_features.append(mask_roi_features[torch.where(torch.IntTensor(match_imgs_mask) == i)[0], ...]) 462 | 463 | return proposals, torch.cat(new_mask_roi_features), matched_idxs 464 | 465 | 466 | class MatchLossPreTrained(object): 467 | def __init__(self): 468 | super(MatchLossPreTrained, self).__init__() 469 | self.criterion = nn.CrossEntropyLoss() 470 | 471 | def __call__(self, logits, proposals, gt_proposals, gt_pairs, gt_styles, types, matched_idxs): 472 | 473 | target_pairs = [l[idxs] for l, idxs in zip(gt_pairs, matched_idxs)] 474 | target_styles = [l[idxs] for l, idxs in zip(gt_styles, matched_idxs)] 475 | # target_pairs, target_styles = self._prepare_target(proposals, targets) 476 | 477 | target_pairs_user = torch.cat(target_pairs)[types == 0] 478 | target_styles_user = torch.cat(target_styles)[types == 0] 479 | 480 | target_pairs_shop = torch.cat(target_pairs)[types == 1] 481 | target_styles_shop = torch.cat(target_styles)[types == 1] 482 | 483 | gts = torch.zeros(len(target_pairs_user), len(target_pairs_shop), dtype=torch.int64).to(logits.device) 484 | for i in range(len(target_pairs_user)): 485 | for j in range(len(target_pairs_shop)): 486 | tpu = target_pairs_user[i] 487 | tps = target_pairs_shop[j] 488 | tsu = target_styles_user[i] 489 | tss = target_styles_shop[j] 490 | if tps == tpu and tsu == tss and tss != 0 and tsu != 0: 491 | gts[i, j] = 1 492 | else: 493 | gts[i, j] = 0 494 | 495 | gts = gts.view(-1) 496 | 497 | logits = logits.view(-1, 2) 498 | loss = self.criterion(logits, gts) 499 | 500 | if loss > 1.0: 501 | loss = loss / 2.0 502 | 503 | 504 | return loss 505 | -------------------------------------------------------------------------------- /models/matchrcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 3 | from torchvision.models.detection.mask_rcnn import MaskRCNN 4 | from torch.hub import load_state_dict_from_url 5 | from torchvision.ops import boxes as box_ops 6 | import torch.nn.functional as F 7 | from torchvision.models.detection.roi_heads import fastrcnn_loss, maskrcnn_loss, maskrcnn_inference 8 | from .match_head import MatchPredictor, MatchLossPreTrained, filter_proposals 9 | 10 | 11 | from torchvision.models.detection.rpn import AnchorGenerator 12 | from torchvision.ops import MultiScaleRoIAlign 13 | 14 | params = { 15 | 'rpn_anchor_generator': AnchorGenerator((32, 64, 128, 256, 512), (0.5, 1.0, 2.0)), 16 | 'rpn_pre_nms_top_n_train': 2000, 17 | 'rpn_pre_nms_top_n_test': 1000, 18 | 'rpn_post_nms_top_n_test': 4000, 19 | 'rpn_post_nms_top_n_train': 8000, 20 | 21 | 'box_roi_pool': MultiScaleRoIAlign( 22 | featmap_names=['0', '1', '2', '3'], 23 | output_size=7, 24 | sampling_ratio=2), 25 | 'mask_roi_pool': MultiScaleRoIAlign( 26 | featmap_names=['0', '1', '2', '3'], 27 | output_size=14, 28 | sampling_ratio=2), 29 | } 30 | 31 | 32 | model_urls = { 33 | 'maskrcnn_resnet50_fpn_coco': 34 | 'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth', 35 | } 36 | 37 | 38 | class NewRoIHeads(torch.nn.Module): 39 | def __init__(self, orh): 40 | # orh: old_roi_heads 41 | super(NewRoIHeads, self).__init__() 42 | 43 | self.box_roi_pool = orh.box_roi_pool 44 | self.box_head = orh.box_head 45 | self.box_predictor = orh.box_predictor 46 | 47 | self.mask_roi_pool = orh.mask_roi_pool 48 | self.mask_head = orh.mask_head 49 | self.mask_predictor = orh.mask_predictor 50 | 51 | self.match_predictor = MatchPredictor() 52 | self.match_loss = MatchLossPreTrained() 53 | 54 | self.keypoint_roi_pool = None 55 | self.keypoint_head = None 56 | self.keypoint_predictor = None 57 | 58 | self.score_thresh = orh.score_thresh 59 | self.nms_thresh = orh.nms_thresh 60 | self.detections_per_img = orh.detections_per_img 61 | 62 | self.proposal_matcher = orh.proposal_matcher 63 | self.fg_bg_sampler = orh.fg_bg_sampler 64 | self.box_coder = orh.box_coder 65 | 66 | self.box_similarity = box_ops.box_iou 67 | 68 | @property 69 | def has_mask(self): 70 | if self.mask_roi_pool is None: 71 | return False 72 | if self.mask_head is None: 73 | return False 74 | if self.mask_predictor is None: 75 | return False 76 | return True 77 | 78 | @property 79 | def has_keypoint(self): 80 | if self.keypoint_roi_pool is None: 81 | return False 82 | if self.keypoint_head is None: 83 | return False 84 | if self.keypoint_predictor is None: 85 | return False 86 | return True 87 | 88 | @property 89 | def has_match(self): 90 | if self.match_predictor is None: 91 | return False 92 | if self.match_loss is None: 93 | return False 94 | return True 95 | 96 | def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): 97 | matched_idxs = [] 98 | labels = [] 99 | for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): 100 | match_quality_matrix = self.box_similarity(gt_boxes_in_image, proposals_in_image) 101 | matched_idxs_in_image = self.proposal_matcher(match_quality_matrix) 102 | 103 | clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0) 104 | 105 | labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image] 106 | labels_in_image = labels_in_image.to(dtype=torch.int64) 107 | 108 | # Label background (below the low threshold) 109 | bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD 110 | labels_in_image[bg_inds] = 0 111 | 112 | # Label ignore proposals (between low and high thresholds) 113 | ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS 114 | labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler 115 | 116 | matched_idxs.append(clamped_matched_idxs_in_image) 117 | labels.append(labels_in_image) 118 | return matched_idxs, labels 119 | 120 | def subsample(self, labels): 121 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) 122 | sampled_inds = [] 123 | for img_idx, (pos_inds_img, neg_inds_img) in enumerate( 124 | zip(sampled_pos_inds, sampled_neg_inds) 125 | ): 126 | img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) 127 | sampled_inds.append(img_sampled_inds) 128 | return sampled_inds 129 | 130 | def add_gt_proposals(self, proposals, gt_boxes): 131 | proposals = [ 132 | torch.cat((proposal, gt_box)) 133 | for proposal, gt_box in zip(proposals, gt_boxes) 134 | ] 135 | 136 | return proposals 137 | 138 | def check_targets(self, targets): 139 | assert targets is not None 140 | assert all("boxes" in t for t in targets) 141 | assert all("labels" in t for t in targets) 142 | if self.has_mask: 143 | assert all("masks" in t for t in targets) 144 | 145 | def select_training_samples(self, proposals, targets): 146 | self.check_targets(targets) 147 | gt_boxes = [t["boxes"] for t in targets] 148 | gt_labels = [t["labels"] for t in targets] 149 | 150 | # append ground-truth bboxes to propos 151 | proposals = self.add_gt_proposals(proposals, gt_boxes) 152 | 153 | # get matching gt indices for each proposal 154 | matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) 155 | # sample a fixed proportion of positive-negative proposals 156 | sampled_inds = self.subsample(labels) 157 | matched_gt_boxes = [] 158 | num_images = len(proposals) 159 | for img_id in range(num_images): 160 | img_sampled_inds = sampled_inds[img_id] 161 | proposals[img_id] = proposals[img_id][img_sampled_inds] 162 | labels[img_id] = labels[img_id][img_sampled_inds] 163 | matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds] 164 | matched_gt_boxes.append(gt_boxes[img_id][matched_idxs[img_id]]) 165 | 166 | regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) 167 | return proposals, matched_idxs, labels, regression_targets 168 | 169 | def keypoints_to_heatmap(self, keypoints, rois, heatmap_size): 170 | offset_x = rois[:, 0] 171 | offset_y = rois[:, 1] 172 | scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) 173 | scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) 174 | 175 | offset_x = offset_x[:, None] 176 | offset_y = offset_y[:, None] 177 | scale_x = scale_x[:, None] 178 | scale_y = scale_y[:, None] 179 | 180 | x = keypoints[..., 0] 181 | y = keypoints[..., 1] 182 | 183 | x_boundary_inds = x == rois[:, 2][:, None] 184 | y_boundary_inds = y == rois[:, 3][:, None] 185 | 186 | x = (x - offset_x) * scale_x 187 | x = x.floor().long() 188 | y = (y - offset_y) * scale_y 189 | y = y.floor().long() 190 | 191 | x[x_boundary_inds] = heatmap_size - 1 192 | y[y_boundary_inds] = heatmap_size - 1 193 | 194 | valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) 195 | vis = keypoints[..., 2] > 0 196 | valid = (valid_loc & vis).long() 197 | 198 | lin_ind = y * heatmap_size + x 199 | heatmaps = lin_ind * valid 200 | 201 | return heatmaps, valid 202 | 203 | def heatmaps_to_keypoints(self, maps, rois): 204 | """Extract predicted keypoint locations from heatmaps. Output has shape 205 | (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) 206 | for each keypoint. 207 | """ 208 | # This function converts a discrete image coordinate in a HEATMAP_SIZE x 209 | # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain 210 | # consistency with keypoints_to_heatmap_labels by using the conversion from 211 | # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a 212 | # continuous coordinate. 213 | offset_x = rois[:, 0] 214 | offset_y = rois[:, 1] 215 | 216 | widths = rois[:, 2] - rois[:, 0] 217 | heights = rois[:, 3] - rois[:, 1] 218 | widths = widths.clamp(min=1) 219 | heights = heights.clamp(min=1) 220 | widths_ceil = widths.ceil() 221 | heights_ceil = heights.ceil() 222 | 223 | num_keypoints = maps.shape[1] 224 | xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device) 225 | end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device) 226 | for i in range(len(rois)): 227 | roi_map_width = int(widths_ceil[i].item()) 228 | roi_map_height = int(heights_ceil[i].item()) 229 | width_correction = widths[i] / roi_map_width 230 | height_correction = heights[i] / roi_map_height 231 | roi_map = torch.nn.functional.interpolate( 232 | maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0] 233 | # roi_map_probs = scores_to_probs(roi_map.copy()) 234 | w = roi_map.shape[2] 235 | pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) 236 | x_int = pos % w 237 | y_int = (pos - x_int) // w 238 | # assert (roi_map_probs[k, y_int, x_int] == 239 | # roi_map_probs[k, :, :].max()) 240 | x = (x_int.float() + 0.5) * width_correction 241 | y = (y_int.float() + 0.5) * height_correction 242 | xy_preds[i, 0, :] = x + offset_x[i] 243 | xy_preds[i, 1, :] = y + offset_y[i] 244 | xy_preds[i, 2, :] = 1 245 | end_scores[i, :] = roi_map[torch.arange(num_keypoints), y_int, x_int] 246 | 247 | return xy_preds.permute(0, 2, 1), end_scores 248 | 249 | def keypointrcnn_loss(self, keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs): 250 | N, K, H, W = keypoint_logits.shape 251 | assert H == W 252 | discretization_size = H 253 | heatmaps = [] 254 | valid = [] 255 | 256 | indx = [x for x in gt_keypoints] 257 | 258 | for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): 259 | kp = gt_kp_in_image[midx] 260 | heatmaps_per_image, valid_per_image = self.keypoints_to_heatmap( 261 | kp, proposals_per_image, discretization_size 262 | ) 263 | heatmaps.append(heatmaps_per_image.view(-1)) 264 | valid.append(valid_per_image.view(-1)) 265 | 266 | keypoint_targets = torch.cat(heatmaps, dim=0) 267 | valid = torch.cat(valid, dim=0).to(dtype=torch.uint8) 268 | valid = torch.nonzero(valid).squeeze(1) 269 | 270 | # torch.mean (in binary_cross_entropy_with_logits) does'nt 271 | # accept empty tensors, so handle it sepaartely 272 | if keypoint_targets.numel() == 0 or len(valid) == 0: 273 | return keypoint_logits.sum() * 0 274 | 275 | keypoint_logits = keypoint_logits.view(N * K, H * W) 276 | 277 | keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) 278 | return keypoint_loss 279 | 280 | def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes): 281 | device = class_logits.device 282 | num_classes = class_logits.shape[-1] 283 | 284 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] 285 | pred_boxes = self.box_coder.decode(box_regression, proposals) 286 | 287 | pred_scores = F.softmax(class_logits, -1) 288 | 289 | # split boxes and scores per image 290 | pred_boxes = pred_boxes.split(boxes_per_image, 0) 291 | pred_scores = pred_scores.split(boxes_per_image, 0) 292 | 293 | all_boxes = [] 294 | all_scores = [] 295 | all_labels = [] 296 | for boxes, scores, image_shape in zip(pred_boxes, pred_scores, image_shapes): 297 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 298 | 299 | # create labels for each prediction 300 | labels = torch.arange(num_classes, device=device) 301 | labels = labels.view(1, -1).expand_as(scores) 302 | 303 | # remove predictions with the background label 304 | boxes = boxes[:, 1:] 305 | scores = scores[:, 1:] 306 | labels = labels[:, 1:] 307 | 308 | # batch everything, by making every class prediction be a separate instance 309 | boxes = boxes.reshape(-1, 4) 310 | scores = scores.flatten() 311 | labels = labels.flatten() 312 | 313 | # remove low scoring boxes 314 | inds = torch.nonzero(scores > self.score_thresh).squeeze(1) 315 | boxes, scores, labels = boxes[inds], scores[inds], labels[inds] 316 | 317 | # remove empty boxes 318 | keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) 319 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 320 | 321 | # non-maximum suppression, independently done per class 322 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) 323 | # keep only topk scoring predictions 324 | keep = keep[:self.detections_per_img] 325 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 326 | 327 | all_boxes.append(boxes) 328 | all_scores.append(scores) 329 | all_labels.append(labels) 330 | 331 | return all_boxes, all_scores, all_labels 332 | 333 | def forward(self, features, proposals, image_shapes, targets=None): 334 | """ 335 | Arguments: 336 | features (List[Tensor]) 337 | proposals (List[Tensor[N, 4]]) 338 | image_shapes (List[Tuple[H, W]]) 339 | targets (List[Dict]) 340 | """ 341 | if targets is not None: 342 | for t in targets: 343 | assert t["boxes"].dtype.is_floating_point, 'target boxes must of float type' 344 | assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' 345 | if self.has_keypoint: 346 | assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' 347 | 348 | if self.training: 349 | proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) 350 | 351 | box_features_roi = self.box_roi_pool(features, proposals, image_shapes) 352 | box_features = self.box_head(box_features_roi) 353 | class_logits, box_regression = self.box_predictor(box_features) 354 | 355 | result, losses = [], {} 356 | if self.training: 357 | loss_classifier, loss_box_reg = fastrcnn_loss( 358 | class_logits, box_regression, labels, regression_targets) 359 | losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg) 360 | else: 361 | boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) 362 | num_images = len(boxes) 363 | for i in range(num_images): 364 | if boxes[i].numel() > 0: 365 | result.append( 366 | dict( 367 | boxes=boxes[i], 368 | labels=labels[i], 369 | scores=scores[i], 370 | ) 371 | ) 372 | else: 373 | result.append( 374 | dict( 375 | boxes=torch.tensor([0.0, 0.0, image_shapes[i][1], image_shapes[i][0]]).to( 376 | boxes[i].device).unsqueeze(0), 377 | labels=torch.tensor([0]).to(boxes[i].device), 378 | scores=torch.tensor([1.0]).to(boxes[i].device), 379 | ) 380 | ) 381 | 382 | if self.has_mask: 383 | mask_proposals = [p["boxes"] for p in result] 384 | if self.training: 385 | # during training, only focus on positive boxes 386 | num_images = len(proposals) 387 | mask_proposals = [] 388 | pos_matched_idxs = [] 389 | for img_id in range(num_images): 390 | pos = torch.nonzero(labels[img_id] > 0).squeeze(1) 391 | mask_proposals.append(proposals[img_id][pos]) 392 | pos_matched_idxs.append(matched_idxs[img_id][pos]) 393 | 394 | mask_roi_features = self.mask_roi_pool(features, mask_proposals, image_shapes) 395 | mask_features = self.mask_head(mask_roi_features) 396 | mask_logits = self.mask_predictor(mask_features) 397 | 398 | loss_mask = {} 399 | if self.training: 400 | gt_masks = [t["masks"] for t in targets] 401 | gt_labels = [t["labels"] for t in targets] 402 | loss_mask = maskrcnn_loss( 403 | mask_logits, mask_proposals, 404 | gt_masks, gt_labels, pos_matched_idxs) 405 | loss_mask = dict(loss_mask=loss_mask) 406 | else: 407 | labels = [r["labels"] for r in result] 408 | masks_probs = maskrcnn_inference(mask_logits, labels) 409 | for mask_prob, r in zip(masks_probs, result): 410 | r["masks"] = mask_prob 411 | 412 | losses.update(loss_mask) 413 | 414 | if self.has_match: 415 | match_proposals = [p["boxes"] for p in result] 416 | if self.training: 417 | gt_proposals = [t["boxes"] for t in targets] 418 | num_images = len(proposals) 419 | match_proposals = [] 420 | pos_matched_idxs = [] 421 | for img_id in range(num_images): 422 | pos = torch.nonzero(labels[img_id] > 0).squeeze(1) 423 | match_proposals.append(proposals[img_id][pos]) 424 | pos_matched_idxs.append(matched_idxs[img_id][pos]) 425 | 426 | match_roi_features = self.mask_roi_pool(features, match_proposals, image_shapes) 427 | match_proposals, mask_roi_features, matched_idxs_match = filter_proposals(match_proposals, 428 | match_roi_features, 429 | gt_proposals, 430 | pos_matched_idxs) 431 | types = [] 432 | s_imgs = [] 433 | i = 0 434 | for p, s in zip(match_proposals, targets): 435 | types = types + ([1] * len(p) if s['sources'][0] == 1 else [0] * len(p)) 436 | s_imgs = s_imgs + ([i] * len(p)) 437 | i += 1 438 | types = torch.IntTensor(types) 439 | # match_roi_features = self.mask_roi_pool(features, match_proposals, image_shapes) 440 | final_features, match_logits = self.match_predictor(mask_roi_features, types) 441 | 442 | gt_pairs = [t["pair_ids"] for t in targets] 443 | gt_styles = [t["styles"] for t in targets] 444 | 445 | loss_match = self.match_loss(match_logits, match_proposals, gt_proposals, gt_pairs, gt_styles, types, 446 | pos_matched_idxs) 447 | 448 | loss_match = dict(loss_match=loss_match) 449 | 450 | 451 | else: 452 | loss_match = {} 453 | 454 | s_imgs = [] 455 | for i, p in enumerate(match_proposals): 456 | if i == 0: 457 | types = [0] * len(p) 458 | else: 459 | types = types + [1] * len(p) 460 | s_imgs = s_imgs + ([i] * len(p)) 461 | 462 | types = torch.IntTensor(types) 463 | match_roi_features = self.mask_roi_pool(features, match_proposals, image_shapes) 464 | final_features, match_logits = self.match_predictor(match_roi_features, types) 465 | for i, r in zip(range(len(match_proposals)), result): 466 | r['match_features'] = final_features[torch.IntTensor(s_imgs) == i, ...] 467 | r['w'] = self.match_predictor.last.weight 468 | r['b'] = self.match_predictor.last.bias 469 | 470 | losses.update(loss_match) 471 | 472 | return result, losses 473 | 474 | 475 | class MatchRCNN(MaskRCNN): 476 | def __init__(self, backbone, num_classes, **kwargs): 477 | super(MatchRCNN, self).__init__(backbone, num_classes, **kwargs) 478 | self.roi_heads = NewRoIHeads(self.roi_heads) 479 | 480 | 481 | def matchrcnn_resnet50_fpn(pretrained=False, progress=True, 482 | num_classes=91, pretrained_backbone=True, **kwargs): 483 | if pretrained: 484 | # no need to download the backbone if pretrained is set 485 | pretrained_backbone = False 486 | backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) 487 | model = MatchRCNN(backbone, num_classes, **kwargs) 488 | if pretrained: 489 | state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], 490 | progress=progress) 491 | model.load_state_dict(state_dict) 492 | return model 493 | -------------------------------------------------------------------------------- /models/nlb.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | class _NonLocalBlockND(nn.Module): 5 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 6 | super(_NonLocalBlockND, self).__init__() 7 | 8 | assert dimension in [1, 2, 3] 9 | 10 | self.dimension = dimension 11 | self.sub_sample = sub_sample 12 | 13 | self.in_channels = in_channels 14 | self.inter_channels = inter_channels 15 | 16 | if self.inter_channels is None: 17 | self.inter_channels = in_channels // 2 18 | if self.inter_channels == 0: 19 | self.inter_channels = 1 20 | 21 | if dimension == 3: 22 | conv_nd = nn.Conv3d 23 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 24 | bn = nn.BatchNorm3d 25 | elif dimension == 2: 26 | conv_nd = nn.Conv2d 27 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 28 | bn = nn.BatchNorm2d 29 | else: 30 | conv_nd = nn.Conv1d 31 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 32 | bn = nn.BatchNorm1d 33 | 34 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 35 | kernel_size=1, stride=1, padding=0) 36 | 37 | if bn_layer: 38 | self.W = nn.Sequential( 39 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 40 | kernel_size=1, stride=1, padding=0), 41 | bn(self.in_channels) 42 | ) 43 | # nn.init.constant_(self.W[1].weight, 0) 44 | nn.init.constant_(self.W[1].bias, 0) 45 | else: 46 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 47 | kernel_size=1, stride=1, padding=0) 48 | nn.init.constant_(self.W.weight, 0) 49 | nn.init.constant_(self.W.bias, 0) 50 | 51 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 52 | kernel_size=1, stride=1, padding=0) 53 | 54 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 55 | kernel_size=1, stride=1, padding=0) 56 | 57 | self.concat_project = nn.Sequential( 58 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 59 | nn.ReLU() 60 | ) 61 | 62 | if sub_sample: 63 | self.g = nn.Sequential(self.g, max_pool_layer) 64 | self.phi = nn.Sequential(self.phi, max_pool_layer) 65 | 66 | def forward(self, x): 67 | ''' 68 | :param x: (b, c, t, h, w) 69 | :return: 70 | ''' 71 | 72 | batch_size = x.size(0) 73 | 74 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 75 | g_x = g_x.permute(0, 2, 1) 76 | 77 | # (b, c, N, 1) 78 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 79 | # (b, c, 1, N) 80 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 81 | 82 | h = theta_x.size(2) 83 | w = phi_x.size(3) 84 | theta_x = theta_x.repeat(1, 1, 1, w) 85 | phi_x = phi_x.repeat(1, 1, h, 1) 86 | 87 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 88 | f = self.concat_project(concat_feature) 89 | b, _, h, w = f.size() 90 | f = f.view(b, h, w) 91 | 92 | N = f.size(-1) 93 | f_div_C = f / N 94 | 95 | y = torch.matmul(f_div_C, g_x) 96 | y = y.permute(0, 2, 1).contiguous() 97 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 98 | W_y = self.W(y) 99 | z = W_y + x 100 | 101 | return z 102 | 103 | 104 | class NONLocalBlock1D(_NonLocalBlockND): 105 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 106 | super(NONLocalBlock1D, self).__init__(in_channels, 107 | inter_channels=inter_channels, 108 | dimension=1, sub_sample=sub_sample, 109 | bn_layer=bn_layer) 110 | 111 | 112 | class NONLocalBlock2D(_NonLocalBlockND): 113 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 114 | super(NONLocalBlock2D, self).__init__(in_channels, 115 | inter_channels=inter_channels, 116 | dimension=2, sub_sample=sub_sample, 117 | bn_layer=bn_layer) 118 | 119 | 120 | class NONLocalBlock3D(_NonLocalBlockND): 121 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 122 | super(NONLocalBlock3D, self).__init__(in_channels, 123 | inter_channels=inter_channels, 124 | dimension=3, sub_sample=sub_sample, 125 | bn_layer=bn_layer) 126 | 127 | 128 | if __name__ == '__main__': 129 | import torch 130 | 131 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 132 | img = torch.zeros(2, 3, 20) 133 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 134 | out = net(img) 135 | print(out.size()) 136 | 137 | img = torch.zeros(2, 3, 20, 20) 138 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | img = torch.randn(2, 3, 8, 20, 20) 143 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 144 | out = net(img) 145 | print(out.size()) 146 | -------------------------------------------------------------------------------- /models/video_matchrcnn.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.hub import load_state_dict_from_url 6 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 7 | from torchvision.models.detection.mask_rcnn import MaskRCNN 8 | from torchvision.models.detection.roi_heads import fastrcnn_loss, maskrcnn_loss, maskrcnn_inference 9 | from torchvision.ops import boxes as box_ops 10 | 11 | from .match_head import MatchPredictor, MatchLoss, TemporalAggregationNLB as TemporalAggregation 12 | 13 | model_urls = { 14 | 'maskrcnn_resnet50_fpn_coco': 15 | 'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth', 16 | } 17 | 18 | 19 | class TemporalRoIHeads(torch.nn.Module): 20 | 21 | def __init__(self, orh, n_frames): 22 | # orh: old_roi_heads 23 | super(TemporalRoIHeads, self).__init__() 24 | self.n_frames = n_frames 25 | 26 | self.box_roi_pool = orh.box_roi_pool 27 | self.box_head = orh.box_head 28 | self.box_predictor = orh.box_predictor 29 | 30 | self.mask_roi_pool = orh.mask_roi_pool 31 | self.mask_head = orh.mask_head 32 | self.mask_predictor = orh.mask_predictor 33 | 34 | self.match_predictor = MatchPredictor() 35 | self.match_loss = MatchLoss() 36 | 37 | self.temporal_aggregator = TemporalAggregation() 38 | 39 | self.keypoint_roi_pool = orh.keypoint_head 40 | self.keypoint_head = orh.keypoint_head 41 | self.keypoint_predictor = orh.keypoint_predictor 42 | 43 | self.score_thresh = orh.score_thresh 44 | self.nms_thresh = orh.nms_thresh 45 | self.detections_per_img = orh.detections_per_img 46 | 47 | self.proposal_matcher = orh.proposal_matcher 48 | self.fg_bg_sampler = orh.fg_bg_sampler 49 | self.box_coder = orh.box_coder 50 | 51 | self.box_similarity = box_ops.box_iou 52 | 53 | @property 54 | def has_mask(self): 55 | if self.mask_roi_pool is None: 56 | return False 57 | if self.mask_head is None: 58 | return False 59 | if self.mask_predictor is None: 60 | return False 61 | return True 62 | 63 | @property 64 | def has_keypoint(self): 65 | if self.keypoint_roi_pool is None: 66 | return False 67 | if self.keypoint_head is None: 68 | return False 69 | if self.keypoint_predictor is None: 70 | return False 71 | return True 72 | 73 | @property 74 | def has_match(self): 75 | if self.match_predictor is None: 76 | return False 77 | if self.match_loss is None: 78 | return False 79 | return True 80 | 81 | def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): 82 | matched_idxs = [] 83 | labels = [] 84 | for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): 85 | match_quality_matrix = self.box_similarity(gt_boxes_in_image, proposals_in_image) 86 | matched_idxs_in_image = self.proposal_matcher(match_quality_matrix) 87 | 88 | clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0) 89 | 90 | labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image] 91 | labels_in_image = labels_in_image.to(dtype=torch.int64) 92 | 93 | # Label background (below the low threshold) 94 | bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD 95 | labels_in_image[bg_inds] = 0 96 | 97 | # Label ignore proposals (between low and high thresholds) 98 | ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS 99 | labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler 100 | 101 | matched_idxs.append(clamped_matched_idxs_in_image) 102 | labels.append(labels_in_image) 103 | return matched_idxs, labels 104 | 105 | def subsample(self, labels): 106 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) 107 | sampled_inds = [] 108 | for img_idx, (pos_inds_img, neg_inds_img) in enumerate( 109 | zip(sampled_pos_inds, sampled_neg_inds) 110 | ): 111 | img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) 112 | sampled_inds.append(img_sampled_inds) 113 | return sampled_inds 114 | 115 | def add_gt_proposals(self, proposals, gt_boxes): 116 | proposals = [ 117 | torch.cat((proposal, gt_box)) 118 | for proposal, gt_box in zip(proposals, gt_boxes) 119 | ] 120 | 121 | return proposals 122 | 123 | def check_targets(self, targets): 124 | assert targets is not None 125 | assert all("boxes" in t for t in targets) 126 | assert all("labels" in t for t in targets) 127 | if self.has_mask: 128 | assert all("masks" in t for t in targets) 129 | 130 | def select_training_samples(self, proposals, targets): 131 | self.check_targets(targets) 132 | gt_boxes = [t["boxes"] for t in targets] 133 | gt_labels = [t["labels"] for t in targets] 134 | 135 | # append ground-truth bboxes to propos 136 | proposals = self.add_gt_proposals(proposals, gt_boxes) 137 | 138 | # get matching gt indices for each proposal 139 | matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) 140 | # sample a fixed proportion of positive-negative proposals 141 | sampled_inds = self.subsample(labels) 142 | matched_gt_boxes = [] 143 | num_images = len(proposals) 144 | for img_id in range(num_images): 145 | img_sampled_inds = sampled_inds[img_id] 146 | proposals[img_id] = proposals[img_id][img_sampled_inds] 147 | labels[img_id] = labels[img_id][img_sampled_inds] 148 | matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds] 149 | matched_gt_boxes.append(gt_boxes[img_id][matched_idxs[img_id]]) 150 | 151 | regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) 152 | return proposals, matched_idxs, labels, regression_targets 153 | 154 | def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes): 155 | device = class_logits.device 156 | num_classes = class_logits.shape[-1] 157 | 158 | boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] 159 | pred_boxes = self.box_coder.decode(box_regression, proposals) 160 | 161 | pred_scores = F.softmax(class_logits, -1) 162 | 163 | # split boxes and scores per image 164 | pred_boxes = pred_boxes.split(boxes_per_image, 0) 165 | pred_scores = pred_scores.split(boxes_per_image, 0) 166 | 167 | all_boxes = [] 168 | all_scores = [] 169 | all_labels = [] 170 | for boxes, scores, image_shape in zip(pred_boxes, pred_scores, image_shapes): 171 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 172 | 173 | # create labels for each prediction 174 | labels = torch.arange(num_classes, device=device) 175 | labels = labels.view(1, -1).expand_as(scores) 176 | 177 | # remove predictions with the background label 178 | boxes = boxes[:, 1:] 179 | scores = scores[:, 1:] 180 | labels = labels[:, 1:] 181 | 182 | # batch everything, by making every class prediction be a separate instance 183 | boxes = boxes.reshape(-1, 4) 184 | scores = scores.flatten() 185 | labels = labels.flatten() 186 | 187 | # remove low scoring boxes 188 | inds = torch.nonzero(scores > self.score_thresh).squeeze(1) 189 | boxes, scores, labels = boxes[inds], scores[inds], labels[inds] 190 | 191 | # remove empty boxes 192 | keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) 193 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 194 | 195 | # non-maximum suppression, independently done per class 196 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) 197 | # keep only topk scoring predictions 198 | keep = keep[:self.detections_per_img] 199 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 200 | 201 | all_boxes.append(boxes) 202 | all_scores.append(scores) 203 | all_labels.append(labels) 204 | 205 | return all_boxes, all_scores, all_labels 206 | 207 | def forward(self, features, proposals, image_shapes, targets=None): 208 | """ 209 | Arguments: 210 | features (List[Tensor]) 211 | proposals (List[Tensor[N, 4]]) 212 | image_shapes (List[Tuple[H, W]]) 213 | targets (List[Dict]) 214 | """ 215 | if targets is not None: 216 | for t in targets: 217 | assert t["boxes"].dtype.is_floating_point, 'target boxes must of float type' 218 | assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' 219 | if self.has_keypoint: 220 | assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' 221 | 222 | if self.training: 223 | proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) 224 | 225 | box_features = self.box_roi_pool(features, proposals, image_shapes) 226 | box_features = self.box_head(box_features) 227 | class_logits, box_regression = self.box_predictor(box_features) 228 | 229 | result, losses = [], {} 230 | if self.training: 231 | loss_classifier, loss_box_reg = fastrcnn_loss( 232 | class_logits, box_regression, labels, regression_targets) 233 | losses = dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg) 234 | else: 235 | boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) 236 | num_images = len(boxes) 237 | for i in range(num_images): 238 | if boxes[i].numel() > 0: 239 | result.append( 240 | dict( 241 | boxes=boxes[i], 242 | labels=labels[i], 243 | scores=scores[i], 244 | ) 245 | ) 246 | else: 247 | result.append( 248 | dict( 249 | boxes=torch.tensor([0.0, 0.0, image_shapes[i][1], image_shapes[i][0]]).to(boxes[i].device).unsqueeze(0), 250 | labels=torch.tensor([0]).to(boxes[i].device), 251 | scores=torch.tensor([0.1]).to(boxes[i].device), 252 | ) 253 | ) 254 | 255 | if self.has_mask: 256 | if targets is not None: 257 | assert(len(targets) == len(result)) 258 | # result.extend([{k:v for k, v in x.items() if k in ["boxes", "labels"]} for x in targets]) 259 | for i, r in enumerate(result): 260 | r["boxes"] = torch.cat([targets[i]["boxes"], r["boxes"]]) 261 | r["labels"] = torch.cat([targets[i]["labels"], r["labels"]]) 262 | r["scores"] = torch.cat([torch.ones((targets[i]["labels"].numel(),)).to(r["scores"].device), r["scores"]]) 263 | # if "scores" not in r: 264 | # r["scores"] = torch.ones((r["labels"].numel(),)) 265 | mask_proposals = [p["boxes"] for p in result] 266 | 267 | if self.training: 268 | # during training, only focus on positive boxes 269 | num_images = len(proposals) 270 | mask_proposals = [] 271 | pos_matched_idxs = [] 272 | for img_id in range(num_images): 273 | pos = torch.nonzero(labels[img_id] > 0).squeeze(1) 274 | mask_proposals.append(proposals[img_id][pos]) 275 | pos_matched_idxs.append(matched_idxs[img_id][pos]) 276 | 277 | mask_roi_features = self.mask_roi_pool(features, mask_proposals, image_shapes) 278 | mask_features = self.mask_head(mask_roi_features) 279 | mask_logits = self.mask_predictor(mask_features) 280 | 281 | loss_mask = {} 282 | if self.training: 283 | gt_masks = [t["masks"] for t in targets] 284 | gt_labels = [t["labels"] for t in targets] 285 | loss_mask = maskrcnn_loss( 286 | mask_logits, mask_proposals, 287 | gt_masks, gt_labels, pos_matched_idxs) 288 | loss_mask = dict(loss_mask=loss_mask) 289 | else: 290 | labels = [r["labels"] for r in result] 291 | masks_probs = maskrcnn_inference(mask_logits, labels) 292 | for mask_prob, r in zip(masks_probs, result): 293 | r["masks"] = mask_prob 294 | 295 | losses.update(loss_mask) 296 | 297 | if self.has_match and not self.training: 298 | loss_match = {} 299 | s_imgs = [] 300 | for i, p in enumerate(mask_proposals): 301 | if i == 0: 302 | types = [0] * len(p) 303 | else: 304 | types = types + [1] * len(p) 305 | s_imgs = s_imgs + ([i] * len(p)) 306 | 307 | types = torch.IntTensor(types) 308 | if mask_roi_features.shape[0] > 0: 309 | final_features, match_logits = self.match_predictor(mask_roi_features, types) 310 | for i, r in zip(range(len(mask_proposals)), result): 311 | r['match_features'] = final_features[torch.IntTensor(s_imgs) == i, ...] 312 | r['w'] = self.match_predictor.last.weight 313 | r['b'] = self.match_predictor.last.bias 314 | r['roi_features'] = mask_roi_features[torch.IntTensor(s_imgs) == i] 315 | 316 | return result, losses 317 | 318 | 319 | 320 | class VideoMatchRCNN(MaskRCNN): 321 | def __init__(self, backbone, num_classes, n_frames, **kwargs): 322 | super(VideoMatchRCNN, self).__init__(backbone, num_classes, **kwargs) 323 | self.roi_heads = TemporalRoIHeads(self.roi_heads, n_frames) 324 | 325 | def load_saved_matchrcnn(self, sd): 326 | self.load_state_dict(sd, strict=False) 327 | self.roi_heads.temporal_aggregator\ 328 | .load_state_dict(deepcopy(self.roi_heads.match_predictor.state_dict()), strict=False) 329 | 330 | 331 | def videomatchrcnn_resnet50_fpn(pretrained=False, progress=True, 332 | num_classes=91, pretrained_backbone=True, 333 | n_frames=3, **kwargs): 334 | if pretrained: 335 | # no need to download the backbone if pretrained is set 336 | pretrained_backbone = False 337 | backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) 338 | model = VideoMatchRCNN(backbone, num_classes, n_frames, **kwargs) 339 | if pretrained: 340 | state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], 341 | progress=progress) 342 | model.load_state_dict(state_dict) 343 | return model 344 | -------------------------------------------------------------------------------- /stuffs/engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | from pycocotools import mask as maskUtils 10 | 11 | from models.match_head import MatchLossWeak, NEWBalancedAggregationMatchLossWeak \ 12 | , AggregationMatchLossDF2 13 | from stuffs import utils 14 | 15 | outputkeys_whitelist = ['scores', 'boxes', 'roi_features'] 16 | 17 | 18 | def train_one_epoch_matchrcnn(model, optimizer, data_loader, device, epoch, print_freq 19 | , writer=None): 20 | if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1: 21 | rank = dist.get_rank() 22 | else: 23 | rank = 0 24 | model.train() 25 | metric_logger = utils.MetricLogger(delimiter=" ") 26 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 27 | header = 'Epoch: [{}]'.format(epoch) 28 | 29 | lr_scheduler = None 30 | if epoch == 0: 31 | warmup_factor = 1. / 1000 32 | warmup_iters = min(1000, len(data_loader) - 1) 33 | 34 | lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 35 | count = -1 36 | for i, (images, targets, idxs) in enumerate(metric_logger.log_every(data_loader, print_freq, header, rank=rank)): 37 | count += 1 38 | images = list(image.to(device) for image in images) 39 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 40 | loss_dict = model(images, targets) 41 | # print(args.local_rank) 42 | 43 | losses = sum(loss for loss in loss_dict.values()) 44 | 45 | # reduce losses over all GPUs for logging purposes 46 | loss_dict_reduced = utils.reduce_dict(loss_dict) 47 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 48 | if writer is not None and (((count % print_freq) == 0) or count == 0): 49 | global_step = (epoch * len(data_loader)) + count 50 | for k, v in loss_dict_reduced.items(): 51 | writer.add_scalar(k, v.item(), global_step=global_step) 52 | writer.add_scalar("loss", losses.item(), global_step=global_step) 53 | 54 | loss_value = losses_reduced.item() 55 | 56 | if not math.isfinite(loss_value): 57 | print("Loss is {}, stopping training".format(loss_value)) 58 | print(loss_dict_reduced) 59 | print(idxs) 60 | sys.exit(1) 61 | 62 | optimizer.zero_grad() 63 | losses.backward() 64 | optimizer.step() 65 | 66 | if lr_scheduler is not None: 67 | lr_scheduler.step() 68 | 69 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 70 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 71 | print("Epoch finished by process #%d" % rank) 72 | 73 | 74 | 75 | 76 | def train_one_epoch_movingfashion(model, optimizer, data_loader, device, epoch, print_freq 77 | , score_thresh=0.7, writer=None, inferstep=10): 78 | if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1: 79 | distributed = True 80 | rank = dist.get_rank() 81 | else: 82 | distributed = False 83 | rank = 0 84 | metric_logger = utils.MetricLogger(delimiter=" ") 85 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 86 | header = 'Epoch: [{}]'.format(epoch) 87 | 88 | lr_scheduler = None 89 | if epoch == 0: 90 | warmup_factor = 1. / 1000 91 | warmup_iters = min(1000, len(data_loader) - 1) 92 | 93 | lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 94 | 95 | # wait for all workers to be ready 96 | # if distributed: 97 | # dist.barrier() 98 | real_model = model if not hasattr(model, "module") else model.module 99 | match_predictor = real_model.roi_heads.match_predictor 100 | temporal_aggregator = real_model.roi_heads.temporal_aggregator 101 | 102 | 103 | match_loss = MatchLossWeak(device) 104 | aggregation_loss = NEWBalancedAggregationMatchLossWeak(device, temporal_aggregator) 105 | 106 | count = -1 107 | for images, targets in metric_logger.log_every(data_loader, print_freq, header, rank=rank): 108 | 109 | count += 1 110 | images = list(image.to(device) for image in images) 111 | # targets = [{k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()} for t in targets] 112 | # output: list of dicts: "boxes", "labels", "scores", "masks", "match_features", "w", "b", "roi_features" 113 | model.eval() 114 | with torch.no_grad(): 115 | output = [model(images[x:x + inferstep]) for x in range(0, len(images), inferstep)] 116 | output = [y for x in output for y in x] 117 | # clean output dict to save memory 118 | output = [{k: v for k, v in o.items() if k in outputkeys_whitelist} for o in output] 119 | 120 | match_predictor.train() 121 | temporal_aggregator.train() 122 | 123 | roi_features = [] 124 | types = [] 125 | prod_ids = [] 126 | img_ids = [] 127 | exclude_prod_ids = [] 128 | boxes = [] 129 | scores = [] 130 | for i, (t, o) in enumerate(zip(targets, output)): 131 | if t["i"] in exclude_prod_ids: 132 | # if a product is excluded, skip all frames 133 | continue 134 | if "roi_features" in o: 135 | indexes = (o["scores"] >= score_thresh).nonzero().view(-1) 136 | if indexes.numel() < 1: 137 | if t["tag"] == 1: 138 | # exclude street imgs if shop doesn't have any boxes 139 | exclude_prod_ids.append(t["i"]) 140 | continue 141 | if t["tag"] == 1: 142 | tmp_bs = o["boxes"][indexes] 143 | indexes = ((tmp_bs[:, 2] - tmp_bs[:, 0]) * (tmp_bs[:, 3] - tmp_bs[:, 1])).argmax().view(1) 144 | roi_features.append(o["roi_features"][indexes]) 145 | boxes.append(o["boxes"][indexes]) 146 | scores.append(o["scores"][indexes]) 147 | types = types + [t["tag"]] * indexes.shape[0] 148 | prod_ids = prod_ids + [t["i"]] * indexes.shape[0] 149 | img_ids = img_ids + [i] * indexes.shape[0] 150 | flag = False 151 | types = torch.IntTensor(types) 152 | # at least two boxes, one being a street and one being a shop 153 | if len(roi_features) >= 2 and (types == 0).any() and (types == 1).any(): 154 | roi_features = torch.cat(roi_features, 0) 155 | 156 | 157 | # predict matches street-shop 158 | _, logits = match_predictor(roi_features, types) 159 | # predict tracking street-street 160 | # first retrieve only the street items 161 | # duplicate them to match with each other 162 | weight_aggr = min(epoch / 1, 1.0) 163 | 164 | loss_dict = { 165 | 'match_loss': match_loss(logits, types, prod_ids, img_ids) 166 | , 'aggregation_loss': weight_aggr * 167 | aggregation_loss(logits, types, prod_ids, img_ids, roi_features) 168 | } 169 | 170 | losses = sum(loss for loss in loss_dict.values()) 171 | 172 | # reduce losses over all GPUs for logging purposes 173 | loss_dict_reduced = utils.reduce_dict(loss_dict) 174 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 175 | 176 | loss_value = losses_reduced.item() 177 | 178 | if not math.isfinite(loss_value): 179 | print("Loss is {}, stopping training".format(loss_value)) 180 | print(loss_dict_reduced) 181 | sys.exit(1) 182 | 183 | optimizer.zero_grad() 184 | losses.backward() 185 | optimizer.step() 186 | 187 | if lr_scheduler is not None: 188 | lr_scheduler.step() 189 | 190 | if writer is not None and (((count % print_freq) == 0) or count == 0): 191 | global_step = (epoch * len(data_loader)) + count 192 | for k, v in loss_dict_reduced.items(): 193 | writer.add_scalar(k, v.item(), global_step=global_step) 194 | writer.add_scalar("loss", losses.item(), global_step=global_step) 195 | 196 | 197 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 198 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 199 | print("Epoch finished by process #%d" % rank) 200 | 201 | 202 | def train_one_epoch_multiDF2(model, optimizer, data_loader, device, epoch, print_freq 203 | , score_thresh=0.7, writer=None, inferstep=10, use_gt=False): 204 | if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1: 205 | rank = dist.get_rank() 206 | else: 207 | rank = 0 208 | metric_logger = utils.MetricLogger(delimiter=" ") 209 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 210 | header = 'Epoch: [{}]'.format(epoch) 211 | 212 | lr_scheduler = None 213 | if epoch == 0: 214 | warmup_factor = 1. / 1000 215 | warmup_iters = min(1000, len(data_loader) - 1) 216 | 217 | lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) 218 | 219 | match_predictor = model.roi_heads.match_predictor 220 | temporal_aggregator = model.roi_heads.temporal_aggregator 221 | aggregation_loss = AggregationMatchLossDF2(device, temporal_aggregator) 222 | 223 | count2 = -1 224 | for images, targets, ids in metric_logger.log_every(data_loader, print_freq, header, rank=rank): 225 | # if count2 >= 5: 226 | # break 227 | count2 += 1 228 | images = list(image.to(device) for image in images) 229 | targets = [{k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()} for t in targets] 230 | targets = [{k: (v.float() if k == "boxes" else v) for k, v in t.items()} for t in targets] 231 | # output: list of dicts: "boxes", "labels", "scores", "masks", "match_features", "w", "b", "roi_features" 232 | model.eval() 233 | targets2 = deepcopy(targets) 234 | with torch.no_grad(): 235 | if use_gt: 236 | output = [model(images[x:x + inferstep], targets=targets2[x:x + inferstep]) for x in range(0, len(images), inferstep)] 237 | else: 238 | output = [model(images[x:x + inferstep]) for x in range(0, len(images), inferstep)] 239 | output = [y for x in output for y in x] 240 | # clean output dict to save memory 241 | output = [{k: v for k, v in o.items() if k in outputkeys_whitelist} for o in output] 242 | # torch.cuda.empty_cache() 243 | 244 | match_predictor.eval() 245 | temporal_aggregator.train() 246 | # print(args.local_rank) 247 | 248 | roi_features = [] 249 | types = [] 250 | prod_ids = [] 251 | img_ids = [] 252 | gt_infos = [] 253 | exclude_prod_ids = [] 254 | boxes = [] 255 | scores = [] 256 | i2tmpid = {} 257 | count = 0 258 | for i, (t, o) in enumerate(zip(targets, output)): 259 | if t["i"] in exclude_prod_ids: 260 | # if a product is excluded, skip all frames 261 | continue 262 | if "roi_features" in o: 263 | indexes = (o["scores"] >= score_thresh).nonzero().view(-1) 264 | if indexes.numel() < 1: 265 | if t["tag"] == 1: 266 | # exclude street imgs if shop doesn't have any boxes 267 | exclude_prod_ids.append(t["i"]) 268 | continue 269 | if t["i"] not in i2tmpid: 270 | i2tmpid[t["i"]] = count 271 | count += 1 272 | pr_boxes = o["boxes"][indexes].detach().cpu().numpy() 273 | gt_boxes = t["boxes"].detach().cpu().numpy() 274 | pr_boxes[:, 2] = pr_boxes[:, 2] - pr_boxes[:, 0] 275 | pr_boxes[:, 3] = pr_boxes[:, 3] - pr_boxes[:, 1] 276 | gt_boxes[:, 2] = gt_boxes[:, 2] - gt_boxes[:, 0] 277 | gt_boxes[:, 3] = gt_boxes[:, 3] - gt_boxes[:, 1] 278 | iou = maskUtils.iou(gt_boxes, pr_boxes, np.zeros((pr_boxes.shape[0]))) # gts x preds 279 | style, pair_id = [int(x) for x in t["i"].split("_")] 280 | gt_prods = [(count if (t["styles"][ind] == style and t["pair_ids"][ind] == pair_id) else -1) 281 | for ind in range(gt_boxes.shape[0])] 282 | det_prods = [-1] * indexes.shape[0] 283 | det_prods[iou[torch.tensor(gt_prods).argmax()].argmax()] = count 284 | 285 | if t["tag"] == 1: 286 | indexes = indexes[torch.tensor(det_prods) == count] 287 | det_prods = [count] 288 | 289 | roi_features.append(o["roi_features"][indexes]) 290 | boxes.append(o["boxes"][indexes]) 291 | scores.append(o["scores"][indexes]) 292 | types = types + [t["tag"]] * indexes.shape[0] 293 | prod_ids = prod_ids + [i2tmpid[t["i"]]] * indexes.shape[0] 294 | img_ids = img_ids + [i] * indexes.shape[0] 295 | gt_infos = gt_infos + det_prods 296 | types = torch.IntTensor(types) 297 | # at least two boxes, one being a street and one being a shop 298 | if len(roi_features) >= 2 and (types == 0).any() and (types == 1).any(): 299 | roi_features = torch.cat(roi_features, 0) 300 | 301 | # predict matches street-shop 302 | _, logits = match_predictor(roi_features, types) 303 | 304 | weight_aggr = 1.0 305 | 306 | loss_dict = { 307 | 'aggregation_loss': weight_aggr * 308 | aggregation_loss(types, roi_features, gt_infos) 309 | } 310 | 311 | losses = sum(loss for loss in loss_dict.values()) 312 | 313 | # reduce losses over all GPUs for logging purposes 314 | loss_dict_reduced = utils.reduce_dict(loss_dict) 315 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 316 | 317 | loss_value = losses_reduced.item() 318 | 319 | if not math.isfinite(loss_value): 320 | print("Loss is {}, stopping training".format(loss_value)) 321 | print(loss_dict_reduced) 322 | sys.exit(1) 323 | 324 | optimizer.zero_grad() 325 | losses.backward() 326 | optimizer.step() 327 | 328 | if lr_scheduler is not None: 329 | lr_scheduler.step() 330 | 331 | if writer is not None and (((count % print_freq) == 0) or count == 0): 332 | global_step = (epoch * len(data_loader)) + count 333 | for k, v in loss_dict_reduced.items(): 334 | writer.add_scalar(k, v.item(), global_step=global_step) 335 | writer.add_scalar("loss", losses.item(), global_step=global_step) 336 | 337 | 338 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 339 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 340 | print("Epoch finished by process #%d" % rank) 341 | -------------------------------------------------------------------------------- /stuffs/mask_utils.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | import pycocotools._mask as _mask 4 | 5 | # Interface for manipulating masks stored in RLE format. 6 | # 7 | # RLE is a simple yet efficient format for storing binary masks. RLE 8 | # first divides a vector (or vectorized image) into a series of piecewise 9 | # constant regions and then for each piece simply stores the length of 10 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 11 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 12 | # (note that the odd counts are always the numbers of zeros). Instead of 13 | # storing the counts directly, additional compression is achieved with a 14 | # variable bitrate representation based on a common scheme called LEB128. 15 | # 16 | # Compression is greatest given large piecewise constant regions. 17 | # Specifically, the size of the RLE is proportional to the number of 18 | # *boundaries* in M (or for an image the number of boundaries in the y 19 | # direction). Assuming fairly simple shapes, the RLE representation is 20 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 21 | # is substantially lower, especially for large simple objects (large n). 22 | # 23 | # Many common operations on masks can be computed directly using the RLE 24 | # (without need for decoding). This includes computations such as area, 25 | # union, intersection, etc. All of these operations are linear in the 26 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 27 | # of the object. Computing these operations on the original mask is O(n). 28 | # Thus, using the RLE can result in substantial computational savings. 29 | # 30 | # The following API functions are defined: 31 | # encode - Encode binary masks using RLE. 32 | # decode - Decode binary masks encoded via RLE. 33 | # merge - Compute union or intersection of encoded masks. 34 | # iou - Compute intersection over union between masks. 35 | # area - Compute area of encoded masks. 36 | # toBbox - Get bounding boxes surrounding encoded masks. 37 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 38 | # 39 | # Usage: 40 | # Rs = encode( masks ) 41 | # masks = decode( Rs ) 42 | # R = merge( Rs, intersect=false ) 43 | # o = iou( dt, gt, iscrowd ) 44 | # a = area( Rs ) 45 | # bbs = toBbox( Rs ) 46 | # Rs = frPyObjects( [pyObjects], h, w ) 47 | # 48 | # In the API the following formats are used: 49 | # Rs - [dict] Run-length encoding of binary masks 50 | # R - dict Run-length encoding of binary mask 51 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 52 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 53 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 54 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 55 | # dt,gt - May be either bounding boxes or encoded masks 56 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 57 | # 58 | # Finally, a note about the intersection over union (iou) computation. 59 | # The standard iou of a ground truth (gt) and detected (dt) object is 60 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 61 | # For "crowd" regions, we use a modified criteria. If a gt object is 62 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 63 | # Choosing gt' in the crowd gt that best matches the dt can be done using 64 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 65 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 66 | # For crowd gt regions we use this modified criteria above for the iou. 67 | # 68 | # To compile run "python setup.py build_ext --inplace" 69 | # Please do not contact us for help with compiling. 70 | # 71 | # Microsoft COCO Toolbox. version 2.0 72 | # Data, paper, and tutorials available at: http://mscoco.org/ 73 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 74 | # Licensed under the Simplified BSD License [see coco/license.txt] 75 | 76 | iou = _mask.iou 77 | merge = _mask.merge 78 | frPyObjects = _mask.frPyObjects 79 | 80 | 81 | def encode(bimask): 82 | if len(bimask.shape) == 3: 83 | return _mask.encode(bimask) 84 | elif len(bimask.shape) == 2: 85 | h, w = bimask.shape 86 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] 87 | 88 | 89 | def decode(rleObjs): 90 | if type(rleObjs) == list: 91 | return _mask.decode(rleObjs) 92 | else: 93 | return _mask.decode([rleObjs])[:, :, 0] 94 | 95 | 96 | def area(rleObjs): 97 | if type(rleObjs) == list: 98 | return _mask.area(rleObjs) 99 | else: 100 | return _mask.area([rleObjs])[0] 101 | 102 | 103 | def toBbox(rleObjs): 104 | if type(rleObjs) == list: 105 | return _mask.toBbox(rleObjs) 106 | else: 107 | return _mask.toBbox([rleObjs])[0] 108 | 109 | 110 | def annToRLE(ann, size): 111 | """ 112 | Convert annotation which can be polygons, uncompressed RLE to RLE. 113 | :return: binary mask (numpy 2D array) 114 | """ 115 | 116 | h, w = size[0], size[1] 117 | segm = ann['segmentation'] 118 | if type(segm) == list: 119 | # polygon -- a single object might consist of multiple parts 120 | # we merge all parts into one mask rle code 121 | rles = frPyObjects(segm, h, w) 122 | rle = merge(rles) 123 | elif type(segm['counts']) == list: 124 | # uncompressed RLE 125 | rle = frPyObjects(segm, h, w) 126 | else: 127 | # rle 128 | rle = ann['segmentation'] 129 | return rle 130 | 131 | 132 | def annToMask(ann, size): 133 | """ 134 | Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask. 135 | :return: binary mask (numpy 2D array) 136 | """ 137 | rle = annToRLE(ann, size) 138 | m = decode(rle) 139 | return m 140 | -------------------------------------------------------------------------------- /stuffs/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torchvision.transforms import functional as F 4 | 5 | 6 | def _flip_coco_person_keypoints(kps, width): 7 | flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] 8 | flipped_data = kps[:, flip_inds] 9 | flipped_data[..., 0] = width - flipped_data[..., 0] 10 | # Maintain COCO convention that if visibility == 0, then x, y = 0 11 | inds = flipped_data[..., 2] == 0 12 | flipped_data[inds] = 0 13 | return flipped_data 14 | 15 | 16 | class Compose(object): 17 | def __init__(self, transforms): 18 | self.transforms = transforms 19 | 20 | def __call__(self, image, target): 21 | for t in self.transforms: 22 | image, target = t(image, target) 23 | return image, target 24 | 25 | 26 | class RandomHorizontalFlip(object): 27 | def __init__(self, prob): 28 | self.prob = prob 29 | 30 | def __call__(self, image, target): 31 | if random.random() < self.prob: 32 | height, width = image.shape[-2:] 33 | image = image.flip(-1) 34 | bbox = target["boxes"] 35 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] 36 | target["boxes"] = bbox 37 | if "masks" in target: 38 | target["masks"] = target["masks"].flip(-1) 39 | if "keypoints" in target: 40 | keypoints = target["keypoints"] 41 | keypoints = _flip_coco_person_keypoints(keypoints, width) 42 | target["keypoints"] = keypoints 43 | return image, target 44 | 45 | 46 | class ToTensor(object): 47 | def __call__(self, image, target): 48 | image = F.to_tensor(image) 49 | return image, target -------------------------------------------------------------------------------- /stuffs/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import datetime 4 | import errno 5 | import math 6 | import os 7 | import pickle 8 | import time 9 | from collections import defaultdict, deque 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | 15 | def visualize_matches(logits, img_ids, types, ind0, ind1, images, boxes, scores, score_thresh=0.7): 16 | ''' 17 | ind0 street (tag 0) 18 | ind1 shop (tag 1) 19 | ''' 20 | import matplotlib.pyplot as plt 21 | import matplotlib.patches as patches 22 | colors = ['r', 'g', 'b', 'c', 'y', 'm', 'k', 'w'] 23 | 24 | # assert(ind0 in img_ids) 25 | # assert(ind1 in img_ids) 26 | if ind0 not in img_ids or ind1 not in img_ids: 27 | return 28 | assert(all(types[i] == 0 for i, x in enumerate(img_ids) if x == ind0)) # check img0 is street 29 | assert(all(types[i] == 1 for i, x in enumerate(img_ids) if x == ind1)) # check img1 is shop 30 | 31 | street_inds = (types == 0).nonzero().view(-1) 32 | shop_inds = (types == 1).nonzero().view(-1) 33 | # associate to each detection the corresponding index in the logits 34 | reverse_street_shop_inds = torch.zeros_like(types, dtype=torch.int64) 35 | reverse_street_shop_inds[street_inds] = torch.arange(street_inds.shape[0]) 36 | reverse_street_shop_inds[shop_inds] = torch.arange(shop_inds.shape[0]) 37 | box_inds0 = torch.LongTensor([i for i, x in enumerate(img_ids) 38 | if x == ind0 and scores[i] > score_thresh]) 39 | box_inds1 = torch.LongTensor([i for i, x in enumerate(img_ids) 40 | if x == ind1 and scores[i] > score_thresh]) 41 | img0 = images[ind0] 42 | out0 = {"boxes": boxes[box_inds0].view(-1, 4), "scores": scores[box_inds0].view(-1, 1)} 43 | img1 = images[ind1] 44 | out1 = {"boxes": boxes[box_inds1].view(-1, 4), "scores": scores[box_inds1].view(-1, 1)} 45 | logits_inds0 = reverse_street_shop_inds[box_inds0] 46 | logits_inds1 = reverse_street_shop_inds[box_inds1] 47 | new_logits = logits[logits_inds0][:, logits_inds1].view(logits_inds0.numel(), logits_inds1.numel(), 2) 48 | 49 | plt.figure() 50 | 51 | plt.subplot(1, 2, 2) 52 | plt.imshow(img1.permute(1, 2, 0).detach().cpu().numpy()) 53 | ax = plt.gca() 54 | rect = None 55 | max_size = 0 56 | maxind = None 57 | for i, (b, score) in enumerate(zip(out1['boxes'], out1['scores'])): 58 | if score > score_thresh: 59 | tmp_rect = patches.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1] 60 | , linewidth=1 61 | , edgecolor=colors[0] 62 | , facecolor='none') 63 | tmp_size = ((b[2] - b[0]) * (b[3] - b[1])) 64 | if tmp_size > max_size: 65 | rect = tmp_rect 66 | max_size = tmp_size 67 | maxind = i 68 | if rect is not None: 69 | ax.add_patch(rect) 70 | 71 | if maxind is not None: 72 | # print(new_logits[:, :, 0]) 73 | print(new_logits[:, maxind, 1]) 74 | else: 75 | # print(new_logits[:, :, 0]) 76 | print(new_logits[:, :, 1]) 77 | 78 | plt.subplot(1, 2, 1) 79 | plt.imshow(img0.permute(1, 2, 0).detach().cpu().numpy()) 80 | ax = plt.gca() 81 | for i, (b, score) in enumerate(zip(out0['boxes'], out0['scores'])): 82 | if score > score_thresh: 83 | # tmp_c = i if new_logits.shape[1] == 0 \ 84 | # else (out1['scores'] > score_thresh).nonzero()[0][new_logits[i, :, 1].argmax()] 85 | if maxind is not None: 86 | softmax_logit = new_logits[i, maxind, :].exp() / new_logits[i, maxind, :].exp().sum() 87 | tmp_c = (softmax_logit[0].item(), softmax_logit[1].item(), 0.0) 88 | else: 89 | tmp_c = colors[tmp_c] 90 | rect = patches.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1] 91 | , linewidth=1 92 | , edgecolor=tmp_c 93 | , facecolor='none') 94 | ax.add_patch(rect) 95 | 96 | # if logits_inds0.numel() == 0 or logits_inds1.numel() == 0: 97 | # print("no matches!") 98 | # else: 99 | # matches = new_logits.argmax(2) 100 | # print((matches == 1).nonzero()) 101 | 102 | print("**********") 103 | plt.savefig("%d%d.png" % (ind0, ind1)) 104 | plt.close() 105 | # plt.show() 106 | 107 | 108 | def visualize_tracking(match_logits, images, types, prod_ids, img_ids 109 | , tracking_prod_ids, tracking_img_ids, prod_id 110 | , boxes, match_threshold=-1.0): 111 | ''' 112 | ind0 street (tag 0) 113 | ind1 shop (tag 1) 114 | ''' 115 | import matplotlib.pyplot as plt 116 | import matplotlib.patches as patches 117 | colors = ['r', 'g', 'b', 'c', 'y', 'm', 'k', 'w'] 118 | 119 | # LOSS CODE UNTIL TRACKING CANDIDATES COMPUTATION 120 | # 121 | img_ids = torch.tensor(img_ids) 122 | prod_ids = torch.tensor(prod_ids) 123 | tracking_prod_ids = torch.tensor(tracking_prod_ids) 124 | tracking_img_ids = torch.tensor(tracking_img_ids) 125 | street_inds = (types == 0).nonzero().view(-1) 126 | shop_inds = (types == 1).nonzero().view(-1) 127 | # associate to each detection the corresponding index in the logits 128 | reverse_street_shop_inds = torch.zeros_like(types, dtype=torch.int64) 129 | reverse_street_shop_inds[street_inds] = torch.arange(street_inds.shape[0]) 130 | reverse_street_shop_inds[shop_inds] = torch.arange(shop_inds.shape[0]) 131 | tracking_candidates = [] 132 | # per ciascun prodotto 133 | for pi in torch.unique(tracking_prod_ids): 134 | # prendo gli indici tracking del prodotto 135 | tmp_prod_inds = (tracking_prod_ids == pi).nonzero().view(-1) 136 | 137 | # per ciascuna immagine di quel prodotto 138 | for ii in torch.unique(tracking_img_ids[tmp_prod_inds]): 139 | # prendo gli indici nei match logit corrispondenti 140 | tmp_prod_id = pi 141 | tmp_tracking_inds = (tracking_img_ids == ii).nonzero().view(-1) 142 | tmp_inds = (img_ids == ii).nonzero().view(-1) 143 | # cerco lo shop corrispondente 144 | shop_ind = ((prod_ids == tmp_prod_id) & (types == 1)).nonzero().view(-1) 145 | tmp_logits = match_logits[reverse_street_shop_inds[tmp_inds], reverse_street_shop_inds[shop_ind], 1].view( 146 | -1) 147 | # max_score_ind is the index within the boxes in this image id 148 | max_score, max_score_ind = tmp_logits.max(), tmp_logits.argmax() 149 | if max_score > match_threshold: 150 | # save the corresponding tracking ind 151 | tracking_candidates.append(tmp_tracking_inds[max_score_ind]) 152 | # tracking_candidates will contain the "best" box for each image 153 | tracking_candidates = torch.tensor(tracking_candidates) 154 | 155 | # VISUALIZATION CODE, SHOW TRACKING CANDIDATES for the prod_id 156 | selected_candidates = tracking_candidates[(tracking_prod_ids[tracking_candidates] == prod_id) 157 | .nonzero().view(-1)] 158 | if selected_candidates.numel() == 0: 159 | print("NO CANDIDATES") 160 | return 161 | boxes = boxes[selected_candidates] 162 | tracking_prod_ids = tracking_prod_ids[selected_candidates] 163 | tracking_img_ids = tracking_img_ids[selected_candidates] 164 | images = [images[x] for x in tracking_img_ids] 165 | plt.figure() 166 | 167 | for i, img in enumerate(images): 168 | plt.subplot(1, len(images), i + 1) 169 | plt.imshow(img.permute(1, 2, 0).detach().cpu().numpy()) 170 | ax = plt.gca() 171 | b = boxes[i] 172 | tmp_rect = patches.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1] 173 | , linewidth=1 174 | , edgecolor=colors[0] 175 | , facecolor='none') 176 | ax.add_patch(tmp_rect) 177 | 178 | print("**********") 179 | plt.savefig("tracking_%d.png" % prod_id) 180 | plt.close() 181 | # plt.show() 182 | 183 | 184 | def visualize_tracking_eval(images, boxes, cls, savename="tracking_eval", rows=1): 185 | ''' 186 | ind0 street (tag 0) 187 | ind1 shop (tag 1) 188 | ''' 189 | import matplotlib.pyplot as plt 190 | import matplotlib.patches as patches 191 | colors = ['r', 'g', 'b', 'c', 'y', 'm', 'k', 'w'] * 10 192 | 193 | spr = rows 194 | spc = len(images) if rows == 1 else int(math.ceil(len(images) / rows)) 195 | 196 | for i, img in enumerate(images): 197 | plt.subplot(spr, spc, i + 1) 198 | plt.imshow(img.permute(1, 2, 0).detach().cpu().numpy()) 199 | ax = plt.gca() 200 | tmp_boxes = boxes[i] 201 | tmp_cls = cls[i] 202 | for ii, b in enumerate(tmp_boxes): 203 | tmp_rect = patches.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1] 204 | , linewidth=1 205 | , edgecolor=colors[tmp_cls[ii]] 206 | , facecolor='none') 207 | ax.add_patch(tmp_rect) 208 | 209 | print("**********") 210 | plt.savefig(savename + ".png") 211 | plt.close() 212 | # plt.show() 213 | 214 | 215 | class SmoothedValue(object): 216 | """Track a series of values and provide access to smoothed values over a 217 | window or the global series average. 218 | """ 219 | 220 | def __init__(self, window_size=20, fmt=None): 221 | if fmt is None: 222 | fmt = "{median:.4f} ({global_avg:.4f})" 223 | self.deque = deque(maxlen=window_size) 224 | self.total = 0.0 225 | self.count = 0 226 | self.fmt = fmt 227 | 228 | def update(self, value, n=1): 229 | self.deque.append(value) 230 | self.count += n 231 | self.total += value * n 232 | 233 | def synchronize_between_processes(self): 234 | """ 235 | Warning: does not synchronize the deque! 236 | """ 237 | if not is_dist_avail_and_initialized(): 238 | return 239 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 240 | dist.barrier() 241 | dist.all_reduce(t) 242 | t = t.tolist() 243 | self.count = int(t[0]) 244 | self.total = t[1] 245 | 246 | @property 247 | def median(self): 248 | d = torch.tensor(list(self.deque)) 249 | return d.median().item() 250 | 251 | @property 252 | def avg(self): 253 | d = torch.tensor(list(self.deque), dtype=torch.float32) 254 | return d.mean().item() 255 | 256 | @property 257 | def global_avg(self): 258 | return self.total / self.count 259 | 260 | @property 261 | def max(self): 262 | return max(self.deque) 263 | 264 | @property 265 | def value(self): 266 | return self.deque[-1] 267 | 268 | def __str__(self): 269 | return self.fmt.format( 270 | median=self.median, 271 | avg=self.avg, 272 | global_avg=self.global_avg, 273 | max=self.max, 274 | value=self.value) 275 | 276 | 277 | def all_gather(data): 278 | """ 279 | Run all_gather on arbitrary picklable data (not necessarily tensors) 280 | Args: 281 | data: any picklable object 282 | Returns: 283 | list[data]: list of data gathered from each rank 284 | """ 285 | world_size = get_world_size() 286 | if world_size == 1: 287 | return [data] 288 | 289 | # serialized to a Tensor 290 | buffer = pickle.dumps(data) 291 | storage = torch.ByteStorage.from_buffer(buffer) 292 | tensor = torch.ByteTensor(storage).to("cuda") 293 | 294 | # obtain Tensor size of each rank 295 | local_size = torch.tensor([tensor.numel()], device="cuda") 296 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 297 | dist.all_gather(size_list, local_size) 298 | size_list = [int(size.item()) for size in size_list] 299 | max_size = max(size_list) 300 | 301 | # receiving Tensor from all ranks 302 | # we pad the tensor because torch all_gather does not support 303 | # gathering tensors of different shapes 304 | tensor_list = [] 305 | for _ in size_list: 306 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 307 | if local_size != max_size: 308 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 309 | tensor = torch.cat((tensor, padding), dim=0) 310 | dist.all_gather(tensor_list, tensor) 311 | 312 | data_list = [] 313 | for size, tensor in zip(size_list, tensor_list): 314 | buffer = tensor.cpu().numpy().tobytes()[:size] 315 | data_list.append(pickle.loads(buffer)) 316 | 317 | return data_list 318 | 319 | 320 | def reduce_dict(input_dict, average=True): 321 | """ 322 | Args: 323 | input_dict (dict): all the values will be reduced 324 | average (bool): whether to do average or sum 325 | Reduce the values in the dictionary from all processes so that all processes 326 | have the averaged results. Returns a dict with the same fields as 327 | input_dict, after reduction. 328 | """ 329 | world_size = get_world_size() 330 | if world_size < 2: 331 | return input_dict 332 | with torch.no_grad(): 333 | names = [] 334 | values = [] 335 | # sort the keys so that they are consistent across processes 336 | for k in sorted(input_dict.keys()): 337 | names.append(k) 338 | values.append(input_dict[k]) 339 | values = torch.stack(values, dim=0) 340 | dist.all_reduce(values) 341 | if average: 342 | values /= world_size 343 | reduced_dict = {k: v for k, v in zip(names, values)} 344 | return reduced_dict 345 | 346 | 347 | class MetricLogger(object): 348 | def __init__(self, delimiter="\t"): 349 | self.meters = defaultdict(SmoothedValue) 350 | self.delimiter = delimiter 351 | 352 | def update(self, **kwargs): 353 | for k, v in kwargs.items(): 354 | if isinstance(v, torch.Tensor): 355 | v = v.item() 356 | assert isinstance(v, (float, int)) 357 | self.meters[k].update(v) 358 | 359 | def __getattr__(self, attr): 360 | if attr in self.meters: 361 | return self.meters[attr] 362 | if attr in self.__dict__: 363 | return self.__dict__[attr] 364 | raise AttributeError("'{}' object has no attribute '{}'".format( 365 | type(self).__name__, attr)) 366 | 367 | def __str__(self): 368 | loss_str = [] 369 | for name, meter in self.meters.items(): 370 | loss_str.append( 371 | "{}: {}".format(name, str(meter)) 372 | ) 373 | return self.delimiter.join(loss_str) 374 | 375 | def synchronize_between_processes(self): 376 | for meter in self.meters.values(): 377 | meter.synchronize_between_processes() 378 | 379 | def add_meter(self, name, meter): 380 | self.meters[name] = meter 381 | 382 | def log_every(self, iterable, print_freq, header=None, rank=0): 383 | i = 0 384 | if not header: 385 | header = '' 386 | start_time = time.time() 387 | end = time.time() 388 | iter_time = SmoothedValue(fmt='{avg:.4f}') 389 | data_time = SmoothedValue(fmt='{avg:.4f}') 390 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 391 | if torch.cuda.is_available(): 392 | log_msg = self.delimiter.join([ 393 | header, 394 | '[{0' + space_fmt + '}/{1}]', 395 | 'eta: {eta}', 396 | '{meters}', 397 | 'time: {time}', 398 | 'data: {data}', 399 | 'max mem: {memory:.0f}' 400 | ]) 401 | else: 402 | log_msg = self.delimiter.join([ 403 | header, 404 | '[{0' + space_fmt + '}/{1}]', 405 | 'eta: {eta}', 406 | '{meters}', 407 | 'time: {time}', 408 | 'data: {data}' 409 | ]) 410 | MB = 1024.0 * 1024.0 411 | for obj in iterable: 412 | data_time.update(time.time() - end) 413 | yield obj 414 | iter_time.update(time.time() - end) 415 | if i % print_freq == 0 or i == len(iterable) - 1: 416 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 417 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 418 | if rank == 0: 419 | if torch.cuda.is_available(): 420 | print(log_msg.format( 421 | i, len(iterable), eta=eta_string, 422 | meters=str(self), 423 | time=str(iter_time), data=str(data_time), 424 | memory=torch.cuda.max_memory_allocated() / MB)) 425 | else: 426 | print(log_msg.format( 427 | i, len(iterable), eta=eta_string, 428 | meters=str(self), 429 | time=str(iter_time), data=str(data_time))) 430 | i += 1 431 | end = time.time() 432 | total_time = time.time() - start_time 433 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 434 | print('{} Total time: {} ({:.4f} s / it)'.format( 435 | header, total_time_str, total_time / len(iterable))) 436 | 437 | 438 | def collate_fn(batch): 439 | return tuple(zip(*batch)) 440 | 441 | 442 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): 443 | def f(x): 444 | if x >= warmup_iters: 445 | return 1 446 | alpha = float(x) / warmup_iters 447 | return warmup_factor * (1 - alpha) + alpha 448 | 449 | return torch.optim.lr_scheduler.LambdaLR(optimizer, f) 450 | 451 | 452 | def mkdir(path): 453 | try: 454 | os.makedirs(path) 455 | except OSError as e: 456 | if e.errno != errno.EEXIST: 457 | raise 458 | 459 | 460 | def setup_for_distributed(is_master): 461 | """ 462 | This function disables printing when not in master process 463 | """ 464 | import builtins as __builtin__ 465 | builtin_print = __builtin__.print 466 | 467 | def print(*args, **kwargs): 468 | force = kwargs.pop('force', False) 469 | if is_master or force: 470 | builtin_print(*args, **kwargs) 471 | 472 | __builtin__.print = print 473 | 474 | 475 | def is_dist_avail_and_initialized(): 476 | if not dist.is_available(): 477 | return False 478 | if not dist.is_initialized(): 479 | return False 480 | return True 481 | 482 | 483 | def get_world_size(): 484 | if not is_dist_avail_and_initialized(): 485 | return 1 486 | return dist.get_world_size() 487 | 488 | 489 | def get_rank(): 490 | if not is_dist_avail_and_initialized(): 491 | return 0 492 | return dist.get_rank() 493 | 494 | 495 | def is_main_process(): 496 | return get_rank() == 0 497 | 498 | 499 | def save_on_master(*args, **kwargs): 500 | if is_main_process(): 501 | torch.save(*args, **kwargs) 502 | 503 | 504 | def init_distributed_mode(args): 505 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 506 | args.rank = int(os.environ["RANK"]) 507 | args.world_size = int(os.environ['WORLD_SIZE']) 508 | args.gpu = int(os.environ['LOCAL_RANK']) 509 | elif 'SLURM_PROCID' in os.environ: 510 | args.rank = int(os.environ['SLURM_PROCID']) 511 | args.gpu = args.rank % torch.cuda.device_count() 512 | else: 513 | print('Not using distributed mode') 514 | args.distributed = False 515 | return 516 | 517 | args.distributed = True 518 | 519 | torch.cuda.set_device(args.gpu) 520 | args.dist_backend = 'nccl' 521 | print('| distributed init (rank {}): {}'.format( 522 | args.rank, args.dist_url), flush=True) 523 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 524 | world_size=args.world_size, rank=args.rank) 525 | torch.distributed.barrier() 526 | setup_for_distributed(args.rank == 0) 527 | 528 | 529 | 530 | -------------------------------------------------------------------------------- /train_matchrcnn.py: -------------------------------------------------------------------------------- 1 | from stuffs import transform as T 2 | from datasets.DF2Dataset import DeepFashion2Dataset, get_dataloader 3 | import torch 4 | from models.matchrcnn import matchrcnn_resnet50_fpn 5 | from stuffs.engine import train_one_epoch_matchrcnn 6 | import os 7 | import torch.distributed as dist 8 | import argparse 9 | 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | gpu_map = [0, 1, 2, 3] 13 | 14 | 15 | def get_transform(train): 16 | transforms = [] 17 | transforms.append(T.ToTensor()) 18 | if train: 19 | transforms.append(T.RandomHorizontalFlip(0.5)) 20 | return T.Compose(transforms) 21 | 22 | 23 | # run with python -m torch.distributed.launch --nproc_per_node #GPUs train_matchrcnn.py 24 | 25 | 26 | def train(args): 27 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 28 | if 'WORLD_SIZE' in os.environ: 29 | distributed = int(os.environ['WORLD_SIZE']) > 1 30 | rank = args.local_rank 31 | print("Distributed training with %d processors. This is #%s" 32 | % (int(os.environ['WORLD_SIZE']), rank)) 33 | else: 34 | distributed = False 35 | rank = 0 36 | print("Not distributed training") 37 | 38 | if distributed: 39 | os.environ['NCCL_BLOCKING_WAIT'] = "1" 40 | torch.cuda.set_device(gpu_map[rank]) 41 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 42 | device = torch.device(torch.cuda.current_device()) 43 | else: 44 | device = torch.device("cuda") 45 | 46 | # DATASET ---------------------------------------------------------------------------------------------------------- 47 | 48 | train_dataset = DeepFashion2Dataset(root=args.root_train 49 | , ann_file=args.train_annots, 50 | transforms=get_transform(True)) 51 | 52 | # ------------------------------------------------------------------------------------------------------------------ 53 | 54 | # DATALOADER-------------------------------------------------------------------------------------------------------- 55 | 56 | data_loader = get_dataloader(train_dataset, batch_size=args.batch_size, is_parallel=distributed) 57 | 58 | # ------------------------------------------------------------------------------------------------------------------ 59 | 60 | # MODEL ------------------------------------------------------------------------------------------------------------ 61 | from models.maskrcnn import params as c_params 62 | model = matchrcnn_resnet50_fpn(pretrained_backbone=True, num_classes=14, **c_params) 63 | model.to(device) 64 | 65 | # ------------------------------------------------------------------------------------------------------------------ 66 | 67 | # OPTIMIZER AND SCHEDULER ------------------------------------------------------------------------------------------ 68 | 69 | # construct an optimizer 70 | params = [p for p in model.parameters() if p.requires_grad] 71 | optimizer = torch.optim.SGD(params, lr=args.learning_rate, 72 | momentum=0.9) 73 | 74 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones) 75 | 76 | # ------------------------------------------------------------------------------------------------------------------ 77 | 78 | if rank == 0: 79 | writer = SummaryWriter(os.path.join(args.save_path, args.save_tag)) 80 | else: 81 | writer = None 82 | 83 | for epoch in range(args.num_epochs): 84 | # train for one epoch, printing every 10 iterations 85 | print("Starting epoch %d for process %d" % (epoch, rank)) 86 | train_one_epoch_matchrcnn(model, optimizer, data_loader, device, epoch, args.print_freq, writer) 87 | # update the learning rate 88 | lr_scheduler.step() 89 | 90 | if rank == 0 and epoch != 0 and epoch % args.save_epochs == 0: 91 | os.makedirs(args.save_path, exist_ok=True) 92 | torch.save({ 93 | 'epoch': epoch, 94 | 'model_state_dict': model.state_dict(), 95 | 'optimizer_state_dict': optimizer.state_dict(), 96 | 'scheduler_state_dict': lr_scheduler.state_dict(), 97 | }, os.path.join(args.save_path, (args.save_tag + "_epoch%03d") % epoch)) 98 | 99 | os.makedirs(args.save_path, exist_ok=True) 100 | torch.save({ 101 | 'epoch': args.num_epochs, 102 | 'model_state_dict': model.state_dict(), 103 | 'optimizer_state_dict': optimizer.state_dict(), 104 | 'scheduler_state_dict': lr_scheduler.state_dict(), 105 | }, os.path.join(args.save_path, "final_model")) 106 | 107 | print("That's it!") 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | parser = argparse.ArgumentParser(description="Match R-CNN Training") 113 | parser.add_argument("--local_rank", type=int, default=0) 114 | parser.add_argument("--gpus", type=str, default="0,1") 115 | parser.add_argument("--n_workers", type=int, default=8) 116 | 117 | parser.add_argument("--batch_size", type=int, default=8) 118 | parser.add_argument("--root_train", type=str, default='data/deepfashion2/train/image') 119 | parser.add_argument("--train_annots", type=str, default='data/deepfashion2/train/annots.json') 120 | 121 | parser.add_argument("--num_epochs", type=int, default=12) 122 | parser.add_argument("--milestones", type=int, default=[6, 9]) 123 | parser.add_argument("--learning_rate", type=float, default=0.02) 124 | 125 | parser.add_argument("--print_freq", type=int, default=100) 126 | parser.add_argument("--save_epochs", type=int, default=2) 127 | 128 | parser.add_argument('--save_path', type=str, default="ckpt/matchrcnn") 129 | parser.add_argument('--save_tag', type=str, default="DF2-pretraining") 130 | 131 | args = parser.parse_args() 132 | 133 | train(args) 134 | -------------------------------------------------------------------------------- /train_movingfashion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import resource 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from datasets.MFDataset import MovingFashionDataset, get_dataloader 10 | from evaluate_movingfashion import evaluate 11 | from models.video_matchrcnn import videomatchrcnn_resnet50_fpn 12 | from stuffs import transform as T 13 | from stuffs.engine import train_one_epoch_movingfashion 14 | 15 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 16 | resource.setrlimit(resource.RLIMIT_NOFILE, (16384, rlimit[1])) 17 | 18 | # DistributedDataParallel tutorial @ https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html 19 | # run with python -m torch.distributed.launch --nproc_per_node #GPUs train.py 20 | 21 | 22 | gpu_map = [0, 1, 2, 3] 23 | 24 | def get_transform(train): 25 | transforms = [] 26 | transforms.append(T.ToTensor()) 27 | if train: 28 | transforms.append(T.RandomHorizontalFlip(0.5)) 29 | return T.Compose(transforms) 30 | 31 | # how many frames to extract from the video of each product 32 | 33 | 34 | def train(args): 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 36 | if 'WORLD_SIZE' in os.environ: 37 | distributed = int(os.environ['WORLD_SIZE']) > 1 38 | rank = args.local_rank 39 | print("Distributed training with %d processors. This is #%s" 40 | % (int(os.environ['WORLD_SIZE']), rank)) 41 | else: 42 | distributed = False 43 | rank = 0 44 | print("Not distributed training") 45 | 46 | if distributed: 47 | os.environ['NCCL_BLOCKING_WAIT'] = "1" 48 | torch.cuda.set_device(gpu_map[rank]) 49 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 50 | device = torch.device(torch.cuda.current_device()) 51 | else: 52 | device = torch.device("cuda") 53 | 54 | # DATASET ---------------------------------------------------------------------------------------------------------- 55 | train_dataset = MovingFashionDataset(args.train_annots 56 | , transform=get_transform(True), noise=args.noise 57 | , root=args.root) 58 | test_dataset = MovingFashionDataset(args.test_annots 59 | , transform=T.ToTensor(), noise=args.noise 60 | , root=args.root) 61 | 62 | # ------------------------------------------------------------------------------------------------------------------ 63 | 64 | # DATALOADER-------------------------------------------------------------------------------------------------------- 65 | 66 | data_loader = get_dataloader(train_dataset, batch_size=args.batch_size_train 67 | , is_parallel=distributed, n_products=args.n_shops, num_workers=args.n_workers) 68 | data_loader_test = get_dataloader(test_dataset, batch_size=args.batch_size_test, is_parallel=distributed, num_workers=args.n_workers) 69 | 70 | # ------------------------------------------------------------------------------------------------------------------ 71 | 72 | # MODEL ------------------------------------------------------------------------------------------------------------ 73 | 74 | model = videomatchrcnn_resnet50_fpn(pretrained_backbone=True, num_classes=14 75 | , n_frames=3) 76 | 77 | 78 | 79 | if args.start_ckpt != None: 80 | savefile = torch.load(args.start_ckpt) 81 | start_ep = savefile['epoch'] + 1 82 | model.load_state_dict(savefile['model_state_dict']) 83 | pass 84 | else: 85 | savefile = torch.load(args.pretrained_path) 86 | sd = savefile['model_state_dict'] 87 | sd = {".".join(k.split(".")[1:]): v for k, v in sd.items()} 88 | model.load_saved_matchrcnn(sd) 89 | start_ep = 0 90 | model.to(device) 91 | # ------------------------------------------------------------------------------------------------------------------ 92 | 93 | # OPTIMIZER AND SCHEDULER ------------------------------------------------------------------------------------------ 94 | 95 | # construct an optimizer 96 | params = [p for p in model.parameters() if p.requires_grad] 97 | optimizer = torch.optim.SGD(params, lr=args.learning_rate, 98 | momentum=0.9, weight_decay=0.0005) 99 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer 100 | , milestones=args.milestones 101 | , gamma=0.1) 102 | if args.start_ckpt != None: 103 | optimizer.load_state_dict(savefile['optimizer_state_dict']) 104 | lr_scheduler.load_state_dict(savefile['scheduler_state_dict']) 105 | # ------------------------------------------------------------------------------------------------------------------ 106 | 107 | 108 | if rank == 0: 109 | writer = SummaryWriter(os.path.join(args.save_path, args.save_tag)) 110 | else: 111 | writer = None 112 | 113 | best_single, best_avg, best_aggr = 0.0, 0.0, 0.0 114 | 115 | for epoch in range(args.num_epochs): 116 | # train for one epoch, printing every 10 iterations 117 | print("Starting epoch %d for process %d" % (epoch, rank)) 118 | train_one_epoch_movingfashion(model, optimizer, data_loader, device, epoch 119 | , print_freq=args.print_freq, score_thresh=0.1, writer=writer, inferstep=15) 120 | # update the learning rate 121 | lr_scheduler.step() 122 | # evaluate on the test dataset 123 | 124 | if rank == 0 and ((epoch % args.save_epochs) == 0): 125 | os.makedirs(args.save_path, exist_ok=True) 126 | torch.save({ 127 | 'epoch': epoch, 128 | 'model_state_dict': model.state_dict(), 129 | 'optimizer_state_dict': optimizer.state_dict(), 130 | 'scheduler_state_dict': lr_scheduler.state_dict() 131 | }, os.path.join(args.save_path, (args.save_tag + "_epoch%03d") % epoch)) 132 | model = model.to(device) 133 | 134 | if rank == 0 and ((epoch % args.eval_freq) == 0): 135 | model.eval() 136 | res = evaluate(model, data_loader_test, device, frames_per_product=args.frames_per_shop_test) 137 | writer.add_scalar("single_acc", res[0], global_step=epoch) 138 | writer.add_scalar("avg_acc", res[1], global_step=epoch) 139 | writer.add_scalar("aggr_acc", res[2], global_step=epoch) 140 | best_single, best_avg, best_aggr = max(res[0], best_single), max(res[1], best_avg)\ 141 | , max(res[2], best_aggr) 142 | print("Best results:\n - Best single: %01.2f" 143 | "\n - Best avg: %01.2f\n - Best aggr: %01.2f\n" % (best_single, best_avg, best_aggr)) 144 | 145 | os.makedirs(args.save_path, exist_ok=True) 146 | torch.save({ 147 | 'epoch': args.num_epochs, 148 | 'model_state_dict': model.state_dict(), 149 | 'optimizer_state_dict': optimizer.state_dict(), 150 | 'scheduler_state_dict': lr_scheduler.state_dict() 151 | }, os.path.join(args.save_path, (args.save_tag + "_epoch%03d") % args.num_epochs)) 152 | if rank == 0: 153 | model.eval() 154 | _ = evaluate(model, data_loader_test, device, frames_per_product=args.frames_per_shop_test) 155 | print("That's it!") 156 | 157 | 158 | if __name__ == '__main__': 159 | 160 | parser = argparse.ArgumentParser(description="SEAM Training") 161 | parser.add_argument("--local_rank", type=int, default=0) 162 | parser.add_argument("--gpus", type=str, default="0") 163 | parser.add_argument("--n_workers", type=int, default=8) 164 | 165 | parser.add_argument("--frames_per_shop_train", type=int, default=10) 166 | parser.add_argument("--frames_per_shop_test", type=int, default=10) 167 | parser.add_argument("--n_shops", type=int, default=16) 168 | parser.add_argument("--root", type=str, default="data/MovingFashion") 169 | parser.add_argument("--train_annots", type=str, default="data/MovingFashion/train.json") 170 | parser.add_argument("--test_annots", type=str, default="data/MovingFashion/test.json") 171 | parser.add_argument("--noise", type=bool, default=True) 172 | 173 | parser.add_argument("--num_epochs", type=int, default=31) 174 | parser.add_argument("--milestones", type=int, default=[15, 25]) 175 | parser.add_argument("--learning_rate", type=float, default=0.04) #Please consider also the number of GPU 176 | parser.add_argument("--start_ckpt", type=str, default=None) #Insert ckpt model path to restart training from a fixed epoch 177 | parser.add_argument("--pretrained_path", type=str, default="pre-trained/df2matchrcnn") 178 | 179 | parser.add_argument("--print_freq", type=int, default=20) 180 | parser.add_argument("--eval_freq", type=int, default=4) 181 | parser.add_argument("--save_epochs", type=int, default=2) 182 | 183 | parser.add_argument('--save_path',type=str, default="ckpt/SEAM/MovingFashion") 184 | parser.add_argument('--save_tag', type=str, default="MF") 185 | 186 | args = parser.parse_args() 187 | 188 | args.batch_size_train = (1 + args.frames_per_shop_train) * args.n_shops 189 | args.batch_size_test = (1 + args.frames_per_shop_test) * 1 190 | 191 | train(args) 192 | -------------------------------------------------------------------------------- /train_multiDF2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import resource 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from datasets.MultiDF2Dataset import MultiDeepFashion2Dataset, get_dataloader 10 | from evaluate_multiDF2 import evaluate 11 | from models.video_matchrcnn import videomatchrcnn_resnet50_fpn 12 | from stuffs import transform as T 13 | from stuffs.engine import train_one_epoch_multiDF2 14 | 15 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 16 | resource.setrlimit(resource.RLIMIT_NOFILE, (16384, rlimit[1])) 17 | 18 | gpu_map = [0, 1, 2, 3] 19 | 20 | 21 | def get_transform(train): 22 | transforms = [] 23 | transforms.append(T.ToTensor()) 24 | if train: 25 | transforms.append(T.RandomHorizontalFlip(0.5)) 26 | return T.Compose(transforms) 27 | 28 | 29 | # how many frames to extract from the video of each product 30 | 31 | 32 | def train(args): 33 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 34 | if 'WORLD_SIZE' in os.environ: 35 | distributed = int(os.environ['WORLD_SIZE']) > 1 36 | rank = args.local_rank 37 | print("Distributed training with %d processors. This is #%s" 38 | % (int(os.environ['WORLD_SIZE']), rank)) 39 | else: 40 | distributed = False 41 | rank = 0 42 | print("Not distributed training") 43 | 44 | if distributed: 45 | os.environ['NCCL_BLOCKING_WAIT'] = "1" 46 | torch.cuda.set_device(gpu_map[rank]) 47 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 48 | device = torch.device(torch.cuda.current_device()) 49 | else: 50 | device = torch.device("cuda") 51 | 52 | # DATASET ---------------------------------------------------------------------------------------------------------- 53 | train_dataset = MultiDeepFashion2Dataset(root=args.root_train 54 | , ann_file=args.train_annots, 55 | transforms=get_transform(True), noise=True, filter_onestreet=True) 56 | test_dataset = MultiDeepFashion2Dataset(root=args.root_test 57 | , ann_file=args.test_annots, 58 | transforms=get_transform(False), filter_onestreet=True) 59 | # ------------------------------------------------------------------------------------------------------------------ 60 | 61 | # DATALOADER-------------------------------------------------------------------------------------------------------- 62 | 63 | data_loader_train = get_dataloader(train_dataset, batch_size=args.batch_size_train 64 | , is_parallel=distributed, n_products=args.n_shops, n_workers=args.n_workers) 65 | data_loader_test = get_dataloader(test_dataset, batch_size=args.batch_size_test, is_parallel=distributed, 66 | n_products=1, n_workers=args.n_workers) 67 | 68 | # ------------------------------------------------------------------------------------------------------------------ 69 | 70 | # MODEL ------------------------------------------------------------------------------------------------------------ 71 | model = videomatchrcnn_resnet50_fpn(pretrained_backbone=True, num_classes=14 72 | , n_frames=3) 73 | 74 | if args.start_ckpt != None: 75 | savefile = torch.load(args.start_ckpt) 76 | start_ep = savefile['epoch'] + 1 77 | model.load_state_dict(savefile['model_state_dict']) 78 | pass 79 | else: 80 | savefile = torch.load(args.pretrained_path) 81 | sd = savefile['model_state_dict'] 82 | sd = {".".join(k.split(".")[1:]): v for k, v in sd.items()} 83 | model.load_saved_matchrcnn(sd) 84 | start_ep = 0 85 | model.to(device) 86 | # ------------------------------------------------------------------------------------------------------------------ 87 | 88 | # OPTIMIZER AND SCHEDULER ------------------------------------------------------------------------------------------ 89 | 90 | # construct an optimizer 91 | params = [p for p in model.parameters() if p.requires_grad] 92 | optimizer = torch.optim.SGD(params, lr=args.learning_rate, 93 | momentum=0.9, weight_decay=0.0005) 94 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer 95 | , milestones=args.milestones 96 | , gamma=0.1) 97 | if args.start_ckpt != None: 98 | optimizer.load_state_dict(savefile['optimizer_state_dict']) 99 | lr_scheduler.load_state_dict(savefile['scheduler_state_dict']) 100 | # ------------------------------------------------------------------------------------------------------------------ 101 | 102 | if rank == 0: 103 | writer = SummaryWriter(os.path.join(args.save_path, args.save_tag)) 104 | else: 105 | writer = None 106 | 107 | best_single, best_avg, best_aggr = 0.0, 0.0, 0.0 108 | 109 | for epoch in range(args.num_epochs): 110 | # train for one epoch, printing every 10 iterations 111 | print("Starting epoch %d for process %d" % (epoch, rank)) 112 | train_one_epoch_multiDF2(model, optimizer, data_loader_train, device, epoch 113 | , print_freq=args.print_freq, score_thresh=0.1, writer=writer, inferstep=15) 114 | # update the learning rate 115 | lr_scheduler.step() 116 | # evaluate on the test dataset 117 | 118 | if rank == 0 and ((epoch % args.save_epochs) == 0): 119 | os.makedirs(args.save_path, exist_ok=True) 120 | torch.save({ 121 | 'epoch': epoch, 122 | 'model_state_dict': model.state_dict(), 123 | 'optimizer_state_dict': optimizer.state_dict(), 124 | 'scheduler_state_dict': lr_scheduler.state_dict() 125 | }, os.path.join(args.save_path, (args.save_tag + "_epoch%03d") % epoch)) 126 | model = model.to(device) 127 | 128 | if rank == 0 and ((epoch % args.eval_freq) == 0): 129 | model.eval() 130 | res = evaluate(model, data_loader_test, device, frames_per_product=args.frames_per_shop_test) 131 | writer.add_scalar("single_acc", res[0], global_step=epoch) 132 | writer.add_scalar("avg_acc", res[1], global_step=epoch) 133 | writer.add_scalar("aggr_acc", res[2], global_step=epoch) 134 | best_single, best_avg, best_aggr = max(res[0], best_single), max(res[1], best_avg) \ 135 | , max(res[2], best_aggr) 136 | print("Best results:\n - Best single: %01.2f" 137 | "\n - Best avg: %01.2f\n - Best aggr: %01.2f\n" % (best_single, best_avg, best_aggr)) 138 | 139 | os.makedirs(args.save_path, exist_ok=True) 140 | torch.save({ 141 | 'epoch': args.num_epochs, 142 | 'model_state_dict': model.state_dict(), 143 | 'optimizer_state_dict': optimizer.state_dict(), 144 | 'scheduler_state_dict': lr_scheduler.state_dict() 145 | }, os.path.join(args.save_path, (args.save_tag + "_epoch%03d") % args.num_epochs)) 146 | if rank == 0: 147 | model.eval() 148 | _ = evaluate(model, data_loader_test, device, frames_per_product=args.frames_per_shop_test) 149 | print("That's it!") 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser(description="SEAM Training") 154 | parser.add_argument("--local_rank", type=int, default=0) 155 | parser.add_argument("--gpus", type=str, default="0,1") 156 | parser.add_argument("--n_workers", type=int, default=8) 157 | 158 | parser.add_argument("--frames_per_shop_train", type=int, default=10) 159 | parser.add_argument("--frames_per_shop_test", type=int, default=10) 160 | parser.add_argument("--n_shops", type=int, default=8) 161 | parser.add_argument("--root_train", type=str, default='data/deepfashion2/train/image') 162 | parser.add_argument("--root_test", type=str, default='data/deepfashion2/validation/image') 163 | parser.add_argument("--train_annots", type=str, default='data/deepfashion2/train/annots.json') 164 | parser.add_argument("--test_annots", type=str, default='data/deepfashion2/validation/annots.json') 165 | parser.add_argument("--noise", type=bool, default=True) 166 | 167 | parser.add_argument("--num_epochs", type=int, default=31) 168 | parser.add_argument("--milestones", type=int, default=[15, 25]) 169 | parser.add_argument("--learning_rate", type=float, default=0.02) 170 | parser.add_argument("--start_ckpt", type=str, default=None) #Insert ckpt model path to restart training from a fixed epoch 171 | parser.add_argument("--pretrained_path", type=str, 172 | default="pre-trained/df2matchrcnn") 173 | 174 | parser.add_argument("--print_freq", type=int, default=20) 175 | parser.add_argument("--eval_freq", type=int, default=4) 176 | parser.add_argument("--save_epochs", type=int, default=2) 177 | 178 | parser.add_argument('--save_path', type=str, default="ckpt/SEAM/multiDF2") 179 | parser.add_argument('--save_tag', type=str, default="DF2") 180 | 181 | args = parser.parse_args() 182 | 183 | args.batch_size_train = (1 + args.frames_per_shop_train) * args.n_shops 184 | args.batch_size_test = (1 + args.frames_per_shop_test) * 1 185 | 186 | train(args) 187 | --------------------------------------------------------------------------------