├── .gitignore ├── CHANGELOG.md ├── COPYRIGHT.txt ├── README.md ├── environment.yml ├── object-locator ├── __init__.py ├── __main__.py ├── argparser.py ├── bmm.py ├── checkpoints │ └── .gitignore ├── data.py ├── data_plant_stuff.py ├── find_lr.py ├── get_image_size.py ├── locate.py ├── logger.py ├── losses.py ├── make_metric_plots.py ├── metrics.py ├── metrics_from_results.py ├── models │ ├── __init__.py │ ├── unet_model.py │ ├── unet_parts.py │ └── utils.py ├── paint.py ├── train.py └── utils.py ├── scripts_dataset_and_results ├── generate_csv.py ├── parseResults.py └── spacing_stats_to_csv.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # vim 107 | *.swp 108 | -------------------------------------------------------------------------------- /COPYRIGHT.txt: -------------------------------------------------------------------------------- 1 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 2 | All rights reserved. 3 | 4 | This software is covered by US patents and copyright. 5 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 6 | 7 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 8 | 9 | Last Modified: 10/02/2019 10 | 11 | 12 | Attribution-NonCommercial-ShareAlike 4.0 International 13 | 14 | ======================================================================= 15 | 16 | Creative Commons Corporation ("Creative Commons") is not a law firm and 17 | does not provide legal services or legal advice. Distribution of 18 | Creative Commons public licenses does not create a lawyer-client or 19 | other relationship. Creative Commons makes its licenses and related 20 | information available on an "as-is" basis. Creative Commons gives no 21 | warranties regarding its licenses, any material licensed under their 22 | terms and conditions, or any related information. Creative Commons 23 | disclaims all liability for damages resulting from their use to the 24 | fullest extent possible. 25 | 26 | Using Creative Commons Public Licenses 27 | 28 | Creative Commons public licenses provide a standard set of terms and 29 | conditions that creators and other rights holders may use to share 30 | original works of authorship and other material subject to copyright 31 | and certain other rights specified in the public license below. The 32 | following considerations are for informational purposes only, are not 33 | exhaustive, and do not form part of our licenses. 34 | 35 | Considerations for licensors: Our public licenses are 36 | intended for use by those authorized to give the public 37 | permission to use material in ways otherwise restricted by 38 | copyright and certain other rights. Our licenses are 39 | irrevocable. Licensors should read and understand the terms 40 | and conditions of the license they choose before applying it. 41 | Licensors should also secure all rights necessary before 42 | applying our licenses so that the public can reuse the 43 | material as expected. Licensors should clearly mark any 44 | material not subject to the license. This includes other CC- 45 | licensed material, or material used under an exception or 46 | limitation to copyright. More considerations for licensors: 47 | wiki.creativecommons.org/Considerations_for_licensors 48 | 49 | Considerations for the public: By using one of our public 50 | licenses, a licensor grants the public permission to use the 51 | licensed material under specified terms and conditions. If 52 | the licensor's permission is not necessary for any reason--for 53 | example, because of any applicable exception or limitation to 54 | copyright--then that use is not regulated by the license. Our 55 | licenses grant only permissions under copyright and certain 56 | other rights that a licensor has authority to grant. Use of 57 | the licensed material may still be restricted for other 58 | reasons, including because others have copyright or other 59 | rights in the material. A licensor may make special requests, 60 | such as asking that all changes be marked or described. 61 | Although not required by our licenses, you are encouraged to 62 | respect those requests where reasonable. More considerations 63 | for the public: 64 | wiki.creativecommons.org/Considerations_for_licensees 65 | 66 | ======================================================================= 67 | 68 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 69 | Public License 70 | 71 | By exercising the Licensed Rights (defined below), You accept and agree 72 | to be bound by the terms and conditions of this Creative Commons 73 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 74 | ("Public License"). To the extent this Public License may be 75 | interpreted as a contract, You are granted the Licensed Rights in 76 | consideration of Your acceptance of these terms and conditions, and the 77 | Licensor grants You such rights in consideration of benefits the 78 | Licensor receives from making the Licensed Material available under 79 | these terms and conditions. 80 | 81 | 82 | Section 1 -- Definitions. 83 | 84 | a. Adapted Material means material subject to Copyright and Similar 85 | Rights that is derived from or based upon the Licensed Material 86 | and in which the Licensed Material is translated, altered, 87 | arranged, transformed, or otherwise modified in a manner requiring 88 | permission under the Copyright and Similar Rights held by the 89 | Licensor. For purposes of this Public License, where the Licensed 90 | Material is a musical work, performance, or sound recording, 91 | Adapted Material is always produced where the Licensed Material is 92 | synched in timed relation with a moving image. 93 | 94 | b. Adapter's License means the license You apply to Your Copyright 95 | and Similar Rights in Your contributions to Adapted Material in 96 | accordance with the terms and conditions of this Public License. 97 | 98 | c. BY-NC-SA Compatible License means a license listed at 99 | creativecommons.org/compatiblelicenses, approved by Creative 100 | Commons as essentially the equivalent of this Public License. 101 | 102 | d. Copyright and Similar Rights means copyright and/or similar rights 103 | closely related to copyright including, without limitation, 104 | performance, broadcast, sound recording, and Sui Generis Database 105 | Rights, without regard to how the rights are labeled or 106 | categorized. For purposes of this Public License, the rights 107 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 108 | Rights. 109 | 110 | e. Effective Technological Measures means those measures that, in the 111 | absence of proper authority, may not be circumvented under laws 112 | fulfilling obligations under Article 11 of the WIPO Copyright 113 | Treaty adopted on December 20, 1996, and/or similar international 114 | agreements. 115 | 116 | f. Exceptions and Limitations means fair use, fair dealing, and/or 117 | any other exception or limitation to Copyright and Similar Rights 118 | that applies to Your use of the Licensed Material. 119 | 120 | g. License Elements means the license attributes listed in the name 121 | of a Creative Commons Public License. The License Elements of this 122 | Public License are Attribution, NonCommercial, and ShareAlike. 123 | 124 | h. Licensed Material means the artistic or literary work, database, 125 | or other material to which the Licensor applied this Public 126 | License. 127 | 128 | i. Licensed Rights means the rights granted to You subject to the 129 | terms and conditions of this Public License, which are limited to 130 | all Copyright and Similar Rights that apply to Your use of the 131 | Licensed Material and that the Licensor has authority to license. 132 | 133 | j. Licensor means the individual(s) or entity(ies) granting rights 134 | under this Public License. 135 | 136 | k. NonCommercial means not primarily intended for or directed towards 137 | commercial advantage or monetary compensation. For purposes of 138 | this Public License, the exchange of the Licensed Material for 139 | other material subject to Copyright and Similar Rights by digital 140 | file-sharing or similar means is NonCommercial provided there is 141 | no payment of monetary compensation in connection with the 142 | exchange. 143 | 144 | l. Share means to provide material to the public by any means or 145 | process that requires permission under the Licensed Rights, such 146 | as reproduction, public display, public performance, distribution, 147 | dissemination, communication, or importation, and to make material 148 | available to the public including in ways that members of the 149 | public may access the material from a place and at a time 150 | individually chosen by them. 151 | 152 | m. Sui Generis Database Rights means rights other than copyright 153 | resulting from Directive 96/9/EC of the European Parliament and of 154 | the Council of 11 March 1996 on the legal protection of databases, 155 | as amended and/or succeeded, as well as other essentially 156 | equivalent rights anywhere in the world. 157 | 158 | n. You means the individual or entity exercising the Licensed Rights 159 | under this Public License. Your has a corresponding meaning. 160 | 161 | 162 | Section 2 -- Scope. 163 | 164 | a. License grant. 165 | 166 | 1. Subject to the terms and conditions of this Public License, 167 | the Licensor hereby grants You a worldwide, royalty-free, 168 | non-sublicensable, non-exclusive, irrevocable license to 169 | exercise the Licensed Rights in the Licensed Material to: 170 | 171 | a. reproduce and Share the Licensed Material, in whole or 172 | in part, for NonCommercial purposes only; and 173 | 174 | b. produce, reproduce, and Share Adapted Material for 175 | NonCommercial purposes only. 176 | 177 | 2. Exceptions and Limitations. For the avoidance of doubt, where 178 | Exceptions and Limitations apply to Your use, this Public 179 | License does not apply, and You do not need to comply with 180 | its terms and conditions. 181 | 182 | 3. Term. The term of this Public License is specified in Section 183 | 6(a). 184 | 185 | 4. Media and formats; technical modifications allowed. The 186 | Licensor authorizes You to exercise the Licensed Rights in 187 | all media and formats whether now known or hereafter created, 188 | and to make technical modifications necessary to do so. The 189 | Licensor waives and/or agrees not to assert any right or 190 | authority to forbid You from making technical modifications 191 | necessary to exercise the Licensed Rights, including 192 | technical modifications necessary to circumvent Effective 193 | Technological Measures. For purposes of this Public License, 194 | simply making modifications authorized by this Section 2(a) 195 | (4) never produces Adapted Material. 196 | 197 | 5. Downstream recipients. 198 | 199 | a. Offer from the Licensor -- Licensed Material. Every 200 | recipient of the Licensed Material automatically 201 | receives an offer from the Licensor to exercise the 202 | Licensed Rights under the terms and conditions of this 203 | Public License. 204 | 205 | b. Additional offer from the Licensor -- Adapted Material. 206 | Every recipient of Adapted Material from You 207 | automatically receives an offer from the Licensor to 208 | exercise the Licensed Rights in the Adapted Material 209 | under the conditions of the Adapter's License You apply. 210 | 211 | c. No downstream restrictions. You may not offer or impose 212 | any additional or different terms or conditions on, or 213 | apply any Effective Technological Measures to, the 214 | Licensed Material if doing so restricts exercise of the 215 | Licensed Rights by any recipient of the Licensed 216 | Material. 217 | 218 | 6. No endorsement. Nothing in this Public License constitutes or 219 | may be construed as permission to assert or imply that You 220 | are, or that Your use of the Licensed Material is, connected 221 | with, or sponsored, endorsed, or granted official status by, 222 | the Licensor or others designated to receive attribution as 223 | provided in Section 3(a)(1)(A)(i). 224 | 225 | b. Other rights. 226 | 227 | 1. Moral rights, such as the right of integrity, are not 228 | licensed under this Public License, nor are publicity, 229 | privacy, and/or other similar personality rights; however, to 230 | the extent possible, the Licensor waives and/or agrees not to 231 | assert any such rights held by the Licensor to the limited 232 | extent necessary to allow You to exercise the Licensed 233 | Rights, but not otherwise. 234 | 235 | 2. Patent and trademark rights are not licensed under this 236 | Public License. 237 | 238 | 3. To the extent possible, the Licensor waives any right to 239 | collect royalties from You for the exercise of the Licensed 240 | Rights, whether directly or through a collecting society 241 | under any voluntary or waivable statutory or compulsory 242 | licensing scheme. In all other cases the Licensor expressly 243 | reserves any right to collect such royalties, including when 244 | the Licensed Material is used other than for NonCommercial 245 | purposes. 246 | 247 | 248 | Section 3 -- License Conditions. 249 | 250 | Your exercise of the Licensed Rights is expressly made subject to the 251 | following conditions. 252 | 253 | a. Attribution. 254 | 255 | 1. If You Share the Licensed Material (including in modified 256 | form), You must: 257 | 258 | a. retain the following if it is supplied by the Licensor 259 | with the Licensed Material: 260 | 261 | i. identification of the creator(s) of the Licensed 262 | Material and any others designated to receive 263 | attribution, in any reasonable manner requested by 264 | the Licensor (including by pseudonym if 265 | designated); 266 | 267 | ii. a copyright notice; 268 | 269 | iii. a notice that refers to this Public License; 270 | 271 | iv. a notice that refers to the disclaimer of 272 | warranties; 273 | 274 | v. a URI or hyperlink to the Licensed Material to the 275 | extent reasonably practicable; 276 | 277 | b. indicate if You modified the Licensed Material and 278 | retain an indication of any previous modifications; and 279 | 280 | c. indicate the Licensed Material is licensed under this 281 | Public License, and include the text of, or the URI or 282 | hyperlink to, this Public License. 283 | 284 | 2. You may satisfy the conditions in Section 3(a)(1) in any 285 | reasonable manner based on the medium, means, and context in 286 | which You Share the Licensed Material. For example, it may be 287 | reasonable to satisfy the conditions by providing a URI or 288 | hyperlink to a resource that includes the required 289 | information. 290 | 3. If requested by the Licensor, You must remove any of the 291 | information required by Section 3(a)(1)(A) to the extent 292 | reasonably practicable. 293 | 294 | b. ShareAlike. 295 | 296 | In addition to the conditions in Section 3(a), if You Share 297 | Adapted Material You produce, the following conditions also apply. 298 | 299 | 1. The Adapter's License You apply must be a Creative Commons 300 | license with the same License Elements, this version or 301 | later, or a BY-NC-SA Compatible License. 302 | 303 | 2. You must include the text of, or the URI or hyperlink to, the 304 | Adapter's License You apply. You may satisfy this condition 305 | in any reasonable manner based on the medium, means, and 306 | context in which You Share Adapted Material. 307 | 308 | 3. You may not offer or impose any additional or different terms 309 | or conditions on, or apply any Effective Technological 310 | Measures to, Adapted Material that restrict exercise of the 311 | rights granted under the Adapter's License You apply. 312 | 313 | 314 | Section 4 -- Sui Generis Database Rights. 315 | 316 | Where the Licensed Rights include Sui Generis Database Rights that 317 | apply to Your use of the Licensed Material: 318 | 319 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 320 | to extract, reuse, reproduce, and Share all or a substantial 321 | portion of the contents of the database for NonCommercial purposes 322 | only; 323 | 324 | b. if You include all or a substantial portion of the database 325 | contents in a database in which You have Sui Generis Database 326 | Rights, then the database in which You have Sui Generis Database 327 | Rights (but not its individual contents) is Adapted Material, 328 | including for purposes of Section 3(b); and 329 | 330 | c. You must comply with the conditions in Section 3(a) if You Share 331 | all or a substantial portion of the contents of the database. 332 | 333 | For the avoidance of doubt, this Section 4 supplements and does not 334 | replace Your obligations under this Public License where the Licensed 335 | Rights include other Copyright and Similar Rights. 336 | 337 | 338 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 339 | 340 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 341 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 342 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 343 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 344 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 345 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 346 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 347 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 348 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 349 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 350 | 351 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 352 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 353 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 354 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 355 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 356 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 357 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 358 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 359 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 360 | 361 | c. The disclaimer of warranties and limitation of liability provided 362 | above shall be interpreted in a manner that, to the extent 363 | possible, most closely approximates an absolute disclaimer and 364 | waiver of all liability. 365 | 366 | 367 | Section 6 -- Term and Termination. 368 | 369 | a. This Public License applies for the term of the Copyright and 370 | Similar Rights licensed here. However, if You fail to comply with 371 | this Public License, then Your rights under this Public License 372 | terminate automatically. 373 | 374 | b. Where Your right to use the Licensed Material has terminated under 375 | Section 6(a), it reinstates: 376 | 377 | 1. automatically as of the date the violation is cured, provided 378 | it is cured within 30 days of Your discovery of the 379 | violation; or 380 | 381 | 2. upon express reinstatement by the Licensor. 382 | 383 | For the avoidance of doubt, this Section 6(b) does not affect any 384 | right the Licensor may have to seek remedies for Your violations 385 | of this Public License. 386 | 387 | c. For the avoidance of doubt, the Licensor may also offer the 388 | Licensed Material under separate terms or conditions or stop 389 | distributing the Licensed Material at any time; however, doing so 390 | will not terminate this Public License. 391 | 392 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 393 | License. 394 | 395 | 396 | Section 7 -- Other Terms and Conditions. 397 | 398 | a. The Licensor shall not be bound by any additional or different 399 | terms or conditions communicated by You unless expressly agreed. 400 | 401 | b. Any arrangements, understandings, or agreements regarding the 402 | Licensed Material not stated herein are separate from and 403 | independent of the terms and conditions of this Public License. 404 | 405 | 406 | Section 8 -- Interpretation. 407 | 408 | a. For the avoidance of doubt, this Public License does not, and 409 | shall not be interpreted to, reduce, limit, restrict, or impose 410 | conditions on any use of the Licensed Material that could lawfully 411 | be made without permission under this Public License. 412 | 413 | b. To the extent possible, if any provision of this Public License is 414 | deemed unenforceable, it shall be automatically reformed to the 415 | minimum extent necessary to make it enforceable. If the provision 416 | cannot be reformed, it shall be severed from this Public License 417 | without affecting the enforceability of the remaining terms and 418 | conditions. 419 | 420 | c. No term or condition of this Public License will be waived and no 421 | failure to comply consented to unless expressly agreed to by the 422 | Licensor. 423 | 424 | d. Nothing in this Public License constitutes or may be interpreted 425 | as a limitation upon, or waiver of, any privileges and immunities 426 | that apply to the Licensor or You, including from the legal 427 | processes of any jurisdiction or authority. 428 | 429 | ======================================================================= 430 | 431 | Creative Commons is not a party to its public 432 | licenses. Notwithstanding, Creative Commons may elect to apply one of 433 | its public licenses to material it publishes and in those instances 434 | will be considered the “Licensor.” The text of the Creative Commons 435 | public licenses is dedicated to the public domain under the CC0 Public 436 | Domain Dedication. Except for the limited purpose of indicating that 437 | material is shared under a Creative Commons public license or as 438 | otherwise permitted by the Creative Commons policies published at 439 | creativecommons.org/policies, Creative Commons does not authorize the 440 | use of the trademark "Creative Commons" or any other trademark or logo 441 | of Creative Commons without its prior written consent including, 442 | without limitation, in connection with any unauthorized modifications 443 | to any of its public licenses or any other arrangements, 444 | understandings, or agreements concerning use of licensed material. For 445 | the avoidance of doubt, this paragraph does not form part of the 446 | public licenses. 447 | 448 | Creative Commons may be contacted at creativecommons.org. 449 | 450 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Locating Objects Without Bounding Boxes 2 | PyTorch code for "Locating Objects Without Bounding Boxes" , CVPR 2019 - Oral, Best Paper Finalist (Top 1 %) [[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/html/Ribera_Locating_Objects_Without_Bounding_Boxes_CVPR_2019_paper.html) [[Youtube]](https://youtu.be/8qkrPSjONhA?t=2620) 3 | 4 | 5 | 6 | 7 | 8 | ## Citing this work 9 | ``` 10 | @article{ribera2019, 11 | title={Locating Objects Without Bounding Boxes}, 12 | author={Javier Ribera and David G\"{u}era and Yuhao Chen and Edward J. Delp}, 13 | journal={Proceedings of the Computer Vision and Pattern Recognition (CVPR)}, 14 | month={June}, 15 | year={2019}, 16 | note={{Long Beach, CA}} 17 | } 18 | ``` 19 | 20 | ## Datasets 21 | The datasets used in the paper can be downloaded from: 22 | - [Mall dataset](http://personal.ie.cuhk.edu.hk/~ccloy/downloads_mall_dataset.html) 23 | - [Pupil dataset](http://www.ti.uni-tuebingen.de/Pupil-detection.1827.0.html) 24 | - [Plant dataset](https://engineering.purdue.edu/~sorghum/dataset-plant-centers-2016) 25 | 26 | ## Installation 27 | Use conda to recreate the environment provided with the code: 28 |
 29 | conda env create -f environment.yml
 30 | 
31 | 32 | Activate the environment: 33 |
 34 | conda activate object-locator
 35 | 
36 | 37 | Install the tool: 38 |
 39 | pip install .
 40 | 
41 | (do not forget the period) 42 | 43 | ## Usage 44 | If you are only interested in the code of the Weighted Hausdorff Distance (which is the loss used in the paper and the main contribution), you can just get the [losses.py](object-locator/losses.py) file. If you want to use the entire object location tool: 45 | 46 | Activate the environment: 47 |
 48 | conda activate object-locator
 49 | 
50 | 51 | Run this to get help (usage instructions): 52 |
 53 | python -m object-locator.locate -h
 54 | python -m object-locator.train -h
 55 | 
56 | 57 | Example: 58 | 59 |
 60 | python -m object-locator.locate \
 61 |        --dataset DIRECTORY \
 62 |        --out DIRECTORY \
 63 |        --model CHECKPOINTS \
 64 |        --evaluate \
 65 |        --no-gpu \
 66 |        --radius 5
 67 | 
68 | 69 |
 70 | python -m object-locator.train \
 71 |        --train-dir TRAINING_DIRECTORY \
 72 |        --batch-size 32 \
 73 |        --visdom-env mytrainingsession \
 74 |        --visdom-server localhost \
 75 |        --lr 1e-3 \
 76 |        --val-dir TRAINING_DIRECTORY \
 77 |        --optim Adam \
 78 |        --save saved_model.ckpt
 79 | 
80 | 81 | ## Dataset format 82 | The options `--dataset` and `--train-dir` should point to a directory. 83 | This directory must contain your dataset, meaning: 84 | 1. One file per image to analyze (png, jpg, jpeg, tiff or tif). 85 | 2. One ground truth file called `gt.csv` with the following format: 86 | ``` 87 | filename,count,locations 88 | img1.png,3,"[(28, 52), (58, 53), (135, 50)]" 89 | img2.png,2,"[(92, 47), (33, 82)]" 90 | ``` 91 | Each row of the CSV must describe the ground truth of an image: the count (number) and location of all objects in that image. 92 | The locations are in (y, x) format, being the origin the most top left pixel, y being the pixel row number, and x being the pixel column number. 93 | 94 | Optionally, if you are working on precision agriculture or plant phenotyping you can use an XML file `gt.xml` instead of a CSV. 95 | The required XML specifications can be found in 96 | [https://communityhub.purdue.edu/groups/phenosorg/wiki/APIspecs](https://communityhub.purdue.edu/groups/phenosorg/wiki/APIspecs) 97 | (accessible only to Purdue users) and in [this](https://hammer.figshare.com/articles/Image-based_Plant_Phenotyping_Using_Machine_Learning/7774313) thesis, but this is only useful in agronomy/phenotyping applications. 98 | The XML file is parsed by the file `data_plant_stuff.py`. 99 | 100 | ## Pre-trained models 101 | Models are trained separately for each of the four datasets, as described in the paper: 102 | 1. [Mall dataset](https://lorenz.ecn.purdue.edu/~cvpr2019/pretrained_models/mall,lambdaa=1,BS=32,Adam,LR1e-4.ckpt) 103 | 2. [Pupil dataset](https://lorenz.ecn.purdue.edu/~cvpr2019/pretrained_models/pupil,lambdaa=1,BS=64,SGD,LR1e-3,p=-1,ultrasmallNet.ckpt) 104 | 3. [Plant dataset](https://lorenz.ecn.purdue.edu/~cvpr2019/pretrained_models/plants_20160613_F54,BS=32,Adam,LR1e-5,p=-1.ckpt) 105 | 4. [ShanghaiTechB dataset](https://lorenz.ecn.purdue.edu/~cvpr2019/pretrained_models/shanghai,lambdaa=1,p=-1,BS=32,Adam,LR=1e-4.ckpt) 106 | 107 | The [COPYRIGHT](COPYRIGHT.txt) of the pre-trained models is the same as in this repository. 108 | 109 | As described in the paper, the pre-trained model for the pupil dataset excludes the five central layers. Thus if you want to use this model you will have to use the option `--ultrasmallnet`. 110 | 111 | ## Uninstall 112 |
113 | conda deactivate object-locator
114 | conda env remove --name object-locator
115 | 
116 | 117 | 118 | ## Code Versioning 119 | The code used in the paper corresponds to the tag `used-for-cvpr2019-submission`. 120 | If you want to reproduce the results, checkout that tag with `git checkout used-for-cvpr2019-submission`. 121 | The master branch is the latest version available, with convenient bug fixes and better documentation. 122 | If you want to develop or retrain your models, we recommend the master branch. 123 | Versions numbers follow [semantic versioning](https://semver.org) and the changelog is in [CHANGELOG.md](CHANGELOG.md). 124 | 125 | 126 | ## Creating an issue 127 | If you're experiencing a problem or a bug, creating a GitHub issue is encouraged, but please include the following: 128 | 1. The commit version of this repository that you ran (`git show | head -n 1`) 129 | 2. The dataset you used (including images and the CSV with groundtruth with the [appropriate format](#datasetformat)) 130 | 3. CPU and GPU model(s) you are using 131 | 4. The full standard output of the training log if you are training, and the testing log if you are evaluating (you can upload it to https://pastebin.com) 132 | 5. The operating system you are using 133 | 6. The command you run to train and evaluate 134 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: object-locator 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - imageio=2.3.0 8 | - ipdb=0.11 9 | - ipython=6.3.1 10 | - ipython_genutils=0.2.0 11 | - matplotlib=2.2.2 12 | - numpy=1.14.3 13 | - opencv=3.4.1 14 | - pandas=0.22.0 15 | - parse=1.8.2 16 | - pip=9.0.3 17 | - python=3.6.5 18 | - python-dateutil=2.7.2 19 | - scikit-image=0.13.1 20 | - scikit-learn=0.19.1 21 | - scipy=1.0.1 22 | - setuptools=39.1.0 23 | - tqdm=4.23.1 24 | - xmltodict=0.11.0 25 | - pytorch=1.0.0 26 | - pip: 27 | - ballpark==1.4.0 28 | - visdom==0.1.8.5 29 | - peterpy 30 | - torchvision==0.2.1 31 | 32 | -------------------------------------------------------------------------------- /object-locator/__init__.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | -------------------------------------------------------------------------------- /object-locator/__main__.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | # Allow printing Unicode characters 18 | import os 19 | os.environ["PYTHONIOENCODING"] = 'UTF-8' 20 | 21 | # Execute locate.py script 22 | from . import locate as object_locator 23 | -------------------------------------------------------------------------------- /object-locator/bmm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from paper 3 | "A hybrid parameter estimation algorithm for beta mixtures 4 | and applications to methylation state classification" 5 | https://doi.org/10.1186/s13015-017-0112-1 6 | https://bitbucket.org/genomeinformatics/betamix 7 | """ 8 | 9 | import numpy as np 10 | 11 | from itertools import count 12 | from argparse import ArgumentParser 13 | 14 | import numpy as np 15 | from scipy.stats import beta 16 | 17 | 18 | def _get_values(x, left, right): 19 | y = x[np.logical_and(x>=left, x<=right)] 20 | n = len(y) 21 | if n == 0: 22 | m = (left+right) / 2.0 23 | v = (right-left) / 12.0 24 | else: 25 | m = np.mean(y) 26 | v = np.var(y) 27 | if v == 0.0: 28 | v = (right-left) / (12.0*(n+1)) 29 | return m, v, n 30 | 31 | 32 | def get_initialization(x, ncomponents, limit=0.8): 33 | # TODO: work with specific components instead of just their number 34 | points = np.linspace(0.0, 1.0, ncomponents+2) 35 | means = np.zeros(ncomponents) 36 | variances = np.zeros(ncomponents) 37 | pi = np.zeros(ncomponents) 38 | # init first component 39 | means[0], variances[0], pi[0] = _get_values(x, points[0], points[1]) 40 | # init intermediate components 41 | N = ncomponents - 1 42 | for j in range(1, N): 43 | means[j], variances[j], pi[j] = _get_values(x, points[j], points[j+2]) 44 | # init last component 45 | means[N], variances[N], pi[N] = _get_values(x, points[N+1], points[N+2]) 46 | 47 | # compute parameters ab, pi 48 | ab = [ab_from_mv(m,v) for (m,v) in zip(means,variances)] 49 | pi = pi / pi.sum() 50 | 51 | # adjust first and last 52 | if ab[0][0] >= limit: ab[0] = (limit, ab[0][1]) 53 | if ab[-1][1] >= limit: ab[-1] = (ab[-1][0], limit) 54 | return ab, pi 55 | 56 | 57 | def ab_from_mv(m, v): 58 | """ 59 | estimate beta parameters (a,b) from given mean and variance; 60 | return (a,b). 61 | 62 | Note, for uniform distribution on [0,1], (m,v)=(0.5,1/12) 63 | """ 64 | phi = m*(1-m)/v - 1 # z = 2 for uniform distribution 65 | return (phi*m, phi*(1-m)) # a = b = 1 for uniform distribution 66 | 67 | 68 | def get_weights(x, ab, pi): 69 | """return nsamples X ncomponents matrix with association weights""" 70 | bpdf = beta.pdf 71 | n, c = len(x), len(ab) 72 | y = np.zeros((n,c), dtype=float) 73 | s = np.zeros((n,1), dtype=float) 74 | for (j, p,(a,b)) in zip(count(), pi, ab): 75 | y[:,j] = p * bpdf(x, a, b) 76 | s = np.sum(y,1).reshape((n,1)) 77 | with np.warnings.catch_warnings(): 78 | np.warnings.filterwarnings('ignore') 79 | w = y / s # this may produce inf or nan; this is o.k.! 80 | # clean up weights w, remove infs, nans, etc. 81 | wfirst = np.array([1] + [0]*(c-1), dtype=float) 82 | wlast = np.array([0]*(c-1) + [1], dtype=float) 83 | bad = (~np.isfinite(w)).any(axis=1) 84 | badfirst = np.logical_and(bad, x<0.5) 85 | badlast = np.logical_and(bad, x>=0.5) 86 | w[badfirst,:] = wfirst 87 | w[badlast,:] = wlast 88 | # now all weights are valid finite values and sum to 1 for each row 89 | assert np.all(np.isfinite(w)), (w, np.isfinite(w)) 90 | assert np.allclose(np.sum(w,1), 1.0), np.max(np.abs(np.sum(w,1)-1.0)) 91 | return w 92 | 93 | 94 | def relerror(x,y): 95 | if x==y: return 0.0 96 | return abs(x-y)/max(abs(x),abs(y)) 97 | 98 | def get_delta(ab, abold, pi, piold): 99 | epi = max(relerror(p,po) for (p,po) in zip(pi,piold)) 100 | ea = max(relerror(a,ao) for (a,_), (ao,_) in zip(ab,abold)) 101 | eb = max(relerror(b,bo) for (_,b), (_,bo) in zip(ab,abold)) 102 | return max(epi,ea,eb) 103 | 104 | 105 | def estimate_mixture(x, init, steps=1000, tolerance=1E-5): 106 | """ 107 | estimate a beta mixture model from the given data x 108 | with the given number of components and component types 109 | """ 110 | (ab, pi) = init 111 | n, ncomponents = len(x), len(ab) 112 | 113 | for step in count(): 114 | if step >= steps: 115 | break 116 | abold = list(ab) 117 | piold = pi[:] 118 | # E-step: compute component memberships for each x 119 | w = get_weights(x, ab, pi) 120 | # compute component means and variances and parameters 121 | for j in range(ncomponents): 122 | wj = w[:,j] 123 | pij = np.sum(wj) 124 | m = np.dot(wj,x) / pij 125 | v = np.dot(wj,(x-m)**2) / pij 126 | if np.isnan(m) or np.isnan(v): 127 | m = 0.5; v = 1/12 # uniform 128 | ab[j]=(1,1) # uniform 129 | assert pij == 0.0 130 | else: 131 | assert np.isfinite(m) and np.isfinite(v), (j,m,v,pij) 132 | ab[j] = ab_from_mv(m,v) 133 | pi[j] = pij / n 134 | delta = get_delta(ab, abold, pi, piold) 135 | if delta < tolerance: 136 | break 137 | usedsteps = step + 1 138 | return (ab, pi, usedsteps) 139 | 140 | 141 | def estimate(x, components, steps=1000, tolerance=1E-4): 142 | init = get_initialization(x, len(components)) 143 | (ab, pi, usedsteps) = estimate_mixture(x, init, steps=steps, tolerance=tolerance) 144 | return (ab, pi, usedsteps) 145 | 146 | 147 | class AccumHistogram1D(): 148 | """https://raw.githubusercontent.com/NichtJens/numpy-accumulative-histograms/master/accuhist.py""" 149 | 150 | def __init__(self, nbins, xlow, xhigh): 151 | self.nbins = nbins 152 | self.xlow = xlow 153 | self.xhigh = xhigh 154 | 155 | self.range = (xlow, xhigh) 156 | 157 | self.hist, edges = np.histogram([], bins=nbins, range=self.range) 158 | self.bins = (edges[:-1] + edges[1:]) / 2. 159 | 160 | def fill(self, arr): 161 | hist, _ = np.histogram(arr, bins=self.nbins, range=self.range) 162 | self.hist += hist 163 | 164 | @property 165 | def data(self): 166 | return self.bins, self.hist 167 | 168 | 169 | -------------------------------------------------------------------------------- /object-locator/checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | # https://stackoverflow.com/questions/115983/how-can-i-add-an-empty-directory-to-a-git-repository#932982 2 | # Ignore everything in this directory 3 | * 4 | # Except this file 5 | !.gitignore 6 | -------------------------------------------------------------------------------- /object-locator/data.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import os 19 | import random 20 | 21 | from PIL import Image 22 | import numpy as np 23 | import pandas as pd 24 | import torch 25 | import torchvision 26 | from ballpark import ballpark 27 | 28 | from . import get_image_size 29 | 30 | IMG_EXTENSIONS = ['.png', '.jpeg', '.jpg', '.tiff', '.tif'] 31 | 32 | torch.set_default_dtype(torch.float32) 33 | 34 | 35 | def build_dataset(directory, 36 | transforms=None, 37 | max_dataset_size=float('inf'), 38 | ignore_gt=False, 39 | seed=0): 40 | """ 41 | Build a dataset from a directory. 42 | Depending if the directory contains a CSV or an XML dataset, 43 | it builds an XMLDataset or a CSVDataset, which are subclasses 44 | of torch.utils.data.Dataset. 45 | :param directory: Directory with all the images and the CSV file. 46 | :param transform: Transform to be applied to each image. 47 | :param max_dataset_size: Only use the first N images in the directory. 48 | :param ignore_gt: Ignore the GT of the dataset, 49 | i.e, provide samples without locations or counts. 50 | :param seed: Random seed. 51 | :return: An XMLDataset or CSVDataset instance. 52 | """ 53 | if any(fn.endswith('.csv') for fn in os.listdir(directory)) \ 54 | or ignore_gt: 55 | dset = CSVDataset(directory=directory, 56 | transforms=transforms, 57 | max_dataset_size=max_dataset_size, 58 | ignore_gt=ignore_gt, 59 | seed=seed) 60 | else: 61 | from . import data_plant_stuff 62 | dset = data_plant_stuff.\ 63 | XMLDataset(directory=directory, 64 | transforms=transforms, 65 | max_dataset_size=max_dataset_size, 66 | ignore_gt=ignore_gt, 67 | seed=seed) 68 | 69 | return dset 70 | 71 | 72 | def get_train_val_loaders(train_dir, 73 | collate_fn, 74 | height, 75 | width, 76 | no_data_augmentation=False, 77 | max_trainset_size=np.infty, 78 | seed=0, 79 | batch_size=1, 80 | drop_last_batch=False, 81 | shuffle=True, 82 | num_workers=0, 83 | val_dir=None, 84 | max_valset_size=np.infty): 85 | """ 86 | Create a training loader and a validation set. 87 | If the validation directory is 'auto', 88 | 20% of the dataset is used for validation. 89 | 90 | :param train_dir: Directory with all the training images and the CSV file. 91 | :param train_transforms: Transform to be applied to each training image. 92 | :param max_trainset_size: Only use first N images for training. 93 | :param collate_fn: Function to assemble samples into batches. 94 | :param height: Resize the images to this height. 95 | :param width: Resize the images to this width. 96 | :param no_data_augmentation: Do not perform data augmentation. 97 | :param seed: Random seed. 98 | :param batch_size: Number of samples in a batch, for training. 99 | :param drop_last_batch: Drop the last incomplete batch during training 100 | :param shuffle: Randomly shuffle the dataset before each epoch. 101 | :param num_workers: Number of subprocesses dedicated for data loading. 102 | :param val_dir: Directory with all the training images and the CSV file. 103 | :param max_valset_size: Only use first N images for validation. 104 | """ 105 | 106 | # Data augmentation for training 107 | training_transforms = [] 108 | if not no_data_augmentation: 109 | training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5, 110 | seed=seed)] 111 | training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5, 112 | seed=seed)] 113 | training_transforms += [ScaleImageAndLabel(size=(height, width))] 114 | training_transforms += [torchvision.transforms.ToTensor()] 115 | training_transforms += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), 116 | (0.5, 0.5, 0.5))] 117 | training_transforms = torchvision.transforms.Compose(training_transforms) 118 | 119 | # Data augmentation for validation 120 | validation_transforms = torchvision.transforms.Compose([ 121 | ScaleImageAndLabel(size=(height, width)), 122 | torchvision.transforms.ToTensor(), 123 | torchvision.transforms.\ 124 | Normalize((0.5, 0.5, 0.5), 125 | (0.5, 0.5, 0.5)), 126 | ]) 127 | 128 | # Training dataset 129 | trainset = build_dataset(directory=train_dir, 130 | transforms=training_transforms, 131 | max_dataset_size=max_trainset_size, 132 | seed=seed) 133 | 134 | # Validation dataset 135 | if val_dir is not None: 136 | if val_dir == 'auto': 137 | # Create a dataset just as in training 138 | valset = build_dataset(directory=train_dir, 139 | transforms=validation_transforms, 140 | max_dataset_size=max_trainset_size, 141 | seed=seed) 142 | 143 | # Split 80% for training, 20% for validation 144 | n_imgs_for_training = int(round(0.8*len(trainset))) 145 | if isinstance(trainset, CSVDataset): 146 | if trainset.there_is_gt: 147 | trainset.csv_df = \ 148 | trainset.csv_df[:n_imgs_for_training] 149 | valset.csv_df = \ 150 | valset.csv_df[n_imgs_for_training:].reset_index() 151 | else: 152 | trainset.listfiles = \ 153 | trainset.listfiles[:n_imgs_for_training] 154 | valset.listfiles = \ 155 | valset.listfiles[n_imgs_for_training:] 156 | else: # isinstance(trainset, XMLDataset): 157 | trainset.dict_list = trainset.dict_list[:n_imgs_for_training] 158 | valset.dict_list = valset.dict_list[n_imgs_for_training:] 159 | 160 | else: 161 | valset = build_dataset(val_dir, 162 | transforms=validation_transforms, 163 | max_dataset_size=max_valset_size, 164 | seed=seed) 165 | valset_loader = torch.utils.data.DataLoader(valset, 166 | batch_size=1, 167 | shuffle=True, 168 | num_workers=num_workers, 169 | collate_fn=csv_collator) 170 | else: 171 | valset, valset_loader = None, None 172 | 173 | print(f'# images for training: ' 174 | f'{ballpark(len(trainset))}') 175 | if valset is not None: 176 | print(f'# images for validation: ' 177 | f'{ballpark(len(valset))}') 178 | else: 179 | print('W: no validation set was selected!') 180 | 181 | # Build data loaders from the datasets 182 | trainset_loader = torch.utils.data.DataLoader(trainset, 183 | batch_size=batch_size, 184 | drop_last=drop_last_batch, 185 | shuffle=True, 186 | num_workers=num_workers, 187 | collate_fn=csv_collator) 188 | if valset is not None: 189 | valset_loader = torch.utils.data.DataLoader(valset, 190 | batch_size=1, 191 | shuffle=True, 192 | num_workers=num_workers, 193 | collate_fn=csv_collator) 194 | 195 | return trainset_loader, valset_loader 196 | 197 | 198 | class CSVDataset(torch.utils.data.Dataset): 199 | def __init__(self, 200 | directory, 201 | transforms=None, 202 | max_dataset_size=float('inf'), 203 | ignore_gt=False, 204 | seed=0): 205 | """CSVDataset. 206 | The sample images of this dataset must be all inside one directory. 207 | Inside the same directory, there must be one CSV file. 208 | This file must contain one row per image. 209 | It can contain as many columns as wanted, i.e, filename, count... 210 | 211 | :param directory: Directory with all the images and the CSV file. 212 | :param transform: Transform to be applied to each image. 213 | :param max_dataset_size: Only use the first N images in the directory. 214 | :param ignore_gt: Ignore the GT of the dataset, 215 | i.e, provide samples without locations or counts. 216 | :param seed: Random seed. 217 | """ 218 | 219 | self.root_dir = directory 220 | self.transforms = transforms 221 | 222 | # Get groundtruth from CSV file 223 | listfiles = os.listdir(directory) 224 | csv_filename = None 225 | for filename in listfiles: 226 | if filename.endswith('.csv'): 227 | csv_filename = filename 228 | break 229 | 230 | # Ignore files that are not images 231 | listfiles = [f for f in listfiles 232 | if any(f.lower().endswith(ext) for ext in IMG_EXTENSIONS)] 233 | 234 | # Shuffle list of files 235 | np.random.seed(seed) 236 | random.shuffle(listfiles) 237 | 238 | if len(listfiles) == 0: 239 | raise ValueError(f"There are no images in '{directory}'") 240 | 241 | self.there_is_gt = (csv_filename is not None) and (not ignore_gt) 242 | 243 | # CSV does not exist (no GT available) 244 | if not self.there_is_gt: 245 | print('W: The dataset directory %s does not contain a CSV file with groundtruth. \n' 246 | ' Metrics will not be evaluated. Only estimations will be returned.' % directory) 247 | self.csv_df = None 248 | self.listfiles = listfiles 249 | 250 | # Make dataset smaller 251 | self.listfiles = self.listfiles[0:min(len(self.listfiles), 252 | max_dataset_size)] 253 | 254 | # CSV does exist (GT is available) 255 | else: 256 | self.csv_df = pd.read_csv(os.path.join(directory, csv_filename)) 257 | 258 | # Shuffle CSV dataframe 259 | self.csv_df = self.csv_df.sample(frac=1).reset_index(drop=True) 260 | 261 | # Make dataset smaller 262 | self.csv_df = self.csv_df[0:min( 263 | len(self.csv_df), max_dataset_size)] 264 | 265 | def __len__(self): 266 | if self.there_is_gt: 267 | return len(self.csv_df) 268 | else: 269 | return len(self.listfiles) 270 | 271 | def __getitem__(self, idx): 272 | """Get one element of the dataset. 273 | Returns a tuple. The first element is the image. 274 | The second element is a dictionary where the keys are the columns of the CSV. 275 | If the CSV did not exist in the dataset directory, 276 | the dictionary will only contain the filename of the image. 277 | :param idx: Index of the image in the dataset to get. 278 | """ 279 | 280 | if self.there_is_gt: 281 | img_abspath = os.path.join(self.root_dir, self.csv_df.ix[idx].filename) 282 | dictionary = dict(self.csv_df.ix[idx]) 283 | else: 284 | img_abspath = os.path.join(self.root_dir, self.listfiles[idx]) 285 | dictionary = {'filename': self.listfiles[idx]} 286 | 287 | img = Image.open(img_abspath) 288 | 289 | if self.there_is_gt: 290 | # str -> lists 291 | dictionary['locations'] = eval(dictionary['locations']) 292 | dictionary['locations'] = [ 293 | list(loc) for loc in dictionary['locations']] 294 | 295 | # list --> Tensors 296 | with torch.no_grad(): 297 | dictionary['locations'] = torch.tensor( 298 | dictionary['locations'], dtype=torch.get_default_dtype()) 299 | dictionary['count'] = torch.tensor( 300 | [dictionary['count']], dtype=torch.get_default_dtype()) 301 | 302 | # Record original size 303 | orig_width, orig_height = get_image_size.get_image_size(img_abspath) 304 | with torch.no_grad(): 305 | orig_height = torch.tensor(orig_height, 306 | dtype=torch.get_default_dtype()) 307 | orig_width = torch.tensor(orig_width, 308 | dtype=torch.get_default_dtype()) 309 | dictionary['orig_width'] = orig_width 310 | dictionary['orig_height'] = orig_height 311 | 312 | img_transformed = img 313 | transformed_dictionary = dictionary 314 | 315 | # Apply all transformations provided 316 | if self.transforms is not None: 317 | for transform in self.transforms.transforms: 318 | if hasattr(transform, 'modifies_label'): 319 | img_transformed, transformed_dictionary = \ 320 | transform(img_transformed, transformed_dictionary) 321 | else: 322 | img_transformed = transform(img_transformed) 323 | 324 | # Prevents crash when making a batch out of an empty tensor 325 | if self.there_is_gt: 326 | if dictionary['count'][0] == 0: 327 | with torch.no_grad(): 328 | dictionary['locations'] = torch.tensor([-1, -1], 329 | dtype=torch.get_default_dtype()) 330 | 331 | return (img_transformed, transformed_dictionary) 332 | 333 | 334 | def csv_collator(samples): 335 | """Merge a list of samples to form a batch. 336 | The batch is a 2-element tuple, being the first element 337 | the BxHxW tensor and the second element a list of dictionaries. 338 | 339 | :param samples: List of samples returned by CSVDataset as (img, dict) tuples. 340 | """ 341 | 342 | imgs = [] 343 | dicts = [] 344 | 345 | for sample in samples: 346 | img = sample[0] 347 | dictt = sample[1] 348 | 349 | # # We cannot deal with images with 0 objects (WHD is not defined) 350 | # if dictt['count'][0] == 0: 351 | # continue 352 | 353 | imgs.append(img) 354 | dicts.append(dictt) 355 | 356 | data = torch.stack(imgs) 357 | 358 | return data, dicts 359 | 360 | 361 | class RandomHorizontalFlipImageAndLabel(object): 362 | """ Horizontally flip a numpy array image and the GT with probability p """ 363 | 364 | def __init__(self, p, seed=0): 365 | self.modifies_label = True 366 | self.p = p 367 | np.random.seed(seed) 368 | 369 | def __call__(self, img, dictionary): 370 | transformed_img = img 371 | transformed_dictionary = dictionary 372 | 373 | if random.random() < self.p: 374 | transformed_img = hflip(img) 375 | width = img.size[0] 376 | for l, loc in enumerate(dictionary['locations']): 377 | dictionary['locations'][l][1] = (width - 1) - loc[1] 378 | 379 | return transformed_img, transformed_dictionary 380 | 381 | 382 | class RandomVerticalFlipImageAndLabel(object): 383 | """ Vertically flip a numpy array image and the GT with probability p """ 384 | 385 | def __init__(self, p, seed=0): 386 | self.modifies_label = True 387 | self.p = p 388 | np.random.seed(seed) 389 | 390 | def __call__(self, img, dictionary): 391 | transformed_img = img 392 | transformed_dictionary = dictionary 393 | 394 | if random.random() < self.p: 395 | transformed_img = vflip(img) 396 | height = img.size[1] 397 | for l, loc in enumerate(dictionary['locations']): 398 | dictionary['locations'][l][0] = (height - 1) - loc[0] 399 | 400 | return transformed_img, transformed_dictionary 401 | 402 | 403 | class ScaleImageAndLabel(torchvision.transforms.Resize): 404 | """ 405 | Scale a PIL Image and the GT to a given size. 406 | If there is no GT, then only scale the PIL Image. 407 | 408 | Args: 409 | size: Desired output size (h, w). 410 | interpolation (int, optional): Desired interpolation. 411 | Default is ``PIL.Image.BILINEAR``. 412 | """ 413 | 414 | def __init__(self, size, interpolation=Image.BILINEAR): 415 | self.modifies_label = True 416 | self.size = size 417 | super(ScaleImageAndLabel, self).__init__(size, interpolation) 418 | 419 | def __call__(self, img, dictionary): 420 | 421 | old_width, old_height = img.size 422 | scale_h = self.size[0]/old_height 423 | scale_w = self.size[1]/old_width 424 | 425 | # Scale image to new size 426 | img = super(ScaleImageAndLabel, self).__call__(img) 427 | 428 | # Scale GT 429 | if 'locations' in dictionary and len(dictionary['locations']) > 0: 430 | # print(dictionary['locations'].type()) 431 | # print(torch.tensor([scale_h, scale_w]).type()) 432 | with torch.no_grad(): 433 | dictionary['locations'] *= torch.tensor([scale_h, scale_w]) 434 | dictionary['locations'] = torch.round(dictionary['locations']) 435 | ys = torch.clamp( 436 | dictionary['locations'][:, 0], 0, self.size[0]) 437 | xs = torch.clamp( 438 | dictionary['locations'][:, 1], 0, self.size[1]) 439 | dictionary['locations'] = torch.cat((ys.view(-1, 1), 440 | xs.view(-1, 1)), 441 | 1) 442 | 443 | # Indicate new size in dictionary 444 | with torch.no_grad(): 445 | dictionary['resized_height'] = self.size[0] 446 | dictionary['resized_width'] = self.size[1] 447 | 448 | return img, dictionary 449 | 450 | 451 | def hflip(img): 452 | """Horizontally flip the given PIL Image. 453 | Args: 454 | img (PIL Image): Image to be flipped. 455 | Returns: 456 | PIL Image: Horizontall flipped image. 457 | """ 458 | if not _is_pil_image(img): 459 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 460 | 461 | return img.transpose(Image.FLIP_LEFT_RIGHT) 462 | 463 | 464 | def vflip(img): 465 | """Vertically flip the given PIL Image. 466 | Args: 467 | img (PIL Image): Image to be flipped. 468 | Returns: 469 | PIL Image: Vertically flipped image. 470 | """ 471 | if not _is_pil_image(img): 472 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 473 | 474 | return img.transpose(Image.FLIP_TOP_BOTTOM) 475 | 476 | 477 | def _is_pil_image(img): 478 | return isinstance(img, Image.Image) 479 | 480 | 481 | """ 482 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 483 | All rights reserved. 484 | 485 | This software is covered by US patents and copyright. 486 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 487 | 488 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 489 | 490 | Last Modified: 10/02/2019 491 | """ 492 | -------------------------------------------------------------------------------- /object-locator/data_plant_stuff.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import os 19 | import random 20 | from collections import OrderedDict 21 | 22 | from PIL import Image 23 | import numpy as np 24 | import torch 25 | from torchvision import datasets 26 | from torchvision import transforms 27 | import xmltodict 28 | from parse import parse 29 | 30 | from . import get_image_size 31 | 32 | IMG_EXTENSIONS = ['.png', '.jpeg', '.jpg', '.tiff'] 33 | 34 | torch.set_default_dtype(torch.float32) 35 | 36 | 37 | class XMLDataset(torch.utils.data.Dataset): 38 | def __init__(self, 39 | directory, 40 | transforms=None, 41 | max_dataset_size=float('inf'), 42 | ignore_gt=False, 43 | seed=0): 44 | """XMLDataset. 45 | The sample images of this dataset must be all inside one directory. 46 | Inside the same directory, there must be one XML file as described by 47 | https://communityhub.purdue.edu/groups/phenosorg/wiki/APIspecs 48 | (minimum XML API version is v.0.4.0). 49 | If there is no XML file, metrics will not be computed, 50 | and only estimations will be provided. 51 | :param directory: Directory with all the images and the XML file. 52 | :param transform: Transform to be applied to each image. 53 | :param max_dataset_size: Only use the first N images in the directory. 54 | :param ignore_gt: Ignore the GT in the XML file, 55 | i.e, provide samples without plant locations or counts. 56 | :param seed: Random seed. 57 | """ 58 | 59 | self.root_dir = directory 60 | self.transforms = transforms 61 | 62 | # Get list of files in the dataset directory, 63 | # and the filename of the XML 64 | listfiles = os.listdir(directory) 65 | xml_filenames = [f for f in listfiles if f.endswith('.xml')] 66 | if len(xml_filenames) == 1: 67 | xml_filename = xml_filenames[0] 68 | elif len(xml_filenames) == 0: 69 | xml_filename = None 70 | else: 71 | print(f"E: there is more than one XML file in '{directory}'") 72 | exit(-1) 73 | 74 | # Ignore files that are not images 75 | listfiles = [f for f in listfiles 76 | if any(f.lower().endswith(ext) for ext in IMG_EXTENSIONS)] 77 | 78 | # Shuffle list of files 79 | np.random.seed(seed) 80 | random.shuffle(listfiles) 81 | 82 | if len(listfiles) == 0: 83 | raise ValueError(f"There are no images in '{directory}'") 84 | 85 | if xml_filename is None: 86 | print('W: The dataset directory %s does not contain ' 87 | 'a XML file with groundtruth. Metrics will not be evaluated.' 88 | 'Only estimations will be returned.' % directory) 89 | 90 | self.there_is_gt = (xml_filename is not None) and (not ignore_gt) 91 | 92 | # Read all XML as a string 93 | with open(os.path.join(directory, xml_filename), 'r') as fd: 94 | xml_str = fd.read() 95 | 96 | # Convert to dictionary 97 | # (some elements we expect to have multiple repetitions, 98 | # so put them in a list) 99 | xml_dict = xmltodict.parse(xml_str, 100 | force_list=['field', 101 | 'panel', 102 | 'plot', 103 | 'plant']) 104 | 105 | # Check API version number 106 | try: 107 | api_version = xml_dict['fields']['@apiversion'] 108 | except: 109 | # An unknown version number means it's the very first one 110 | # when we did not have api version numbers 111 | api_version = '0.1.0' 112 | major_version, minor_version, _ = parse('{}.{}.{}', api_version) 113 | major_version = int(major_version) 114 | minor_version = int(minor_version) 115 | if not(major_version == 0 and minor_version == 4): 116 | raise ValueError('An XML with API v0.4 is required.') 117 | 118 | # Create the dictionary with the entire dataset 119 | dictt = {} 120 | for field in xml_dict['fields']['field']: 121 | for panel in field['panels']['panel']: 122 | for plot in panel['plots']['plot']: 123 | 124 | if self.there_is_gt and \ 125 | not('plant_count' in plot and \ 126 | 'plants' in plot): 127 | # There is GT for some plots but not this one 128 | continue 129 | 130 | filename = plot['orthophoto_chop_filename'] 131 | if 'plot_number' in plot: 132 | plot_number = plot['plot_number'] 133 | else: 134 | plot_number = 'unknown' 135 | if 'subrow_grid_location' in plot: 136 | subrow_grid_x = \ 137 | int(plot['subrow_grid_location']['x']['#text']) 138 | subrow_grid_y = \ 139 | int(plot['subrow_grid_location']['y']['#text']) 140 | else: 141 | subrow_grid_x = 'unknown' 142 | subrow_grid_y = 'unknown' 143 | if 'row_number' in plot: 144 | row_number = plot['row_number'] 145 | else: 146 | row_number = 'unknown' 147 | if 'range_number' in plot: 148 | range_number = plot['range_number'] 149 | else: 150 | range_number = 'unknown' 151 | img_abspath = os.path.join(self.root_dir, filename) 152 | orig_width, orig_height = \ 153 | get_image_size.get_image_size(img_abspath) 154 | with torch.no_grad(): 155 | orig_height = torch.tensor( 156 | orig_height, dtype=torch.get_default_dtype()) 157 | orig_width = torch.tensor( 158 | orig_width, dtype=torch.get_default_dtype()) 159 | dictt[filename] = {'filename': filename, 160 | 'plot_number': plot_number, 161 | 'subrow_grid_location_x': subrow_grid_x, 162 | 'subrow_grid_location_y': subrow_grid_y, 163 | 'row_number': row_number, 164 | 'range_number': range_number, 165 | 'orig_width': orig_width, 166 | 'orig_height': orig_height} 167 | if self.there_is_gt: 168 | count = int(plot['plant_count']) 169 | locations = [] 170 | for plant in plot['plants']['plant']: 171 | for y in plant['location']['y']: 172 | if y['@units'] == 'pixels' and \ 173 | y['@wrt'] == 'plot': 174 | y = float(y['#text']) 175 | break 176 | for x in plant['location']['x']: 177 | if x['@units'] == 'pixels' and \ 178 | x['@wrt'] == 'plot': 179 | x = float(x['#text']) 180 | break 181 | locations.append([y, x]) 182 | dictt[filename]['count'] = count 183 | dictt[filename]['locations'] = locations 184 | 185 | # Use an Ordered Dictionary to allow random access 186 | dictt = OrderedDict(dictt.items()) 187 | self.dict_list = list(dictt.items()) 188 | 189 | # Make dataset smaller 190 | new_dataset_length = min(len(dictt), max_dataset_size) 191 | dictt = {key: elem_dict 192 | for key, elem_dict in 193 | self.dict_list[:new_dataset_length]} 194 | self.dict_list = list(dictt.items()) 195 | 196 | def __len__(self): 197 | return len(self.dict_list) 198 | 199 | def __getitem__(self, idx): 200 | """Get one element of the dataset. 201 | Returns a tuple. The first element is the image. 202 | The second element is a dictionary containing the labels of that image. 203 | The dictionary may not contain the location and count if the original 204 | XML did not include it. 205 | 206 | :param idx: Index of the image in the dataset to get. 207 | """ 208 | 209 | filename, dictionary = self.dict_list[idx] 210 | img_abspath = os.path.join(self.root_dir, filename) 211 | 212 | if self.there_is_gt: 213 | # list --> Tensors 214 | with torch.no_grad(): 215 | dictionary['locations'] = torch.tensor( 216 | dictionary['locations'], 217 | dtype=torch.get_default_dtype()) 218 | dictionary['count'] = torch.tensor( 219 | dictionary['count'], 220 | dtype=torch.get_default_dtype()) 221 | # else: 222 | # filename = self.listfiles[idx] 223 | # img_abspath = os.path.join(self.root_dir, filename) 224 | # orig_width, orig_height = \ 225 | # get_image_size.get_image_size(img_abspath) 226 | # with torch.no_grad(): 227 | # orig_height = torch.tensor( 228 | # orig_height, dtype=torch.get_default_dtype()) 229 | # orig_width = torch.tensor( 230 | # orig_width, dtype=torch.get_default_dtype()) 231 | # dictionary = {'filename': self.listfiles[idx], 232 | # 'orig_width': orig_width, 233 | # 'orig_height': orig_height} 234 | 235 | img = Image.open(img_abspath) 236 | 237 | img_transformed = img 238 | transformed_dictionary = dictionary 239 | 240 | # Apply all transformations provided 241 | if self.transforms is not None: 242 | for transform in self.transforms.transforms: 243 | if hasattr(transform, 'modifies_label'): 244 | img_transformed, transformed_dictionary = \ 245 | transform(img_transformed, transformed_dictionary) 246 | else: 247 | img_transformed = transform(img_transformed) 248 | 249 | # Prevents crash when making a batch out of an empty tensor 250 | if self.there_is_gt and dictionary['count'].item() == 0: 251 | with torch.no_grad(): 252 | dictionary['locations'] = torch.tensor([-1, -1], 253 | dtype=torch.get_default_dtype()) 254 | 255 | return (img_transformed, transformed_dictionary) 256 | 257 | 258 | """ 259 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 260 | All rights reserved. 261 | 262 | This software is covered by US patents and copyright. 263 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 264 | 265 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 266 | 267 | Last Modified: 10/02/2019 268 | """ 269 | -------------------------------------------------------------------------------- /object-locator/find_lr.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | __copyright__ = \ 4 | """ 5 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 6 | All rights reserved. 7 | 8 | This software is covered by US patents and copyright. 9 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 10 | 11 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 12 | 13 | Last Modified: 10/02/2019 14 | """ 15 | __license__ = "CC BY-NC-SA 4.0" 16 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 17 | __version__ = "1.6.0" 18 | 19 | 20 | import math 21 | import os 22 | from itertools import chain 23 | from tqdm import tqdm 24 | 25 | import numpy as np 26 | import torch 27 | import torch.optim as optim 28 | from torch import nn 29 | from torch.autograd import Variable 30 | from torchvision import transforms 31 | from torch.utils.data import DataLoader 32 | import torch.optim.lr_scheduler 33 | import matplotlib 34 | matplotlib.use('Agg') 35 | import skimage.transform 36 | from peterpy import peter 37 | from ballpark import ballpark 38 | from matplotlib import pyplot as plt 39 | 40 | from . import losses 41 | from .models import unet_model 42 | from .data import CSVDataset 43 | from .data import csv_collator 44 | from .data import RandomHorizontalFlipImageAndLabel 45 | from .data import RandomVerticalFlipImageAndLabel 46 | from .data import ScaleImageAndLabel 47 | from . import argparser 48 | 49 | 50 | # Parse command line arguments 51 | args = argparser.parse_command_args('training') 52 | 53 | # Tensor type to use, select CUDA or not 54 | torch.set_default_dtype(torch.float32) 55 | device_cpu = torch.device('cpu') 56 | device = torch.device('cuda') if args.cuda else device_cpu 57 | 58 | # Set seeds 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | if args.cuda: 62 | torch.cuda.manual_seed_all(args.seed) 63 | 64 | # Data loading code 65 | training_transforms = [] 66 | if not args.no_data_augm: 67 | training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5)] 68 | training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5)] 69 | training_transforms += [ScaleImageAndLabel(size=(args.height, args.width))] 70 | training_transforms += [transforms.ToTensor()] 71 | training_transforms += [transforms.Normalize((0.5, 0.5, 0.5), 72 | (0.5, 0.5, 0.5))] 73 | trainset = CSVDataset(args.train_dir, 74 | transforms=transforms.Compose(training_transforms), 75 | max_dataset_size=args.max_trainset_size) 76 | trainset_loader = DataLoader(trainset, 77 | batch_size=args.batch_size, 78 | drop_last=args.drop_last_batch, 79 | shuffle=True, 80 | num_workers=args.nThreads, 81 | collate_fn=csv_collator) 82 | 83 | # Model 84 | with peter('Building network'): 85 | model = unet_model.UNet(3, 1, 86 | height=args.height, 87 | width=args.width, 88 | known_n_points=args.n_points) 89 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 90 | print(f" with {ballpark(num_params)} trainable parameters. ", end='') 91 | model = nn.DataParallel(model) 92 | model.to(device) 93 | 94 | 95 | # Loss function 96 | loss_regress = nn.SmoothL1Loss() 97 | loss_loc = losses.WeightedHausdorffDistance(resized_height=args.height, 98 | resized_width=args.width, 99 | p=args.p, 100 | return_2_terms=True, 101 | device=device) 102 | l1_loss = nn.L1Loss(size_average=False) 103 | mse_loss = nn.MSELoss(reduce=False) 104 | 105 | optimizer = optim.SGD(model.parameters(), 106 | lr=999) # will be set later 107 | 108 | 109 | def find_lr(init_value = 1e-6, final_value=1e-3, beta = 0.7): 110 | num = len(trainset_loader)-1 111 | mult = (final_value / init_value) ** (1/num) 112 | lr = init_value 113 | optimizer.param_groups[0]['lr'] = lr 114 | avg_loss = 0. 115 | best_loss = 0. 116 | batch_num = 0 117 | losses = [] 118 | log_lrs = [] 119 | for imgs, dicts in tqdm(trainset_loader): 120 | batch_num += 1 121 | 122 | # Pull info from this batch and move to device 123 | imgs = imgs.to(device) 124 | imgs = Variable(imgs) 125 | target_locations = [dictt['locations'].to(device) 126 | for dictt in dicts] 127 | target_counts = [dictt['count'].to(device) 128 | for dictt in dicts] 129 | target_orig_heights = [dictt['orig_height'].to(device) 130 | for dictt in dicts] 131 | target_orig_widths = [dictt['orig_width'].to(device) 132 | for dictt in dicts] 133 | 134 | # Lists -> Tensor batches 135 | target_counts = torch.stack(target_counts) 136 | target_orig_heights = torch.stack(target_orig_heights) 137 | target_orig_widths = torch.stack(target_orig_widths) 138 | target_orig_sizes = torch.stack((target_orig_heights, 139 | target_orig_widths)).transpose(0, 1) 140 | # As before, get the loss for this mini-batch of inputs/outputs 141 | optimizer.zero_grad() 142 | est_maps, est_counts = model.forward(imgs) 143 | term1, term2 = loss_loc.forward(est_maps, 144 | target_locations, 145 | target_orig_sizes) 146 | target_counts = target_counts.view(-1) 147 | est_counts = est_counts.view(-1) 148 | target_counts = target_counts.view(-1) 149 | term3 = loss_regress.forward(est_counts, target_counts) 150 | term3 *= args.lambdaa 151 | loss = term1 + term2 + term3 152 | 153 | # Compute the smoothed loss 154 | avg_loss = beta * avg_loss + (1-beta) *loss.item() 155 | smoothed_loss = avg_loss / (1 - beta**batch_num) 156 | 157 | # Stop if the loss is exploding 158 | if (batch_num > 1 and smoothed_loss > 4 * best_loss): 159 | return log_lrs, losses 160 | 161 | # Record the best loss 162 | if smoothed_loss < best_loss or batch_num==1: 163 | best_loss = smoothed_loss 164 | 165 | # Store the values 166 | losses.append(smoothed_loss) 167 | log_lrs.append(math.log10(lr)) 168 | 169 | # Do the SGD step 170 | loss.backward() 171 | optimizer.step() 172 | 173 | # Update the lr for the next step 174 | lr *= mult 175 | optimizer.param_groups[0]['lr'] = lr 176 | return log_lrs, losses 177 | 178 | logs, losses = find_lr() 179 | plt.plot(logs, losses) 180 | plt.savefig('/data/jprat/plot_beta0.7.png') 181 | 182 | 183 | """ 184 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 185 | All rights reserved. 186 | 187 | This software is covered by US patents and copyright. 188 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 189 | 190 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 191 | 192 | Last Modified: 10/02/2019 193 | """ 194 | -------------------------------------------------------------------------------- /object-locator/get_image_size.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | """ 5 | 6 | get_image_size.py 7 | ==================== 8 | 9 | :Name: get_image_size 10 | :Purpose: extract image dimensions given a file path 11 | 12 | :Author: Paulo Scardine (based on code from Emmanuel VAÏSSE) 13 | 14 | :Created: 26/09/2013 15 | :Copyright: (c) Paulo Scardine 2013 16 | :Licence: MIT 17 | 18 | """ 19 | import collections 20 | import json 21 | import os 22 | import struct 23 | 24 | FILE_UNKNOWN = "Sorry, don't know how to get size for this file." 25 | 26 | 27 | class UnknownImageFormat(Exception): 28 | pass 29 | 30 | 31 | types = collections.OrderedDict() 32 | BMP = types['BMP'] = 'BMP' 33 | GIF = types['GIF'] = 'GIF' 34 | ICO = types['ICO'] = 'ICO' 35 | JPEG = types['JPEG'] = 'JPEG' 36 | PNG = types['PNG'] = 'PNG' 37 | TIFF = types['TIFF'] = 'TIFF' 38 | 39 | image_fields = ['path', 'type', 'file_size', 'width', 'height'] 40 | 41 | 42 | class Image(collections.namedtuple('Image', image_fields)): 43 | 44 | def to_str_row(self): 45 | return ("%d\t%d\t%d\t%s\t%s" % ( 46 | self.width, 47 | self.height, 48 | self.file_size, 49 | self.type, 50 | self.path.replace('\t', '\\t'), 51 | )) 52 | 53 | def to_str_row_verbose(self): 54 | return ("%d\t%d\t%d\t%s\t%s\t##%s" % ( 55 | self.width, 56 | self.height, 57 | self.file_size, 58 | self.type, 59 | self.path.replace('\t', '\\t'), 60 | self)) 61 | 62 | def to_str_json(self, indent=None): 63 | return json.dumps(self._asdict(), indent=indent) 64 | 65 | 66 | def get_image_size(file_path): 67 | """ 68 | Return (width, height) for a given img file content - no external 69 | dependencies except the os and struct builtin modules 70 | """ 71 | img = get_image_metadata(file_path) 72 | return (img.width, img.height) 73 | 74 | 75 | def get_image_metadata(file_path): 76 | """ 77 | Return an `Image` object for a given img file content - no external 78 | dependencies except the os and struct builtin modules 79 | 80 | Args: 81 | file_path (str): path to an image file 82 | 83 | Returns: 84 | Image: (path, type, file_size, width, height) 85 | """ 86 | size = os.path.getsize(file_path) 87 | 88 | # be explicit with open arguments - we need binary mode 89 | with open(file_path, "rb") as input: 90 | height = -1 91 | width = -1 92 | data = input.read(26) 93 | msg = " raised while trying to decode as JPEG." 94 | 95 | if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'): 96 | # GIFs 97 | imgtype = GIF 98 | w, h = struct.unpack("= 24) and data.startswith(b'\211PNG\r\n\032\n') 102 | and (data[12:16] == b'IHDR')): 103 | # PNGs 104 | imgtype = PNG 105 | w, h = struct.unpack(">LL", data[16:24]) 106 | width = int(w) 107 | height = int(h) 108 | elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'): 109 | # older PNGs 110 | imgtype = PNG 111 | w, h = struct.unpack(">LL", data[8:16]) 112 | width = int(w) 113 | height = int(h) 114 | elif (size >= 2) and data.startswith(b'\377\330'): 115 | # JPEG 116 | imgtype = JPEG 117 | input.seek(0) 118 | input.read(2) 119 | b = input.read(1) 120 | try: 121 | while (b and ord(b) != 0xDA): 122 | while (ord(b) != 0xFF): 123 | b = input.read(1) 124 | while (ord(b) == 0xFF): 125 | b = input.read(1) 126 | if (ord(b) >= 0xC0 and ord(b) <= 0xC3): 127 | input.read(3) 128 | h, w = struct.unpack(">HH", input.read(4)) 129 | break 130 | else: 131 | input.read( 132 | int(struct.unpack(">H", input.read(2))[0]) - 2) 133 | b = input.read(1) 134 | width = int(w) 135 | height = int(h) 136 | except struct.error: 137 | raise UnknownImageFormat("StructError" + msg) 138 | except ValueError: 139 | raise UnknownImageFormat("ValueError" + msg) 140 | except Exception as e: 141 | raise UnknownImageFormat(e.__class__.__name__ + msg) 142 | elif (size >= 26) and data.startswith(b'BM'): 143 | # BMP 144 | imgtype = 'BMP' 145 | headersize = struct.unpack("= 40: 151 | w, h = struct.unpack("= 8) and data[:4] in (b"II\052\000", b"MM\000\052"): 160 | # Standard TIFF, big- or little-endian 161 | # BigTIFF and other different but TIFF-like formats are not 162 | # supported currently 163 | imgtype = TIFF 164 | byteOrder = data[:2] 165 | boChar = ">" if byteOrder == "MM" else "<" 166 | # maps TIFF type id to size (in bytes) 167 | # and python format char for struct 168 | tiffTypes = { 169 | 1: (1, boChar + "B"), # BYTE 170 | 2: (1, boChar + "c"), # ASCII 171 | 3: (2, boChar + "H"), # SHORT 172 | 4: (4, boChar + "L"), # LONG 173 | 5: (8, boChar + "LL"), # RATIONAL 174 | 6: (1, boChar + "b"), # SBYTE 175 | 7: (1, boChar + "c"), # UNDEFINED 176 | 8: (2, boChar + "h"), # SSHORT 177 | 9: (4, boChar + "l"), # SLONG 178 | 10: (8, boChar + "ll"), # SRATIONAL 179 | 11: (4, boChar + "f"), # FLOAT 180 | 12: (8, boChar + "d") # DOUBLE 181 | } 182 | ifdOffset = struct.unpack(boChar + "L", data[4:8])[0] 183 | try: 184 | countSize = 2 185 | input.seek(ifdOffset) 186 | ec = input.read(countSize) 187 | ifdEntryCount = struct.unpack(boChar + "H", ec)[0] 188 | # 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4 189 | # bytes: value offset 190 | ifdEntrySize = 12 191 | for i in range(ifdEntryCount): 192 | entryOffset = ifdOffset + countSize + i * ifdEntrySize 193 | input.seek(entryOffset) 194 | tag = input.read(2) 195 | tag = struct.unpack(boChar + "H", tag)[0] 196 | if(tag == 256 or tag == 257): 197 | # if type indicates that value fits into 4 bytes, value 198 | # offset is not an offset but value itself 199 | type = input.read(2) 200 | type = struct.unpack(boChar + "H", type)[0] 201 | if type not in tiffTypes: 202 | raise UnknownImageFormat( 203 | "Unkown TIFF field type:" + 204 | str(type)) 205 | typeSize = tiffTypes[type][0] 206 | typeChar = tiffTypes[type][1] 207 | input.seek(entryOffset + 8) 208 | value = input.read(typeSize) 209 | value = int(struct.unpack(typeChar, value)[0]) 210 | if tag == 256: 211 | width = value 212 | else: 213 | height = value 214 | if width > -1 and height > -1: 215 | break 216 | except Exception as e: 217 | raise UnknownImageFormat(str(e)) 218 | elif size >= 2: 219 | # see http://en.wikipedia.org/wiki/ICO_(file_format) 220 | imgtype = 'ICO' 221 | input.seek(0) 222 | reserved = input.read(2) 223 | if 0 != struct.unpack(" 1: 230 | import warnings 231 | warnings.warn("ICO File contains more than one image") 232 | # http://msdn.microsoft.com/en-us/library/ms997538.aspx 233 | w = input.read(1) 234 | h = input.read(1) 235 | width = ord(w) 236 | height = ord(h) 237 | else: 238 | raise UnknownImageFormat(FILE_UNKNOWN) 239 | 240 | return Image(path=file_path, 241 | type=imgtype, 242 | file_size=size, 243 | width=width, 244 | height=height) 245 | 246 | 247 | import unittest 248 | 249 | 250 | class Test_get_image_size(unittest.TestCase): 251 | data = [{ 252 | 'path': 'lookmanodeps.png', 253 | 'width': 251, 254 | 'height': 208, 255 | 'file_size': 22228, 256 | 'type': 'PNG'}] 257 | 258 | def setUp(self): 259 | pass 260 | 261 | def test_get_image_metadata(self): 262 | img = self.data[0] 263 | output = get_image_metadata(img['path']) 264 | self.assertTrue(output) 265 | self.assertEqual(output.path, img['path']) 266 | self.assertEqual(output.width, img['width']) 267 | self.assertEqual(output.height, img['height']) 268 | self.assertEqual(output.type, img['type']) 269 | self.assertEqual(output.file_size, img['file_size']) 270 | for field in image_fields: 271 | self.assertEqual(getattr(output, field), img[field]) 272 | 273 | def test_get_image_metadata__ENOENT_OSError(self): 274 | with self.assertRaises(OSError): 275 | get_image_metadata('THIS_DOES_NOT_EXIST') 276 | 277 | def test_get_image_metadata__not_an_image_UnknownImageFormat(self): 278 | with self.assertRaises(UnknownImageFormat): 279 | get_image_metadata('README.rst') 280 | 281 | def test_get_image_size(self): 282 | img = self.data[0] 283 | output = get_image_size(img['path']) 284 | self.assertTrue(output) 285 | self.assertEqual(output, 286 | (img['width'], 287 | img['height'])) 288 | 289 | def tearDown(self): 290 | pass 291 | 292 | 293 | def main(argv=None): 294 | """ 295 | Print image metadata fields for the given file path. 296 | 297 | Keyword Arguments: 298 | argv (list): commandline arguments (e.g. sys.argv[1:]) 299 | Returns: 300 | int: zero for OK 301 | """ 302 | import logging 303 | import optparse 304 | import sys 305 | 306 | prs = optparse.OptionParser( 307 | usage="%prog [-v|--verbose] [--json|--json-indent] []", 308 | description="Print metadata for the given image paths " 309 | "(without image library bindings).") 310 | 311 | prs.add_option('--json', 312 | dest='json', 313 | action='store_true') 314 | prs.add_option('--json-indent', 315 | dest='json_indent', 316 | action='store_true') 317 | 318 | prs.add_option('-v', '--verbose', 319 | dest='verbose', 320 | action='store_true',) 321 | prs.add_option('-q', '--quiet', 322 | dest='quiet', 323 | action='store_true',) 324 | prs.add_option('-t', '--test', 325 | dest='run_tests', 326 | action='store_true',) 327 | 328 | argv = list(argv) if argv is not None else sys.argv[1:] 329 | (opts, args) = prs.parse_args(args=argv) 330 | loglevel = logging.INFO 331 | if opts.verbose: 332 | loglevel = logging.DEBUG 333 | elif opts.quiet: 334 | loglevel = logging.ERROR 335 | logging.basicConfig(level=loglevel) 336 | log = logging.getLogger() 337 | log.debug('argv: %r', argv) 338 | log.debug('opts: %r', opts) 339 | log.debug('args: %r', args) 340 | 341 | if opts.run_tests: 342 | import sys 343 | sys.argv = [sys.argv[0]] + args 344 | import unittest 345 | return unittest.main() 346 | 347 | output_func = Image.to_str_row 348 | if opts.json_indent: 349 | import functools 350 | output_func = functools.partial(Image.to_str_json, indent=2) 351 | elif opts.json: 352 | output_func = Image.to_str_json 353 | elif opts.verbose: 354 | output_func = Image.to_str_row_verbose 355 | 356 | EX_OK = 0 357 | EX_NOT_OK = 2 358 | 359 | if len(args) < 1: 360 | prs.print_help() 361 | print('') 362 | prs.error("You must specify one or more paths to image files") 363 | 364 | errors = [] 365 | for path_arg in args: 366 | try: 367 | img = get_image_metadata(path_arg) 368 | print(output_func(img)) 369 | except KeyboardInterrupt: 370 | raise 371 | except OSError as e: 372 | log.error((path_arg, e)) 373 | errors.append((path_arg, e)) 374 | except Exception as e: 375 | log.exception(e) 376 | errors.append((path_arg, e)) 377 | pass 378 | if len(errors): 379 | import pprint 380 | print("ERRORS", file=sys.stderr) 381 | print("======", file=sys.stderr) 382 | print(pprint.pformat(errors, indent=2), file=sys.stderr) 383 | return EX_NOT_OK 384 | return EX_OK 385 | 386 | 387 | if __name__ == "__main__": 388 | import sys 389 | sys.exit(main(argv=sys.argv[1:])) 390 | -------------------------------------------------------------------------------- /object-locator/locate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | __copyright__ = \ 4 | """ 5 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 6 | All rights reserved. 7 | 8 | This software is covered by US patents and copyright. 9 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 10 | 11 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 12 | 13 | Last Modified: 10/02/2019 14 | """ 15 | __license__ = "CC BY-NC-SA 4.0" 16 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 17 | __version__ = "1.6.0" 18 | 19 | 20 | import argparse 21 | import os 22 | import sys 23 | import time 24 | import shutil 25 | from parse import parse 26 | import math 27 | from collections import OrderedDict 28 | import itertools 29 | 30 | import matplotlib 31 | matplotlib.use('Agg') 32 | import cv2 33 | from tqdm import tqdm 34 | import numpy as np 35 | import pandas as pd 36 | import skimage.io 37 | import torch 38 | from torch import nn 39 | from torch.autograd import Variable 40 | from torch.utils import data 41 | from torchvision import datasets 42 | from torchvision import transforms 43 | import torchvision as tv 44 | from torchvision.models import inception_v3 45 | import skimage.transform 46 | from peterpy import peter 47 | from ballpark import ballpark 48 | 49 | from .data import csv_collator 50 | from .data import ScaleImageAndLabel 51 | from .data import build_dataset 52 | from . import losses 53 | from . import argparser 54 | from .models import unet_model 55 | from .metrics import Judge 56 | from .metrics import make_metric_plots 57 | from . import utils 58 | 59 | 60 | # Parse command line arguments 61 | args = argparser.parse_command_args('testing') 62 | 63 | # Tensor type to use, select CUDA or not 64 | torch.set_default_dtype(torch.float32) 65 | device_cpu = torch.device('cpu') 66 | device = torch.device('cuda') if args.cuda else device_cpu 67 | 68 | # Set seeds 69 | np.random.seed(args.seed) 70 | torch.manual_seed(args.seed) 71 | if args.cuda: 72 | torch.cuda.manual_seed_all(args.seed) 73 | 74 | # Data loading code 75 | try: 76 | testset = build_dataset(args.dataset, 77 | transforms=transforms.Compose([ 78 | ScaleImageAndLabel(size=(args.height, 79 | args.width)), 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.5, 0.5, 0.5), 82 | (0.5, 0.5, 0.5)), 83 | ]), 84 | ignore_gt=not args.evaluate, 85 | max_dataset_size=args.max_testset_size) 86 | except ValueError as e: 87 | print(f'E: {e}') 88 | exit(-1) 89 | testset_loader = data.DataLoader(testset, 90 | batch_size=1, 91 | num_workers=args.nThreads, 92 | collate_fn=csv_collator) 93 | 94 | # Array with [height, width] of the new size 95 | resized_size = np.array([args.height, args.width]) 96 | 97 | # Loss function 98 | criterion_training = losses.WeightedHausdorffDistance(resized_height=args.height, 99 | resized_width=args.width, 100 | return_2_terms=True, 101 | device=device) 102 | 103 | # Restore saved checkpoint (model weights) 104 | with peter("Loading checkpoint"): 105 | 106 | if os.path.isfile(args.model): 107 | if args.cuda: 108 | checkpoint = torch.load(args.model) 109 | else: 110 | checkpoint = torch.load( 111 | args.model, map_location=lambda storage, loc: storage) 112 | # Model 113 | if args.n_points is None: 114 | if 'n_points' not in checkpoint: 115 | # Model will also estimate # of points 116 | model = unet_model.UNet(3, 1, 117 | known_n_points=None, 118 | height=args.height, 119 | width=args.width, 120 | ultrasmall=args.ultrasmallnet) 121 | 122 | else: 123 | # The checkpoint tells us the # of points to estimate 124 | model = unet_model.UNet(3, 1, 125 | known_n_points=checkpoint['n_points'], 126 | height=args.height, 127 | width=args.width, 128 | ultrasmall=args.ultrasmallnet) 129 | else: 130 | # The user tells us the # of points to estimate 131 | model = unet_model.UNet(3, 1, 132 | known_n_points=args.n_points, 133 | height=args.height, 134 | width=args.width, 135 | ultrasmall=args.ultrasmallnet) 136 | 137 | # Parallelize 138 | if args.cuda: 139 | model = nn.DataParallel(model) 140 | model = model.to(device) 141 | 142 | # Load model in checkpoint 143 | if args.cuda: 144 | state_dict = checkpoint['model'] 145 | else: 146 | # remove 'module.' of DataParallel 147 | state_dict = OrderedDict() 148 | for k, v in checkpoint['model'].items(): 149 | name = k[7:] 150 | state_dict[name] = v 151 | model.load_state_dict(state_dict) 152 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 153 | print(f"\n\__ loaded checkpoint '{args.model}' " 154 | f"with {ballpark(num_params)} trainable parameters") 155 | # print(model) 156 | else: 157 | print(f"\n\__ E: no checkpoint found at '{args.model}'") 158 | exit(-1) 159 | 160 | tic = time.time() 161 | 162 | 163 | # Set the module in evaluation mode 164 | model.eval() 165 | 166 | # Accumulative histogram of estimated maps 167 | bmm_tracker = utils.AccBetaMixtureModel() 168 | 169 | 170 | if testset.there_is_gt: 171 | # Prepare Judges that will compute P/R as fct of r and th 172 | judges = [] 173 | for r, th in itertools.product(args.radii, args.taus): 174 | judge = Judge(r=r) 175 | judge.th = th 176 | judges.append(judge) 177 | 178 | # Empty output CSV (one per threshold) 179 | df_outs = [pd.DataFrame() for _ in args.taus] 180 | 181 | # --force will overwrite output directory 182 | if args.force: 183 | shutil.rmtree(args.out) 184 | 185 | for batch_idx, (imgs, dictionaries) in tqdm(enumerate(testset_loader), 186 | total=len(testset_loader)): 187 | 188 | # Move to device 189 | imgs = imgs.to(device) 190 | 191 | # Pull info from this batch and move to device 192 | if testset.there_is_gt: 193 | target_locations = [dictt['locations'].to(device) 194 | for dictt in dictionaries] 195 | target_count = [dictt['count'].to(device) 196 | for dictt in dictionaries] 197 | 198 | target_orig_heights = [dictt['orig_height'].to(device) 199 | for dictt in dictionaries] 200 | target_orig_widths = [dictt['orig_width'].to(device) 201 | for dictt in dictionaries] 202 | 203 | # Lists -> Tensor batches 204 | if testset.there_is_gt: 205 | target_count = torch.stack(target_count) 206 | target_orig_heights = torch.stack(target_orig_heights) 207 | target_orig_widths = torch.stack(target_orig_widths) 208 | target_orig_sizes = torch.stack((target_orig_heights, 209 | target_orig_widths)).transpose(0, 1) 210 | origsize = (dictionaries[0]['orig_height'].item(), 211 | dictionaries[0]['orig_width'].item()) 212 | 213 | # Tensor -> float & numpy 214 | if testset.there_is_gt: 215 | target_count = target_count.item() 216 | target_locations = \ 217 | target_locations[0].to(device_cpu).numpy().reshape(-1, 2) 218 | target_orig_size = \ 219 | target_orig_sizes[0].to(device_cpu).numpy().reshape(2) 220 | 221 | normalzr = utils.Normalizer(args.height, args.width) 222 | 223 | # Feed forward 224 | with torch.no_grad(): 225 | est_maps, est_count = model.forward(imgs) 226 | 227 | # Convert to original size 228 | est_map_np = est_maps[0, :, :].to(device_cpu).numpy() 229 | est_map_np_origsize = \ 230 | skimage.transform.resize(est_map_np, 231 | output_shape=origsize, 232 | mode='constant') 233 | orig_img_np = imgs[0].to(device_cpu).squeeze().numpy() 234 | orig_img_np_origsize = ((skimage.transform.resize(orig_img_np.transpose((1, 2, 0)), 235 | output_shape=origsize, 236 | mode='constant') + 1) / 2.0 * 255.0).\ 237 | astype(np.float32).transpose((2, 0, 1)) 238 | 239 | # Overlay output on original image as a heatmap 240 | orig_img_w_heatmap_origsize = utils.overlay_heatmap(img=orig_img_np_origsize, 241 | map=est_map_np_origsize).\ 242 | astype(np.float32) 243 | 244 | # Save estimated map to disk 245 | os.makedirs(os.path.join(args.out, 'intermediate', 'estimated_map'), 246 | exist_ok=True) 247 | cv2.imwrite(os.path.join(args.out, 248 | 'intermediate', 249 | 'estimated_map', 250 | dictionaries[0]['filename']), 251 | orig_img_w_heatmap_origsize.transpose((1, 2, 0))[:, :, ::-1]) 252 | 253 | # Tensor -> int 254 | est_count_int = int(round(est_count.item())) 255 | 256 | # The estimated map must be thresholded to obtain estimated points 257 | for t, tau in enumerate(args.taus): 258 | if tau != -2: 259 | mask, _ = utils.threshold(est_map_np_origsize, tau) 260 | else: 261 | mask, _, mix = utils.threshold(est_map_np_origsize, tau) 262 | bmm_tracker.feed(mix) 263 | centroids_wrt_orig = utils.cluster(mask, est_count_int, 264 | max_mask_pts=args.max_mask_pts) 265 | 266 | # Save thresholded map to disk 267 | os.makedirs(os.path.join(args.out, 268 | 'intermediate', 269 | 'estimated_map_thresholded', 270 | f'tau={round(tau, 4)}'), 271 | exist_ok=True) 272 | cv2.imwrite(os.path.join(args.out, 273 | 'intermediate', 274 | 'estimated_map_thresholded', 275 | f'tau={round(tau, 4)}', 276 | dictionaries[0]['filename']), 277 | mask) 278 | 279 | # Paint red dots if user asked for it 280 | if args.paint: 281 | # Paint a cross at the estimated centroids 282 | img_with_x_n_map = utils.paint_circles(img=orig_img_w_heatmap_origsize, 283 | points=centroids_wrt_orig, 284 | color='red', 285 | crosshair=True) 286 | # Save to disk 287 | os.makedirs(os.path.join(args.out, 288 | 'intermediate', 289 | 'painted_on_estimated_map', 290 | f'tau={round(tau, 4)}'), exist_ok=True) 291 | cv2.imwrite(os.path.join(args.out, 292 | 'intermediate', 293 | 'painted_on_estimated_map', 294 | f'tau={round(tau, 4)}', 295 | dictionaries[0]['filename']), 296 | img_with_x_n_map.transpose((1, 2, 0))[:, :, ::-1]) 297 | # Paint a cross at the estimated centroids 298 | img_with_x = utils.paint_circles(img=orig_img_np_origsize, 299 | points=centroids_wrt_orig, 300 | color='red', 301 | crosshair=True) 302 | # Save to disk 303 | os.makedirs(os.path.join(args.out, 304 | 'intermediate', 305 | 'painted_on_original', 306 | f'tau={round(tau, 4)}'), exist_ok=True) 307 | cv2.imwrite(os.path.join(args.out, 308 | 'intermediate', 309 | 'painted_on_original', 310 | f'tau={round(tau, 4)}', 311 | dictionaries[0]['filename']), 312 | img_with_x.transpose((1, 2, 0))[:, :, ::-1]) 313 | 314 | 315 | if args.evaluate: 316 | target_locations_wrt_orig = normalzr.unnormalize(target_locations, 317 | orig_img_size=target_orig_size) 318 | 319 | # Compute metrics for each value of r (for each Judge) 320 | for judge in judges: 321 | if judge.th != tau: 322 | continue 323 | judge.feed_points(centroids_wrt_orig, target_locations_wrt_orig, 324 | max_ahd=criterion_training.max_dist) 325 | judge.feed_count(est_count_int, target_count) 326 | 327 | # Save a new line in the CSV corresonding to the resuls of this img 328 | res_dict = dictionaries[0] 329 | res_dict['count'] = est_count_int 330 | res_dict['locations'] = str(centroids_wrt_orig.tolist()) 331 | for key, val in res_dict.copy().items(): 332 | if 'height' in key or 'width' in key: 333 | del res_dict[key] 334 | df = pd.DataFrame(data={idx: [val] for idx, val in res_dict.items()}) 335 | df = df.set_index('filename') 336 | df_outs[t] = df_outs[t].append(df) 337 | 338 | # Write CSVs to disk 339 | os.makedirs(os.path.join(args.out, 'estimations'), exist_ok=True) 340 | for df_out, tau in zip(df_outs, args.taus): 341 | df_out.to_csv(os.path.join(args.out, 342 | 'estimations', 343 | f'estimations_tau={round(tau, 4)}.csv')) 344 | 345 | os.makedirs(os.path.join(args.out, 'intermediate', 'metrics_plots'), 346 | exist_ok=True) 347 | 348 | if args.evaluate: 349 | 350 | with peter("Evauating metrics"): 351 | 352 | # Output CSV where we will put 353 | # all our metrics as a function of r and the threshold 354 | df_metrics = pd.DataFrame(columns=['r', 'th', 355 | 'precision', 'recall', 'fscore', 'MAHD', 356 | 'MAPE', 'ME', 'MPE', 'MAE', 357 | 'MSE', 'RMSE', 'r', 'R2']) 358 | df_metrics.index.name = 'idx' 359 | 360 | for j, judge in enumerate(tqdm(judges)): 361 | # Accumulate precision and recall in the CSV dataframe 362 | df = pd.DataFrame(data=[[judge.r, 363 | judge.th, 364 | judge.precision, 365 | judge.recall, 366 | judge.fscore, 367 | judge.mahd, 368 | judge.mape, 369 | judge.me, 370 | judge.mpe, 371 | judge.mae, 372 | judge.mse, 373 | judge.rmse, 374 | judge.pearson_corr, 375 | judge.coeff_of_determination]], 376 | columns=['r', 'th', 377 | 'precision', 'recall', 'fscore', 'MAHD', 378 | 'MAPE', 'ME', 'MPE', 'MAE', 379 | 'MSE', 'RMSE', 'r', 'R2'], 380 | index=[j]) 381 | df.index.name = 'idx' 382 | df_metrics = df_metrics.append(df) 383 | 384 | # Write CSV of metrics to disk 385 | df_metrics.to_csv(os.path.join(args.out, 'metrics.csv')) 386 | 387 | # Generate plots 388 | figs = make_metric_plots(csv_path=os.path.join(args.out, 'metrics.csv'), 389 | taus=args.taus, 390 | radii=args.radii) 391 | for label, fig in figs.items(): 392 | # Save to disk 393 | fig.savefig(os.path.join(args.out, 394 | 'intermediate', 395 | 'metrics_plots', 396 | f'{label}.png')) 397 | 398 | 399 | # Save plot figures of the statistics of the BMM-based threshold 400 | if -2 in args.taus: 401 | for label, fig in bmm_tracker.plot().items(): 402 | fig.savefig(os.path.join(args.out, 403 | 'intermediate', 404 | 'metrics_plots', 405 | f'{label}.png')) 406 | 407 | 408 | elapsed_time = int(time.time() - tic) 409 | print(f'It took {elapsed_time} seconds to evaluate all this dataset.') 410 | 411 | 412 | """ 413 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 414 | All rights reserved. 415 | 416 | This software is covered by US patents and copyright. 417 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 418 | 419 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 420 | 421 | Last Modified: 10/02/2019 422 | """ 423 | -------------------------------------------------------------------------------- /object-locator/logger.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import visdom 19 | import torch 20 | import numbers 21 | from . import utils 22 | 23 | from torch.autograd import Variable 24 | 25 | class Logger(): 26 | def __init__(self, 27 | server=None, 28 | port=8989, 29 | env_name='main'): 30 | """ 31 | Logger that connects to a Visdom server 32 | and sends training losses/metrics and images of any kind. 33 | 34 | :param server: Host name of the server (e.g, http://localhost), 35 | without the port number. If None, 36 | this Logger will do nothing at all 37 | (it will not connect to any server, 38 | and the functions here will do nothing). 39 | :param port: Port number of the Visdom server. 40 | :param env_name: Name of the environment within the Visdom 41 | server where everything you sent to it will go. 42 | :param terms_legends: Legend of each term. 43 | """ 44 | 45 | if server is None: 46 | self.train_losses = utils.nothing 47 | self.val_losses = utils.nothing 48 | self.image = utils.nothing 49 | print('W: Not connected to any Visdom server. ' 50 | 'You will not visualize any training/validation plot ' 51 | 'or intermediate image') 52 | else: 53 | # Connect to Visdom 54 | self.client = visdom.Visdom(server=server, 55 | env=env_name, 56 | port=port) 57 | if self.client.check_connection(): 58 | print(f'Connected to Visdom server ' 59 | f'{server}:{port}') 60 | else: 61 | print(f'E: cannot connect to Visdom server ' 62 | f'{server}:{port}') 63 | exit(-1) 64 | 65 | # Each of the 'windows' in visdom web panel 66 | self.viz_train_input_win = None 67 | self.viz_train_loss_win = None 68 | self.viz_train_gt_win = None 69 | self.viz_train_est_win = None 70 | self.viz_val_input_win = None 71 | self.viz_val_loss_win = None 72 | self.viz_val_gt_win = None 73 | self.viz_val_est_win = None 74 | 75 | # Visdom only supports CPU Tensors 76 | self.device = torch.device("cpu") 77 | 78 | 79 | def train_losses(self, terms, iteration_number, terms_legends=None): 80 | """ 81 | Plot a new point of the training losses (scalars) to Visdom. 82 | All losses will be plotted in the same figure/window. 83 | 84 | :param terms: List of scalar losses. 85 | Each element will be a different plot in the y axis. 86 | :param iteration_number: Value of the x axis in the plot. 87 | :param terms_legends: Legend of each term. 88 | """ 89 | 90 | # Watch dog 91 | if terms_legends is not None and \ 92 | len(terms) != len(terms_legends): 93 | raise ValueError('The number of "terms" and "terms_legends" must be equal, got %s and %s, respectively' 94 | % (len(terms), len(terms_legends))) 95 | if not isinstance(iteration_number, numbers.Number): 96 | raise ValueError('iteration_number must be a number, got %s' 97 | % iteration_number) 98 | 99 | # Make terms CPU Tensors 100 | curated_terms = [] 101 | for term in terms: 102 | if isinstance(term, numbers.Number): 103 | curated_term = torch.tensor([term]) 104 | elif isinstance(term, torch.Tensor): 105 | curated_term = term 106 | else: 107 | raise ValueError('there is a term with an unsupported type' 108 | f'({type(term)}') 109 | curated_term = curated_term.to(self.device) 110 | curated_term = curated_term.view(1) 111 | curated_terms.append(curated_term) 112 | 113 | y = torch.cat(curated_terms).view(1, -1).data 114 | x = torch.Tensor([iteration_number]).repeat(1, len(terms)) 115 | if terms_legends is None: 116 | terms_legends = ['Term %s' % t 117 | for t in range(1, len(terms) + 1)] 118 | 119 | # Send training loss to Visdom 120 | self.win_train_loss = \ 121 | self.client.line(Y=y, 122 | X=x, 123 | opts=dict(title='Training', 124 | legend=terms_legends, 125 | ylabel='Loss', 126 | xlabel='Epoch'), 127 | update='append', 128 | win='train_losses') 129 | if self.win_train_loss == 'win does not exist': 130 | self.win_train_loss = \ 131 | self.client.line(Y=y, 132 | X=x, 133 | opts=dict(title='Training', 134 | legend=terms_legends, 135 | ylabel='Loss', 136 | xlabel='Epoch'), 137 | win='train_losses') 138 | 139 | def image(self, imgs, titles, window_ids): 140 | """Send images to Visdom. 141 | Each image will be shown in a different window/plot. 142 | 143 | :param imgs: List of numpy images. 144 | :param titles: List of titles of each image. 145 | :param window_ids: List of window IDs. 146 | """ 147 | 148 | # Watchdog 149 | if not(len(imgs) == len(titles) == len(window_ids)): 150 | raise ValueError('The number of "imgs", "titles" and ' 151 | '"window_ids" must be equal, got ' 152 | '%s, %s and %s, respectively' 153 | % (len(imgs), len(titles), len(window_ids))) 154 | 155 | for img, title, win in zip(imgs, titles, window_ids): 156 | self.client.image(img, 157 | opts=dict(title=title), 158 | win=str(win)) 159 | 160 | def val_losses(self, terms, iteration_number, terms_legends=None): 161 | """ 162 | Plot a new point of the training losses (scalars) to Visdom. All losses will be plotted in the same figure/window. 163 | 164 | :param terms: List of scalar losses. 165 | Each element will be a different plot in the y axis. 166 | :param iteration_number: Value of the x axis in the plot. 167 | :param terms_legends: Legend of each term. 168 | """ 169 | 170 | # Watchdog 171 | if terms_legends is not None and \ 172 | len(terms) != len(terms_legends): 173 | raise ValueError('The number of "terms" and "terms_legends" must be equal, got %s and %s, respectively' 174 | % (len(terms), len(terms_legends))) 175 | if not isinstance(iteration_number, numbers.Number): 176 | raise ValueError('iteration_number must be a number, got %s' 177 | % iteration_number) 178 | 179 | # Make terms CPU Tensors 180 | curated_terms = [] 181 | for term in terms: 182 | if isinstance(term, numbers.Number): 183 | curated_term = torch.tensor([term], 184 | dtype=torch.get_default_dtype()) 185 | elif isinstance(term, torch.Tensor): 186 | curated_term = term 187 | else: 188 | raise ValueError('there is a term with an unsupported type' 189 | f'({type(term)}') 190 | curated_term = curated_term.to(self.device) 191 | curated_term = curated_term.view(1) 192 | curated_terms.append(curated_term) 193 | 194 | y = torch.stack(curated_terms).view(1, -1) 195 | x = torch.Tensor([iteration_number]).repeat(1, len(terms)) 196 | if terms_legends is None: 197 | terms_legends = ['Term %s' % t for t in range(1, len(terms) + 1)] 198 | 199 | # Send validation loss to Visdom 200 | self.win_val_loss = \ 201 | self.client.line(Y=y, 202 | X=x, 203 | opts=dict(title='Validation', 204 | legend=terms_legends, 205 | ylabel='Loss', 206 | xlabel='Epoch'), 207 | update='append', 208 | win='val_metrics') 209 | if self.win_val_loss == 'win does not exist': 210 | self.win_val_loss = \ 211 | self.client.line(Y=y, 212 | X=x, 213 | opts=dict(title='Validation', 214 | legend=terms_legends, 215 | ylabel='Loss', 216 | xlabel='Epoch'), 217 | win='val_metrics') 218 | 219 | 220 | """ 221 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 222 | All rights reserved. 223 | 224 | This software is covered by US patents and copyright. 225 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 226 | 227 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 228 | 229 | Last Modified: 10/02/2019 230 | """ 231 | -------------------------------------------------------------------------------- /object-locator/losses.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import math 19 | import torch 20 | from sklearn.utils.extmath import cartesian 21 | import numpy as np 22 | from torch.nn import functional as F 23 | import os 24 | import time 25 | from sklearn.metrics.pairwise import pairwise_distances 26 | from sklearn.neighbors.kde import KernelDensity 27 | import skimage.io 28 | from matplotlib import pyplot as plt 29 | from torch import nn 30 | 31 | 32 | torch.set_default_dtype(torch.float32) 33 | 34 | 35 | def _assert_no_grad(variables): 36 | for var in variables: 37 | assert not var.requires_grad, \ 38 | "nn criterions don't compute the gradient w.r.t. targets - please " \ 39 | "mark these variables as volatile or not requiring gradients" 40 | 41 | 42 | def cdist(x, y): 43 | """ 44 | Compute distance between each pair of the two collections of inputs. 45 | :param x: Nxd Tensor 46 | :param y: Mxd Tensor 47 | :res: NxM matrix where dist[i,j] is the norm between x[i,:] and y[j,:], 48 | i.e. dist[i,j] = ||x[i,:]-y[j,:]|| 49 | 50 | """ 51 | differences = x.unsqueeze(1) - y.unsqueeze(0) 52 | distances = torch.sum(differences**2, -1).sqrt() 53 | return distances 54 | 55 | 56 | def averaged_hausdorff_distance(set1, set2, max_ahd=np.inf): 57 | """ 58 | Compute the Averaged Hausdorff Distance function 59 | between two unordered sets of points (the function is symmetric). 60 | Batches are not supported, so squeeze your inputs first! 61 | :param set1: Array/list where each row/element is an N-dimensional point. 62 | :param set2: Array/list where each row/element is an N-dimensional point. 63 | :param max_ahd: Maximum AHD possible to return if any set is empty. Default: inf. 64 | :return: The Averaged Hausdorff Distance between set1 and set2. 65 | """ 66 | 67 | if len(set1) == 0 or len(set2) == 0: 68 | return max_ahd 69 | 70 | set1 = np.array(set1) 71 | set2 = np.array(set2) 72 | 73 | assert set1.ndim == 2, 'got %s' % set1.ndim 74 | assert set2.ndim == 2, 'got %s' % set2.ndim 75 | 76 | assert set1.shape[1] == set2.shape[1], \ 77 | 'The points in both sets must have the same number of dimensions, got %s and %s.'\ 78 | % (set2.shape[1], set2.shape[1]) 79 | 80 | d2_matrix = pairwise_distances(set1, set2, metric='euclidean') 81 | 82 | res = np.average(np.min(d2_matrix, axis=0)) + \ 83 | np.average(np.min(d2_matrix, axis=1)) 84 | 85 | return res 86 | 87 | 88 | class AveragedHausdorffLoss(nn.Module): 89 | def __init__(self): 90 | super(nn.Module, self).__init__() 91 | 92 | def forward(self, set1, set2): 93 | """ 94 | Compute the Averaged Hausdorff Distance function 95 | between two unordered sets of points (the function is symmetric). 96 | Batches are not supported, so squeeze your inputs first! 97 | :param set1: Tensor where each row is an N-dimensional point. 98 | :param set2: Tensor where each row is an N-dimensional point. 99 | :return: The Averaged Hausdorff Distance between set1 and set2. 100 | """ 101 | 102 | assert set1.ndimension() == 2, 'got %s' % set1.ndimension() 103 | assert set2.ndimension() == 2, 'got %s' % set2.ndimension() 104 | 105 | assert set1.size()[1] == set2.size()[1], \ 106 | 'The points in both sets must have the same number of dimensions, got %s and %s.'\ 107 | % (set2.size()[1], set2.size()[1]) 108 | 109 | d2_matrix = cdist(set1, set2) 110 | 111 | # Modified Chamfer Loss 112 | term_1 = torch.mean(torch.min(d2_matrix, 1)[0]) 113 | term_2 = torch.mean(torch.min(d2_matrix, 0)[0]) 114 | 115 | res = term_1 + term_2 116 | 117 | return res 118 | 119 | 120 | class WeightedHausdorffDistance(nn.Module): 121 | def __init__(self, 122 | resized_height, resized_width, 123 | p=-9, 124 | return_2_terms=False, 125 | device=torch.device('cpu')): 126 | """ 127 | :param resized_height: Number of rows in the image. 128 | :param resized_width: Number of columns in the image. 129 | :param p: Exponent in the generalized mean. -inf makes it the minimum. 130 | :param return_2_terms: Whether to return the 2 terms 131 | of the WHD instead of their sum. 132 | Default: False. 133 | :param device: Device where all Tensors will reside. 134 | """ 135 | super(nn.Module, self).__init__() 136 | 137 | # Prepare all possible (row, col) locations in the image 138 | self.height, self.width = resized_height, resized_width 139 | self.resized_size = torch.tensor([resized_height, 140 | resized_width], 141 | dtype=torch.get_default_dtype(), 142 | device=device) 143 | self.max_dist = math.sqrt(resized_height**2 + resized_width**2) 144 | self.n_pixels = resized_height * resized_width 145 | self.all_img_locations = torch.from_numpy(cartesian([np.arange(resized_height), 146 | np.arange(resized_width)])) 147 | # Convert to appropiate type 148 | self.all_img_locations = self.all_img_locations.to(device=device, 149 | dtype=torch.get_default_dtype()) 150 | 151 | self.return_2_terms = return_2_terms 152 | self.p = p 153 | 154 | def forward(self, prob_map, gt, orig_sizes): 155 | """ 156 | Compute the Weighted Hausdorff Distance function 157 | between the estimated probability map and ground truth points. 158 | The output is the WHD averaged through all the batch. 159 | 160 | :param prob_map: (B x H x W) Tensor of the probability map of the estimation. 161 | B is batch size, H is height and W is width. 162 | Values must be between 0 and 1. 163 | :param gt: List of Tensors of the Ground Truth points. 164 | Must be of size B as in prob_map. 165 | Each element in the list must be a 2D Tensor, 166 | where each row is the (y, x), i.e, (row, col) of a GT point. 167 | :param orig_sizes: Bx2 Tensor containing the size 168 | of the original images. 169 | B is batch size. 170 | The size must be in (height, width) format. 171 | :param orig_widths: List of the original widths for each image 172 | in the batch. 173 | :return: Single-scalar Tensor with the Weighted Hausdorff Distance. 174 | If self.return_2_terms=True, then return a tuple containing 175 | the two terms of the Weighted Hausdorff Distance. 176 | """ 177 | 178 | _assert_no_grad(gt) 179 | 180 | assert prob_map.dim() == 3, 'The probability map must be (B x H x W)' 181 | assert prob_map.size()[1:3] == (self.height, self.width), \ 182 | 'You must configure the WeightedHausdorffDistance with the height and width of the ' \ 183 | 'probability map that you are using, got a probability map of size %s'\ 184 | % str(prob_map.size()) 185 | 186 | batch_size = prob_map.shape[0] 187 | assert batch_size == len(gt) 188 | 189 | terms_1 = [] 190 | terms_2 = [] 191 | for b in range(batch_size): 192 | 193 | # One by one 194 | prob_map_b = prob_map[b, :, :] 195 | gt_b = gt[b] 196 | orig_size_b = orig_sizes[b, :] 197 | norm_factor = (orig_size_b/self.resized_size).unsqueeze(0) 198 | n_gt_pts = gt_b.size()[0] 199 | 200 | # Corner case: no GT points 201 | if gt_b.ndimension() == 1 and (gt_b < 0).all().item() == 0: 202 | terms_1.append(torch.tensor([0], 203 | dtype=torch.get_default_dtype())) 204 | terms_2.append(torch.tensor([self.max_dist], 205 | dtype=torch.get_default_dtype())) 206 | continue 207 | 208 | # Pairwise distances between all possible locations and the GTed locations 209 | n_gt_pts = gt_b.size()[0] 210 | normalized_x = norm_factor.repeat(self.n_pixels, 1) *\ 211 | self.all_img_locations 212 | normalized_y = norm_factor.repeat(len(gt_b), 1)*gt_b 213 | d_matrix = cdist(normalized_x, normalized_y) 214 | 215 | # Reshape probability map as a long column vector, 216 | # and prepare it for multiplication 217 | p = prob_map_b.view(prob_map_b.nelement()) 218 | n_est_pts = p.sum() 219 | p_replicated = p.view(-1, 1).repeat(1, n_gt_pts) 220 | 221 | # Weighted Hausdorff Distance 222 | term_1 = (1 / (n_est_pts + 1e-6)) * \ 223 | torch.sum(p * torch.min(d_matrix, 1)[0]) 224 | weighted_d_matrix = (1 - p_replicated)*self.max_dist + p_replicated*d_matrix 225 | minn = generaliz_mean(weighted_d_matrix, 226 | p=self.p, 227 | dim=0, keepdim=False) 228 | term_2 = torch.mean(minn) 229 | 230 | # terms_1[b] = term_1 231 | # terms_2[b] = term_2 232 | terms_1.append(term_1) 233 | terms_2.append(term_2) 234 | 235 | terms_1 = torch.stack(terms_1) 236 | terms_2 = torch.stack(terms_2) 237 | 238 | if self.return_2_terms: 239 | res = terms_1.mean(), terms_2.mean() 240 | else: 241 | res = terms_1.mean() + terms_2.mean() 242 | 243 | return res 244 | 245 | 246 | def generaliz_mean(tensor, dim, p=-9, keepdim=False): 247 | # """ 248 | # Computes the softmin along some axes. 249 | # Softmin is the same as -softmax(-x), i.e, 250 | # softmin(x) = -log(sum_i(exp(-x_i))) 251 | 252 | # The smoothness of the operator is controlled with k: 253 | # softmin(x) = -log(sum_i(exp(-k*x_i)))/k 254 | 255 | # :param input: Tensor of any dimension. 256 | # :param dim: (int or tuple of ints) The dimension or dimensions to reduce. 257 | # :param keepdim: (bool) Whether the output tensor has dim retained or not. 258 | # :param k: (float>0) How similar softmin is to min (the lower the more smooth). 259 | # """ 260 | # return -torch.log(torch.sum(torch.exp(-k*input), dim, keepdim))/k 261 | """ 262 | The generalized mean. It corresponds to the minimum when p = -inf. 263 | https://en.wikipedia.org/wiki/Generalized_mean 264 | :param tensor: Tensor of any dimension. 265 | :param dim: (int or tuple of ints) The dimension or dimensions to reduce. 266 | :param keepdim: (bool) Whether the output tensor has dim retained or not. 267 | :param p: (float<0). 268 | """ 269 | assert p < 0 270 | res= torch.mean((tensor + 1e-6)**p, dim, keepdim=keepdim)**(1./p) 271 | return res 272 | 273 | 274 | """ 275 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 276 | All rights reserved. 277 | 278 | This software is covered by US patents and copyright. 279 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 280 | 281 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 282 | 283 | Last Modified: 10/02/2019 284 | """ 285 | -------------------------------------------------------------------------------- /object-locator/make_metric_plots.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import os 19 | import numpy as np 20 | import pandas as pd 21 | import argparse 22 | 23 | from . import metrics 24 | 25 | # Parse command-line arguments 26 | parser = argparse.ArgumentParser( 27 | description='Create a bunch of plot from the metrics in a CSV.', 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 29 | parser.add_argument('csv', 30 | help='CSV file with the precision and recall results.') 31 | parser.add_argument('out', 32 | help='Output directory.') 33 | parser.add_argument('--title', 34 | default='', 35 | help='Title of the plot in the figure.') 36 | parser.add_argument('--taus', 37 | type=str, 38 | required=True, 39 | help='Detection threshold taus. ' 40 | 'For each of these taus, a precision(r) and recall(r) will be created.' 41 | 'The closest to these values will be used.') 42 | parser.add_argument('--radii', 43 | type=str, 44 | required=True, 45 | help='List of values, each with different colors in the scatter plot. ' 46 | 'Maximum distance to consider a True Positive. ' 47 | 'The closest to this value will be used.') 48 | args = parser.parse_args() 49 | 50 | 51 | os.makedirs(args.out, exist_ok=True) 52 | 53 | taus = [float(tau) for tau in args.taus.replace('[', '').replace(']', '').split(',')] 54 | radii = [int(r) for r in args.radii.replace('[', '').replace(']', '').split(',')] 55 | 56 | figs = metrics.make_metric_plots(csv_path=args.csv, 57 | taus=taus, 58 | radii=radii, 59 | title=args.title) 60 | 61 | for label, fig in figs.items(): 62 | # Save to disk 63 | fig.savefig(os.path.join(args.out, f'{label}.png')) 64 | 65 | 66 | """ 67 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 68 | All rights reserved. 69 | 70 | This software is covered by US patents and copyright. 71 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 72 | 73 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 74 | 75 | Last Modified: 10/02/2019 76 | """ 77 | -------------------------------------------------------------------------------- /object-locator/metrics.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import math 19 | 20 | import pandas as pd 21 | import numpy as np 22 | import matplotlib 23 | matplotlib.use('Agg') 24 | from matplotlib import pyplot as plt 25 | import sklearn.metrics 26 | import sklearn.neighbors 27 | import scipy.stats 28 | from . import losses 29 | 30 | class Judge(): 31 | """ 32 | A Judge computes the following metrics: 33 | (Location metrics) 34 | - Precision 35 | - Recall 36 | - Fscore 37 | - Mean Average Hausdorff Distance (MAHD) 38 | (Count metrics) 39 | - Mean Error (ME) 40 | - Mean Absolute Error (MAE) 41 | - Mean Percent Error (MPE) 42 | - Mean Absolute Percent Error (MAPE) 43 | - Mean Squared Error (MSE) 44 | - Root Mean Squared Error (RMSE) 45 | - Pearson correlation (r) 46 | - Coefficient of determination (R^2) 47 | """ 48 | 49 | def __init__(self, r): 50 | """ 51 | Create a Judge that will compute metrics with a particular r 52 | (r is only used to compute Precision, Recall, and Fscore). 53 | 54 | :param r: If an estimated point and a ground truth point 55 | are at a distance <= r, then a True Positive is counted. 56 | """ 57 | # Location metrics 58 | self.r = r 59 | self.tp = 0 60 | self.fp = 0 61 | self.fn = 0 62 | 63 | # Count data points 64 | self._predicted_counts = [] 65 | self._true_counts = [] 66 | 67 | # Internal variables 68 | self._sum_ahd = 0 69 | self._sum_e = 0 70 | self._sum_pe = 0 71 | self._sum_ae = 0 72 | self._sum_se = 0 73 | self._sum_ape = 0 74 | self._n_calls_to_feed_points = 0 75 | self._n_calls_to_feed_count = 0 76 | 77 | def feed_points(self, pts, gt, max_ahd=np.inf): 78 | """ 79 | Evaluate the location metrics of one set of estimations. 80 | This set can correspond to the estimated points and 81 | the groundtruthed points of one image. 82 | The TP, FP, FN, Precision, Recall, Fscore, and AHD will be 83 | accumulated into this Judge. 84 | 85 | :param pts: List of estmated points. 86 | :param gt: List of ground truth points. 87 | :param max_ahd: Maximum AHD possible to return if any set is empty. Default: inf. 88 | """ 89 | 90 | if len(pts) == 0: 91 | tp = 0 92 | fp = 0 93 | fn = len(gt) 94 | else: 95 | nbr = sklearn.neighbors.NearestNeighbors(n_neighbors=1, metric='euclidean').fit(gt) 96 | dis, idx = nbr.kneighbors(pts) 97 | detected_pts = (dis[:, 0] <= self.r).astype(np.uint8) 98 | 99 | nbr = sklearn.neighbors.NearestNeighbors(n_neighbors=1, metric='euclidean').fit(pts) 100 | dis, idx = nbr.kneighbors(gt) 101 | detected_gt = (dis[:, 0] <= self.r).astype(np.uint8) 102 | 103 | tp = np.sum(detected_pts) 104 | fp = len(pts) - tp 105 | fn = len(gt) - np.sum(detected_gt) 106 | 107 | self.tp += tp 108 | self.fp += fp 109 | self.fn += fn 110 | 111 | # Evaluation using the Averaged Hausdorff Distance 112 | ahd = losses.averaged_hausdorff_distance(pts, gt, 113 | max_ahd=max_ahd) 114 | self._sum_ahd += ahd 115 | self._n_calls_to_feed_points += 1 116 | 117 | def feed_count(self, estim_count, gt_count): 118 | """ 119 | Evaluate count metrics for a count estimation. 120 | This count can correspond to the estimated and groundtruthed count 121 | of one image. The ME, MAE, MPE, MAPE, MSE, and RMSE will be updated 122 | accordignly. 123 | 124 | :param estim_count: (positive number) Estimated count. 125 | :param gt_count: (positive number) Groundtruthed count. 126 | """ 127 | 128 | if estim_count < 0: 129 | raise ValueError(f'estim_count < 0, got {estim_count}') 130 | if gt_count < 0: 131 | raise ValueError(f'gt_count < 0, got {gt_count}') 132 | 133 | self._predicted_counts.append(estim_count) 134 | self._true_counts.append(gt_count) 135 | 136 | e = estim_count - gt_count 137 | ae = abs(e) 138 | if gt_count == 0: 139 | ape = 100*ae 140 | pe = 100*e 141 | else: 142 | ape = 100 * ae / gt_count 143 | pe = 100 * e / gt_count 144 | se = e**2 145 | 146 | self._sum_e += e 147 | self._sum_pe += pe 148 | self._sum_ae += ae 149 | self._sum_se += se 150 | self._sum_ape += ape 151 | 152 | self._n_calls_to_feed_count += 1 153 | 154 | @property 155 | def me(self): 156 | """ Mean Error (float) """ 157 | return float(self._sum_e / self._n_calls_to_feed_count) 158 | 159 | @property 160 | def mae(self): 161 | """ Mean Absolute Error (positive float) """ 162 | return float(self._sum_ae / self._n_calls_to_feed_count) 163 | 164 | @property 165 | def mpe(self): 166 | """ Mean Percent Error (float) """ 167 | return float(self._sum_pe / self._n_calls_to_feed_count) 168 | 169 | @property 170 | def mape(self): 171 | """ Mean Absolute Percent Error (positive float) """ 172 | return float(self._sum_ape / self._n_calls_to_feed_count) 173 | 174 | @property 175 | def mse(self): 176 | """ Mean Squared Error (positive float)""" 177 | return float(self._sum_se / self._n_calls_to_feed_count) 178 | 179 | @property 180 | def rmse(self): 181 | """ Root Mean Squared Error (positive float)""" 182 | return float(math.sqrt(self.mse)) 183 | 184 | @property 185 | def coeff_of_determination(self): 186 | """ Coefficient of Determination (-inf, 1]""" 187 | return sklearn.metrics.r2_score(self._true_counts, 188 | self._predicted_counts) 189 | 190 | @property 191 | def pearson_corr(self): 192 | """ Pearson coefficient of Correlation [-1, 1]""" 193 | return scipy.stats.pearsonr(self._true_counts, 194 | self._predicted_counts)[0] 195 | 196 | @property 197 | def mahd(self): 198 | """ Mean Average Hausdorff Distance (positive float)""" 199 | return float(self._sum_ahd / self._n_calls_to_feed_points) 200 | 201 | @property 202 | def precision(self): 203 | """ Precision (positive float) """ 204 | return float(100*self.tp / (self.tp + self.fp)) \ 205 | if self.tp > 0 else 0 206 | 207 | @property 208 | def recall(self): 209 | """ Recall (positive float) """ 210 | return float(100*self.tp / (self.tp + self.fn)) \ 211 | if self.tp > 0 else 0 212 | 213 | @property 214 | def fscore(self): 215 | """ F-score (positive float) """ 216 | return float(2 * (self.precision*self.recall / 217 | (self.precision+self.recall))) \ 218 | if self.tp > 0 else 0 219 | 220 | 221 | def make_metric_plots(csv_path, taus, radii, title=''): 222 | """ 223 | Create a bunch of plots from the metrics contained in a CSV file. 224 | 225 | :param csv_path: Path to a CSV file containing metrics. 226 | :param taus: Detection thresholds tau. 227 | For each of these taus, a precision(r) and recall(r) will be created. 228 | The closest to each of these values will be used. 229 | :param radii: List of values, each with different colors in the scatter plot. 230 | Maximum distance to consider a True Positive. 231 | The closest to each of these values will be used. 232 | :param title: (optional) Title of the plot in the figure. 233 | :return: Dictionary with matplotlib figures. 234 | """ 235 | 236 | dic = {} 237 | 238 | # Data extraction 239 | df = pd.read_csv(csv_path) 240 | 241 | plt.ioff() 242 | 243 | # ==== Precision and Recall as a function of R, fixing t ==== 244 | for tau in taus: 245 | # Find closest threshold 246 | tau_selected = df.th.values[np.argmin(np.abs(df.th.values - tau))] 247 | print(f'Making Precision(r) and Recall(r) using tau={tau_selected}') 248 | 249 | # Use only a particular r 250 | precision = df.precision.values[df.th.values == tau_selected] 251 | recall = df.recall.values[df.th.values == tau_selected] 252 | r = df.r.values[df.th.values == tau_selected] 253 | 254 | # Create the figure for "Crowd" Dataset 255 | fig, ax = plt.subplots() 256 | precision = ax.plot(r, precision, 'r--',label='Precision') 257 | recall = ax.plot(r, recall, 'b:',label='Recall') 258 | ax.legend() 259 | ax.set_ylabel('%') 260 | ax.set_xlabel(r'$r$ (in pixels)') 261 | ax.grid(True) 262 | plt.title(title + f' tau={round(tau_selected, 4)}') 263 | 264 | # Hide grid lines below the plot 265 | ax.set_axisbelow(True) 266 | 267 | # Add figure to dictionary 268 | dic[f'precision_and_recall_vs_r,_tau={round(tau_selected, 4)}'] = fig 269 | plt.close(fig) 270 | 271 | # ==== Precision vs Recall ==== 272 | colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] 273 | if len(radii) > len(colors): 274 | print(f'W: {len(radii)} are too many radii to plot, ' 275 | f'taking {len(colors)} randomly.') 276 | radii = list(radii) 277 | np.random.shuffle(radii) 278 | radii = radii[:len(colors)] 279 | radii = sorted(radii) 280 | 281 | # Create figure 282 | fig, ax = plt.subplots() 283 | plt.ioff() 284 | ax.set_ylabel('Precision') 285 | ax.set_xlabel('Recall') 286 | ax.grid(True) 287 | plt.title(title) 288 | 289 | for r, c in zip(radii, colors): 290 | # Find closest R 291 | r_selected = df.r.values[np.argmin(np.abs(df.r.values - r))] 292 | 293 | # Use only a particular r for all fixed thresholds 294 | selection = (df.r.values == r_selected) & (df.th.values >= 0) 295 | if selection.any(): 296 | precision = df.precision.values[selection] 297 | recall = df.recall.values[selection] 298 | 299 | # Sort by ascending recall 300 | idxs = np.argsort(recall) 301 | recall = recall[idxs] 302 | precision = precision[idxs] 303 | 304 | # Plot precision vs. recall for this r 305 | ax.scatter(recall, precision, 306 | c=c, s=2, label=f'$r={r}$') 307 | 308 | # Otsu threshold (tau = -1) 309 | selection = (df.r.values == r_selected) & (df.th.values == -1) 310 | if selection.any(): 311 | precision = df.precision.values[selection] 312 | recall = df.recall.values[selection] 313 | ax.scatter(recall, precision, 314 | c=c, s=8, marker='+', label=f'$r={r}$, Otsu') 315 | 316 | # BMM threshold (tau = -2) 317 | selection = (df.r.values == r_selected) & (df.th.values == -2) 318 | if selection.any(): 319 | precision = df.precision.values[selection] 320 | recall = df.recall.values[selection] 321 | ax.scatter(recall, precision, 322 | c=c, s=8, marker='s', label=f'$r={r}$, BMM') 323 | 324 | # Invert legend order 325 | handles, labels = ax.get_legend_handles_labels() 326 | handles, labels = handles[::-1], labels[::-1] 327 | 328 | # Put legend outside the plot 329 | box = ax.get_position() 330 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) 331 | ax.legend(handles, labels, loc='upper left', bbox_to_anchor=(1, 1.03)) 332 | 333 | # Hide grid lines below the plot 334 | ax.set_axisbelow(True) 335 | 336 | # Add figure to dictionary 337 | dic['precision_vs_recall'] = fig 338 | plt.close(fig) 339 | 340 | 341 | # ==== Precision as a function of tau for all provided R ==== 342 | # Create figure 343 | fig, ax = plt.subplots() 344 | plt.ioff() 345 | ax.set_ylabel('Precision') 346 | ax.set_xlabel(r'$\tau$') 347 | ax.grid(True) 348 | plt.title(title) 349 | 350 | list_of_precisions = [] 351 | 352 | for r, c in zip(radii, colors): 353 | # Find closest R 354 | r_selected = df.r.values[np.argmin(np.abs(df.r.values - r))] 355 | 356 | # Use only a particular r for all fixed thresholds 357 | selection = (df.r.values == r_selected) & (df.th.values >= 0) 358 | if selection.any(): 359 | precision = df.precision.values[selection] 360 | list_of_precisions.append(precision) 361 | taus = df.th.values[selection] 362 | 363 | # Plot precision vs tau for this r 364 | ax.scatter(taus, precision, c=c, s=2, label=f'$r={r}$') 365 | 366 | # Otsu threshold (tau = -1) 367 | selection = (df.r.values == r_selected) & (df.th.values == -1) 368 | if selection.any(): 369 | precision = df.precision.values[selection] 370 | ax.axhline(y=precision, 371 | linestyle='-', 372 | c=c, label=f'$r={r}$, Otsu') 373 | 374 | # BMM threshold (tau = -1) 375 | selection = (df.r.values == r_selected) & (df.th.values == -2) 376 | if selection.any(): 377 | precision = df.precision.values[selection] 378 | ax.axhline(y=precision, 379 | linestyle='--', 380 | c=c, label=f'$r={r}$, BMM') 381 | 382 | if len(list_of_precisions) > 0: 383 | # Plot average precision for all r's 384 | ax.scatter(taus, np.average(np.stack(list_of_precisions), axis=0), 385 | c='k', marker='x', s=7, label='avg along r') 386 | 387 | 388 | 389 | # Invert legend order 390 | handles, labels = ax.get_legend_handles_labels() 391 | handles, labels = handles[::-1], labels[::-1] 392 | 393 | # Put legend outside the plot 394 | box = ax.get_position() 395 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) 396 | ax.legend(handles, labels, loc='upper left', bbox_to_anchor=(1, 1.03)) 397 | 398 | # Hide grid lines below the plot 399 | ax.set_axisbelow(True) 400 | 401 | # Add figure to dictionary 402 | dic['precision_vs_th'] = fig 403 | plt.close(fig) 404 | 405 | # ==== Recall as a function of tau for all provided R ==== 406 | # Create figure 407 | fig, ax = plt.subplots() 408 | plt.ioff() 409 | ax.set_ylabel('Recall') 410 | ax.set_xlabel(r'$\tau$') 411 | ax.grid(True) 412 | plt.title(title) 413 | 414 | list_of_recalls = [] 415 | 416 | for r, c in zip(radii, colors): 417 | # Find closest R 418 | r_selected = df.r.values[np.argmin(np.abs(df.r.values - r))] 419 | 420 | # Use only a particular r 421 | selection = (df.r.values == r_selected) & (df.th.values >= 0) 422 | if selection.any(): 423 | recall = df.recall.values[selection] 424 | list_of_recalls.append(recall) 425 | taus = df.th.values[selection] 426 | 427 | # Plot precision vs tau for this r 428 | ax.scatter(taus, recall, c=c, s=2, label=f'$r={r}$') 429 | 430 | # Otsu threshold (tau = -1) 431 | selection = (df.r.values == r_selected) & (df.th.values == -1) 432 | if selection.any(): 433 | recall = df.recall.values[selection] 434 | ax.axhline(y=recall, 435 | linestyle='-', 436 | c=c, label=f'$r={r}$, Otsu') 437 | 438 | # BMM threshold (tau = -2) 439 | selection = (df.r.values == r_selected) & (df.th.values == -2) 440 | if selection.any(): 441 | recall = df.recall.values[selection] 442 | ax.axhline(y=recall, 443 | linestyle='--', 444 | c=c, label=f'$r={r}$, BMM') 445 | 446 | 447 | if len(list_of_recalls) > 0: 448 | ax.scatter(taus, np.average(np.stack(list_of_recalls), axis=0), 449 | c='k', marker='x', s=7, label='avg along $r$') 450 | 451 | # Invert legend order 452 | handles, labels = ax.get_legend_handles_labels() 453 | handles, labels = handles[::-1], labels[::-1] 454 | 455 | # Put legend outside the plot 456 | box = ax.get_position() 457 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) 458 | ax.legend(handles, labels, loc='upper left', bbox_to_anchor=(1, 1.03)) 459 | 460 | # Hide grid lines below the plot 461 | ax.set_axisbelow(True) 462 | 463 | # Add figure to dictionary 464 | dic['recall_vs_tau'] = fig 465 | plt.close(fig) 466 | 467 | 468 | # ==== F-score as a function of tau for all provided R ==== 469 | # Create figure 470 | fig, ax = plt.subplots() 471 | plt.ioff() 472 | ax.set_ylabel('F-score') 473 | ax.set_xlabel(r'$\tau$') 474 | ax.grid(True) 475 | plt.title(title) 476 | 477 | list_of_fscores = [] 478 | 479 | for r, c in zip(radii, colors): 480 | # Find closest R 481 | r_selected = df.r.values[np.argmin(np.abs(df.r.values - r))] 482 | 483 | # Use only a particular r 484 | selection = (df.r.values == r_selected) & (df.th.values >= 0) 485 | if selection.any(): 486 | fscore = df.fscore.values[selection] 487 | list_of_fscores.append(fscore) 488 | taus = df.th.values[selection] 489 | 490 | # Plot precision vs tau for this r 491 | ax.scatter(taus, fscore, c=c, s=2, label=f'$r={r}$') 492 | 493 | # Otsu threshold (tau = -1) 494 | selection = (df.r.values == r_selected) & (df.th.values == -1) 495 | if selection.any(): 496 | fscore = df.fscore.values[selection] 497 | ax.axhline(y=fscore, 498 | linestyle='-', 499 | c=c, label=f'$r={r}$, Otsu') 500 | 501 | # BMM threshold (tau = -2) 502 | selection = (df.r.values == r_selected) & (df.th.values == -2) 503 | if selection.any(): 504 | fscore = df.fscore.values[selection] 505 | ax.axhline(y=fscore, 506 | linestyle='--', 507 | c=c, label=f'$r={r}$, BMM') 508 | 509 | if len(list_of_fscores) > 0: 510 | ax.scatter(taus, np.average(np.stack(list_of_fscores), axis=0), 511 | c='k', marker='x', s=7, label='avg along r') 512 | 513 | # Invert legend order 514 | handles, labels = ax.get_legend_handles_labels() 515 | handles, labels = handles[::-1], labels[::-1] 516 | 517 | # Put legend outside the plot 518 | box = ax.get_position() 519 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) 520 | ax.legend(handles, labels, loc='upper left', bbox_to_anchor=(1, 1.03)) 521 | 522 | # Hide grid lines below the plot 523 | ax.set_axisbelow(True) 524 | 525 | # Add figure to dictionary 526 | dic['fscore_vs_tau'] = fig 527 | plt.close(fig) 528 | 529 | return dic 530 | 531 | 532 | """ 533 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 534 | All rights reserved. 535 | 536 | This software is covered by US patents and copyright. 537 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 538 | 539 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 540 | 541 | Last Modified: 10/02/2019 542 | """ 543 | -------------------------------------------------------------------------------- /object-locator/metrics_from_results.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import os 19 | import argparse 20 | import ast 21 | import math 22 | 23 | from tqdm import tqdm 24 | import numpy as np 25 | import pandas as pd 26 | 27 | from . import metrics 28 | from . import get_image_size 29 | 30 | # Parse command-line arguments 31 | parser = argparse.ArgumentParser( 32 | description='Compute metrics from results and GT.', 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | required_args = parser.add_argument_group('MANDATORY arguments') 35 | optional_args = parser._action_groups.pop() 36 | required_args.add_argument('results', 37 | help='Input CSV file with the estimated locations.') 38 | required_args.add_argument('gt', 39 | help='Input CSV file with the groundtruthed locations.') 40 | required_args.add_argument('metrics', 41 | help='Output CSV file with the metrics ' 42 | '(MAE, AHD, Precision, Recall...)') 43 | required_args.add_argument('--dataset', 44 | type=str, 45 | required=True, 46 | help='Dataset directory with the images. ' 47 | 'This is used only to get the image diagonal, ' 48 | 'as the worst estimate for the AHD.') 49 | optional_args.add_argument('--radii', 50 | type=str, 51 | default=range(0, 15 + 1), 52 | metavar='Rs', 53 | help='Detections at dist <= R to a GT pt are True Positives.') 54 | args = parser.parse_args() 55 | 56 | 57 | # Prepare Judges that will compute P/R as fct of r and th 58 | judges = [metrics.Judge(r=r) for r in args.radii] 59 | 60 | df_results = pd.read_csv(args.results) 61 | df_gt = pd.read_csv(args.gt) 62 | 63 | df_metrics = pd.DataFrame(columns=['r', 64 | 'precision', 'recall', 'fscore', 'MAHD', 65 | 'MAPE', 'ME', 'MPE', 'MAE', 66 | 'MSE', 'RMSE', 'r', 'R2']) 67 | 68 | for j, judge in enumerate(tqdm(judges)): 69 | 70 | for idx, row_result in df_results.iterrows(): 71 | filename = row_result['filename'] 72 | row_gt = df_gt[df_gt['filename'] == filename].iloc()[0] 73 | 74 | w, h = get_image_size.get_image_size(os.path.join(args.dataset, filename)) 75 | diagonal = math.sqrt(w**2 + h**2) 76 | 77 | judge.feed_count(row_result['count'], 78 | row_gt['count']) 79 | judge.feed_points(ast.literal_eval(row_result['locations']), 80 | ast.literal_eval(row_gt['locations']), 81 | max_ahd=diagonal) 82 | 83 | df = pd.DataFrame(data=[[judge.r, 84 | judge.precision, 85 | judge.recall, 86 | judge.fscore, 87 | judge.mahd, 88 | judge.mape, 89 | judge.me, 90 | judge.mpe, 91 | judge.mae, 92 | judge.mse, 93 | judge.rmse, 94 | judge.pearson_corr \ 95 | if not np.isnan(judge.pearson_corr) else 1, 96 | judge.coeff_of_determination]], 97 | columns=['r', 98 | 'precision', 'recall', 'fscore', 'MAHD', 99 | 'MAPE', 'ME', 'MPE', 'MAE', 100 | 'MSE', 'RMSE', 'r', 'R2'], 101 | index=[j]) 102 | df.index.name = 'idx' 103 | df_metrics = df_metrics.append(df) 104 | 105 | # Write CSV of metrics to disk 106 | df_metrics.to_csv(args.metrics) 107 | 108 | 109 | """ 110 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 111 | All rights reserved. 112 | 113 | This software is covered by US patents and copyright. 114 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 115 | 116 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 117 | 118 | Last Modified: 10/02/2019 119 | """ 120 | -------------------------------------------------------------------------------- /object-locator/models/__init__.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | -------------------------------------------------------------------------------- /object-locator/models/unet_model.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 11/11/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.autograd import Variable 22 | 23 | from .unet_parts import * 24 | 25 | 26 | class UNet(nn.Module): 27 | def __init__(self, n_channels, n_classes, 28 | height, width, 29 | known_n_points=None, 30 | ultrasmall=False, 31 | device=torch.device('cuda')): 32 | """ 33 | Instantiate a UNet network. 34 | :param n_channels: Number of input channels (e.g, 3 for RGB) 35 | :param n_classes: Number of output classes 36 | :param height: Height of the input images 37 | :param known_n_points: If you know the number of points, 38 | (e.g, one pupil), then set it. 39 | Otherwise it will be estimated by a lateral NN. 40 | If provided, no lateral network will be build 41 | and the resulting UNet will be a FCN. 42 | :param ultrasmall: If True, the 5 central layers are removed, 43 | resulting in a much smaller UNet. 44 | :param device: Which torch device to use. Default: CUDA (GPU). 45 | """ 46 | super(UNet, self).__init__() 47 | 48 | self.ultrasmall = ultrasmall 49 | self.device = device 50 | 51 | # With this network depth, there is a minimum image size 52 | if height < 256 or width < 256: 53 | raise ValueError('Minimum input image size is 256x256, got {}x{}'.\ 54 | format(height, width)) 55 | 56 | self.inc = inconv(n_channels, 64) 57 | self.down1 = down(64, 128) 58 | self.down2 = down(128, 256) 59 | if self.ultrasmall: 60 | self.down3 = down(256, 512, normaliz=False) 61 | self.up1 = up(768, 128) 62 | self.up2 = up(256, 64) 63 | self.up3 = up(128, 64, activ=False) 64 | else: 65 | self.down3 = down(256, 512) 66 | self.down4 = down(512, 512) 67 | self.down5 = down(512, 512) 68 | self.down6 = down(512, 512) 69 | self.down7 = down(512, 512) 70 | self.down8 = down(512, 512, normaliz=False) 71 | self.up1 = up(1024, 512) 72 | self.up2 = up(1024, 512) 73 | self.up3 = up(1024, 512) 74 | self.up4 = up(1024, 512) 75 | self.up5 = up(1024, 256) 76 | self.up6 = up(512, 128) 77 | self.up7 = up(256, 64) 78 | self.up8 = up(128, 64, activ=False) 79 | self.outc = outconv(64, n_classes) 80 | self.out_nonlin = nn.Sigmoid() 81 | 82 | self.known_n_points = known_n_points 83 | if known_n_points is None: 84 | steps = 3 if self.ultrasmall else 8 85 | height_mid_features = height//(2**steps) 86 | width_mid_features = width//(2**steps) 87 | self.branch_1 = nn.Sequential(nn.Linear(height_mid_features*\ 88 | width_mid_features*\ 89 | 512, 90 | 64), 91 | nn.ReLU(inplace=True), 92 | nn.Dropout(p=0.5)) 93 | self.branch_2 = nn.Sequential(nn.Linear(height*width, 64), 94 | nn.ReLU(inplace=True), 95 | nn.Dropout(p=0.5)) 96 | self.regressor = nn.Sequential(nn.Linear(64 + 64, 1), 97 | nn.ReLU()) 98 | 99 | # This layer is not connected anywhere 100 | # It is only here for backward compatibility 101 | self.lin = nn.Linear(1, 1, bias=False) 102 | 103 | def forward(self, x): 104 | 105 | batch_size = x.shape[0] 106 | 107 | x1 = self.inc(x) 108 | x2 = self.down1(x1) 109 | x3 = self.down2(x2) 110 | x4 = self.down3(x3) 111 | if self.ultrasmall: 112 | x = self.up1(x4, x3) 113 | x = self.up2(x, x2) 114 | x = self.up3(x, x1) 115 | else: 116 | x5 = self.down4(x4) 117 | x6 = self.down5(x5) 118 | x7 = self.down6(x6) 119 | x8 = self.down7(x7) 120 | x9 = self.down8(x8) 121 | x = self.up1(x9, x8) 122 | x = self.up2(x, x7) 123 | x = self.up3(x, x6) 124 | x = self.up4(x, x5) 125 | x = self.up5(x, x4) 126 | x = self.up6(x, x3) 127 | x = self.up7(x, x2) 128 | x = self.up8(x, x1) 129 | x = self.outc(x) 130 | x = self.out_nonlin(x) 131 | 132 | # Reshape Bx1xHxW -> BxHxW 133 | # because probability map is real-valued by definition 134 | x = x.squeeze(1) 135 | 136 | if self.known_n_points is None: 137 | middle_layer = x4 if self.ultrasmall else x9 138 | middle_layer_flat = middle_layer.view(batch_size, -1) 139 | x_flat = x.view(batch_size, -1) 140 | 141 | lateral_flat = self.branch_1(middle_layer_flat) 142 | x_flat = self.branch_2(x_flat) 143 | 144 | regression_features = torch.cat((x_flat, lateral_flat), dim=1) 145 | regression = self.regressor(regression_features) 146 | 147 | return x, regression 148 | else: 149 | n_pts = torch.tensor([self.known_n_points]*batch_size, 150 | dtype=torch.get_default_dtype()) 151 | n_pts = n_pts.to(self.device) 152 | return x, n_pts 153 | # summ = torch.sum(x) 154 | # count = self.lin(summ) 155 | 156 | # count = torch.abs(count) 157 | 158 | # if self.known_n_points is not None: 159 | # count = Variable(torch.cuda.FloatTensor([self.known_n_points])) 160 | 161 | # return x, count 162 | 163 | 164 | """ 165 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 166 | All rights reserved. 167 | 168 | This software is covered by US patents and copyright. 169 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 170 | 171 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 172 | 173 | Last Modified: 11/11/2019 174 | """ 175 | -------------------------------------------------------------------------------- /object-locator/models/unet_parts.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | # sub-parts of the U-Net model 19 | 20 | import math 21 | import warnings 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | 28 | class double_conv(nn.Module): 29 | def __init__(self, in_ch, out_ch, normaliz=True, activ=True): 30 | super(double_conv, self).__init__() 31 | 32 | ops = [] 33 | ops += [nn.Conv2d(in_ch, out_ch, 3, padding=1)] 34 | # ops += [nn.Dropout(p=0.1)] 35 | if normaliz: 36 | ops += [nn.BatchNorm2d(out_ch)] 37 | if activ: 38 | ops += [nn.ReLU(inplace=True)] 39 | ops += [nn.Conv2d(out_ch, out_ch, 3, padding=1)] 40 | # ops += [nn.Dropout(p=0.1)] 41 | if normaliz: 42 | ops += [nn.BatchNorm2d(out_ch)] 43 | if activ: 44 | ops += [nn.ReLU(inplace=True)] 45 | 46 | self.conv = nn.Sequential(*ops) 47 | 48 | def forward(self, x): 49 | x = self.conv(x) 50 | return x 51 | 52 | 53 | class inconv(nn.Module): 54 | def __init__(self, in_ch, out_ch): 55 | super(inconv, self).__init__() 56 | self.conv = double_conv(in_ch, out_ch) 57 | 58 | def forward(self, x): 59 | x = self.conv(x) 60 | return x 61 | 62 | 63 | class down(nn.Module): 64 | def __init__(self, in_ch, out_ch, normaliz=True): 65 | super(down, self).__init__() 66 | self.mpconv = nn.Sequential( 67 | nn.MaxPool2d(2), 68 | double_conv(in_ch, out_ch, normaliz=normaliz) 69 | ) 70 | 71 | def forward(self, x): 72 | x = self.mpconv(x) 73 | return x 74 | 75 | 76 | class up(nn.Module): 77 | def __init__(self, in_ch, out_ch, normaliz=True, activ=True): 78 | super(up, self).__init__() 79 | self.up = nn.Upsample(scale_factor=2, 80 | mode='bilinear', 81 | align_corners=True) 82 | # self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) 83 | self.conv = double_conv(in_ch, out_ch, 84 | normaliz=normaliz, activ=activ) 85 | 86 | def forward(self, x1, x2): 87 | with warnings.catch_warnings(): 88 | warnings.simplefilter("ignore") # Upsample is deprecated 89 | x1 = self.up(x1) 90 | diffY = x2.size()[2] - x1.size()[2] 91 | diffX = x2.size()[3] - x1.size()[3] 92 | x1 = F.pad(x1, (diffX // 2, int(math.ceil(diffX / 2)), 93 | diffY // 2, int(math.ceil(diffY / 2)))) 94 | x = torch.cat([x2, x1], dim=1) 95 | x = self.conv(x) 96 | return x 97 | 98 | 99 | class outconv(nn.Module): 100 | def __init__(self, in_ch, out_ch): 101 | super(outconv, self).__init__() 102 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 103 | # self.conv = nn.Sequential( 104 | # nn.Conv2d(in_ch, out_ch, 1), 105 | # ) 106 | 107 | def forward(self, x): 108 | x = self.conv(x) 109 | return x 110 | 111 | 112 | """ 113 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 114 | All rights reserved. 115 | 116 | This software is covered by US patents and copyright. 117 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 118 | 119 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 120 | 121 | Last Modified: 10/02/2019 122 | """ 123 | -------------------------------------------------------------------------------- /object-locator/models/utils.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import h5py 19 | import torch 20 | import shutil 21 | 22 | def save_net(fname, net): 23 | with h5py.File(fname, 'w') as h5f: 24 | for k, v in net.state_dict().items(): 25 | h5f.create_dataset(k, data=v.cpu().numpy()) 26 | def load_net(fname, net): 27 | with h5py.File(fname, 'r') as h5f: 28 | for k, v in net.state_dict().items(): 29 | param = torch.from_numpy(np.asarray(h5f[k])) 30 | v.copy_(param) 31 | 32 | def save_checkpoint(state, is_best,task_id, filename='checkpoint.pth.tar'): 33 | torch.save(state, task_id+filename) 34 | if is_best: 35 | shutil.copyfile(task_id+filename, task_id+'model_best.pth.tar') 36 | 37 | 38 | """ 39 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 40 | All rights reserved. 41 | 42 | This software is covered by US patents and copyright. 43 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 44 | 45 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 46 | 47 | Last Modified: 10/02/2019 48 | """ 49 | -------------------------------------------------------------------------------- /object-locator/paint.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | __copyright__ = \ 4 | """ 5 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 6 | All rights reserved. 7 | 8 | This software is covered by US patents and copyright. 9 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 10 | 11 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 12 | 13 | Last Modified: 10/02/2019 14 | """ 15 | __license__ = "CC BY-NC-SA 4.0" 16 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 17 | __version__ = "1.6.0" 18 | 19 | 20 | import os 21 | import sys 22 | 23 | import cv2 24 | from tqdm import tqdm 25 | import numpy as np 26 | import torch 27 | from torchvision import transforms 28 | from torch.utils import data 29 | 30 | from .data import CSVDataset 31 | from .data import csv_collator 32 | from . import argparser 33 | from . import utils 34 | 35 | 36 | # Parse command line arguments 37 | args = argparser.parse_command_args('testing') 38 | 39 | # Tensor type to use, select CUDA or not 40 | torch.set_default_dtype(torch.float32) 41 | device_cpu = torch.device('cpu') 42 | 43 | # Set seeds 44 | np.random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | if args.cuda: 47 | torch.cuda.manual_seed_all(args.seed) 48 | 49 | # Data loading code 50 | try: 51 | testset = CSVDataset(args.dataset, 52 | transforms=transforms.Compose([ 53 | transforms.ToTensor(), 54 | ]), 55 | max_dataset_size=args.max_testset_size) 56 | except ValueError as e: 57 | print(f'E: {e}') 58 | exit(-1) 59 | dataset_loader = data.DataLoader(testset, 60 | batch_size=1, 61 | num_workers=args.nThreads, 62 | collate_fn=csv_collator) 63 | 64 | os.makedirs(os.path.join(args.out), exist_ok=True) 65 | 66 | for img, dictionary in tqdm(dataset_loader): 67 | 68 | # Move to device 69 | img = img.to(device_cpu) 70 | 71 | # One image at a time (BS=1) 72 | img = img[0] 73 | dictionary = dictionary[0] 74 | 75 | # Tensor -> float & numpy 76 | target_locs = dictionary['locations'].to(device_cpu).numpy().reshape(-1, 2) 77 | img = img.to(device_cpu).numpy() 78 | 79 | img *= 255 80 | 81 | # Paint circles on top of image 82 | img_with_x = utils.paint_circles(img=img, 83 | points=target_locs, 84 | color='white') 85 | img_with_x = np.moveaxis(img_with_x, 0, 2) 86 | img_with_x = img_with_x[:, :, ::-1] 87 | 88 | cv2.imwrite(os.path.join(args.out, dictionary['filename']), 89 | img_with_x) 90 | 91 | 92 | """ 93 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 94 | All rights reserved. 95 | 96 | This software is covered by US patents and copyright. 97 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 98 | 99 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 100 | 101 | Last Modified: 10/02/2019 102 | """ 103 | -------------------------------------------------------------------------------- /object-locator/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | __copyright__ = \ 4 | """ 5 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 6 | All rights reserved. 7 | 8 | This software is covered by US patents and copyright. 9 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 10 | 11 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 12 | 13 | Last Modified: 10/02/2019 14 | """ 15 | __license__ = "CC BY-NC-SA 4.0" 16 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 17 | __version__ = "1.6.0" 18 | 19 | 20 | import math 21 | import cv2 22 | import os 23 | import sys 24 | import time 25 | import shutil 26 | from itertools import chain 27 | from tqdm import tqdm 28 | 29 | import numpy as np 30 | import torch 31 | import torch.optim as optim 32 | from torch import nn 33 | from torch.autograd import Variable 34 | import torchvision as tv 35 | from torchvision.models import inception_v3 36 | from torchvision import transforms 37 | from torch.utils.data import DataLoader 38 | import matplotlib 39 | matplotlib.use('Agg') 40 | import skimage.transform 41 | from peterpy import peter 42 | from ballpark import ballpark 43 | 44 | from . import losses 45 | from .models import unet_model 46 | from .metrics import Judge 47 | from . import logger 48 | from . import argparser 49 | from . import utils 50 | from . import data 51 | from .data import csv_collator 52 | from .data import RandomHorizontalFlipImageAndLabel 53 | from .data import RandomVerticalFlipImageAndLabel 54 | from .data import ScaleImageAndLabel 55 | 56 | 57 | # Parse command line arguments 58 | args = argparser.parse_command_args('training') 59 | 60 | # Tensor type to use, select CUDA or not 61 | torch.set_default_dtype(torch.float32) 62 | device_cpu = torch.device('cpu') 63 | device = torch.device('cuda') if args.cuda else device_cpu 64 | 65 | # Create directory for checkpoint to be saved 66 | if args.save: 67 | os.makedirs(os.path.split(args.save)[0], exist_ok=True) 68 | 69 | # Set seeds 70 | np.random.seed(args.seed) 71 | torch.manual_seed(args.seed) 72 | if args.cuda: 73 | torch.cuda.manual_seed_all(args.seed) 74 | 75 | # Visdom setup 76 | log = logger.Logger(server=args.visdom_server, 77 | port=args.visdom_port, 78 | env_name=args.visdom_env) 79 | 80 | 81 | # Create data loaders (return data in batches) 82 | trainset_loader, valset_loader = \ 83 | data.get_train_val_loaders(train_dir=args.train_dir, 84 | max_trainset_size=args.max_trainset_size, 85 | collate_fn=csv_collator, 86 | height=args.height, 87 | width=args.width, 88 | seed=args.seed, 89 | batch_size=args.batch_size, 90 | drop_last_batch=args.drop_last_batch, 91 | num_workers=args.nThreads, 92 | val_dir=args.val_dir, 93 | max_valset_size=args.max_valset_size) 94 | 95 | # Model 96 | with peter('Building network'): 97 | model = unet_model.UNet(3, 1, 98 | height=args.height, 99 | width=args.width, 100 | known_n_points=args.n_points, 101 | device=device, 102 | ultrasmall=args.ultrasmallnet) 103 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 104 | print(f" with {ballpark(num_params)} trainable parameters. ", end='') 105 | model = nn.DataParallel(model) 106 | model.to(device) 107 | 108 | # Loss functions 109 | loss_regress = nn.SmoothL1Loss() 110 | loss_loc = losses.WeightedHausdorffDistance(resized_height=args.height, 111 | resized_width=args.width, 112 | p=args.p, 113 | return_2_terms=True, 114 | device=device) 115 | 116 | # Optimization strategy 117 | if args.optimizer == 'sgd': 118 | optimizer = optim.SGD(model.parameters(), 119 | lr=args.lr, 120 | momentum=0.9) 121 | elif args.optimizer == 'adam': 122 | optimizer = optim.Adam(model.parameters(), 123 | lr=args.lr, 124 | amsgrad=True) 125 | 126 | start_epoch = 0 127 | lowest_mahd = np.infty 128 | 129 | # Restore saved checkpoint (model weights + epoch + optimizer state) 130 | if args.resume: 131 | with peter('Loading checkpoint'): 132 | if os.path.isfile(args.resume): 133 | checkpoint = torch.load(args.resume) 134 | start_epoch = checkpoint['epoch'] 135 | try: 136 | lowest_mahd = checkpoint['mahd'] 137 | except KeyError: 138 | lowest_mahd = np.infty 139 | print('W: Loaded checkpoint has not been validated. ', end='') 140 | model.load_state_dict(checkpoint['model']) 141 | if not args.replace_optimizer: 142 | optimizer.load_state_dict(checkpoint['optimizer']) 143 | print(f"\n\__ loaded checkpoint '{args.resume}'" 144 | f"(now on epoch {checkpoint['epoch']})") 145 | else: 146 | print(f"\n\__ E: no checkpoint found at '{args.resume}'") 147 | exit(-1) 148 | 149 | running_avg = utils.RunningAverage(len(trainset_loader)) 150 | 151 | normalzr = utils.Normalizer(args.height, args.width) 152 | 153 | # Time at the last evaluation 154 | tic_train = -np.infty 155 | tic_val = -np.infty 156 | 157 | epoch = start_epoch 158 | it_num = 0 159 | while epoch < args.epochs: 160 | 161 | loss_avg_this_epoch = 0 162 | iter_train = tqdm(trainset_loader, 163 | desc=f'Epoch {epoch} ({len(trainset_loader.dataset)} images)') 164 | 165 | # === TRAIN === 166 | 167 | # Set the module in training mode 168 | model.train() 169 | 170 | for batch_idx, (imgs, dictionaries) in enumerate(iter_train): 171 | 172 | # Pull info from this batch and move to device 173 | imgs = imgs.to(device) 174 | target_locations = [dictt['locations'].to(device) 175 | for dictt in dictionaries] 176 | target_counts = [dictt['count'].to(device) 177 | for dictt in dictionaries] 178 | target_orig_heights = [dictt['orig_height'].to(device) 179 | for dictt in dictionaries] 180 | target_orig_widths = [dictt['orig_width'].to(device) 181 | for dictt in dictionaries] 182 | 183 | # Lists -> Tensor batches 184 | target_counts = torch.stack(target_counts) 185 | target_orig_heights = torch.stack(target_orig_heights) 186 | target_orig_widths = torch.stack(target_orig_widths) 187 | target_orig_sizes = torch.stack((target_orig_heights, 188 | target_orig_widths)).transpose(0, 1) 189 | 190 | # One training step 191 | optimizer.zero_grad() 192 | est_maps, est_counts = model.forward(imgs) 193 | term1, term2 = loss_loc.forward(est_maps, 194 | target_locations, 195 | target_orig_sizes) 196 | est_counts = est_counts.view(-1) 197 | target_counts = target_counts.view(-1) 198 | term3 = loss_regress.forward(est_counts, target_counts) 199 | term3 *= args.lambdaa 200 | loss = term1 + term2 + term3 201 | loss.backward() 202 | optimizer.step() 203 | 204 | # Update progress bar 205 | running_avg.put(loss.item()) 206 | iter_train.set_postfix(running_avg=f'{round(running_avg.avg/3, 1)}') 207 | 208 | # Log training error 209 | if time.time() > tic_train + args.log_interval: 210 | tic_train = time.time() 211 | 212 | # Log training losses 213 | log.train_losses(terms=[term1, term2, term3, loss / 3, running_avg.avg / 3], 214 | iteration_number=epoch + 215 | batch_idx/len(trainset_loader), 216 | terms_legends=['Term1', 217 | 'Term2', 218 | 'Term3*%s' % args.lambdaa, 219 | 'Sum/3', 220 | 'Sum/3 runn avg']) 221 | 222 | # Resize images to original size 223 | orig_shape = target_orig_sizes[0].data.to(device_cpu).numpy().tolist() 224 | orig_img_origsize = ((skimage.transform.resize(imgs[0].data.squeeze().to(device_cpu).numpy().transpose((1, 2, 0)), 225 | output_shape=orig_shape, 226 | mode='constant') + 1) / 2.0 * 255.0).\ 227 | astype(np.float32).transpose((2, 0, 1)) 228 | est_map_origsize = skimage.transform.resize(est_maps[0].data.unsqueeze(0).to(device_cpu).numpy().transpose((1, 2, 0)), 229 | output_shape=orig_shape, 230 | mode='constant').\ 231 | astype(np.float32).transpose((2, 0, 1)).squeeze(0) 232 | 233 | # Overlay output on heatmap 234 | orig_img_w_heatmap_origsize = utils.overlay_heatmap(img=orig_img_origsize, 235 | map=est_map_origsize).\ 236 | astype(np.float32) 237 | 238 | # Send heatmap with circles at the labeled points to Visdom 239 | target_locs_np = target_locations[0].\ 240 | to(device_cpu).numpy().reshape(-1, 2) 241 | target_orig_size_np = target_orig_sizes[0].\ 242 | to(device_cpu).numpy().reshape(2) 243 | target_locs_wrt_orig = normalzr.unnormalize(target_locs_np, 244 | orig_img_size=target_orig_size_np) 245 | img_with_x = utils.paint_circles(img=orig_img_w_heatmap_origsize, 246 | points=target_locs_wrt_orig, 247 | color='white') 248 | log.image(imgs=[img_with_x], 249 | titles=['(Training) Image w/ output heatmap and labeled points'], 250 | window_ids=[1]) 251 | 252 | # # Read image with GT dots from disk 253 | # gt_img_numpy = skimage.io.imread( 254 | # os.path.join('/home/jprat/projects/phenosorg/data/plant_counts_dots/20160613_F54_training_256x256_white_bigdots', 255 | # dictionary['filename'][0])) 256 | # # dots_img_tensor = torch.from_numpy(gt_img_numpy).permute( 257 | # # 2, 0, 1)[0, :, :].type(torch.FloatTensor) / 255 258 | # # Send GT image to Visdom 259 | # viz.image(np.moveaxis(gt_img_numpy, 2, 0), 260 | # opts=dict(title='(Training) Ground Truth'), 261 | # win=3) 262 | 263 | it_num += 1 264 | 265 | # Never do validation? 266 | if not args.val_dir or \ 267 | not valset_loader or \ 268 | len(valset_loader) == 0 or \ 269 | args.val_freq == 0: 270 | 271 | # Time to save checkpoint? 272 | if args.save and (epoch + 1) % args.val_freq == 0: 273 | torch.save({'epoch': epoch, 274 | 'model': model.state_dict(), 275 | 'optimizer': optimizer.state_dict(), 276 | 'n_points': args.n_points, 277 | }, args.save) 278 | epoch += 1 279 | continue 280 | 281 | # Time to do validation? 282 | if (epoch + 1) % args.val_freq != 0: 283 | epoch += 1 284 | continue 285 | 286 | # === VALIDATION === 287 | 288 | # Set the module in evaluation mode 289 | model.eval() 290 | 291 | judge = Judge(r=args.radius) 292 | sum_term1 = 0 293 | sum_term2 = 0 294 | sum_term3 = 0 295 | sum_loss = 0 296 | iter_val = tqdm(valset_loader, 297 | desc=f'Validating Epoch {epoch} ({len(valset_loader.dataset)} images)') 298 | for batch_idx, (imgs, dictionaries) in enumerate(iter_val): 299 | 300 | # Pull info from this batch and move to device 301 | imgs = imgs.to(device) 302 | target_locations = [dictt['locations'].to(device) 303 | for dictt in dictionaries] 304 | target_counts = [dictt['count'].to(device) 305 | for dictt in dictionaries] 306 | target_orig_heights = [dictt['orig_height'].to(device) 307 | for dictt in dictionaries] 308 | target_orig_widths = [dictt['orig_width'].to(device) 309 | for dictt in dictionaries] 310 | 311 | with torch.no_grad(): 312 | target_counts = torch.stack(target_counts) 313 | target_orig_heights = torch.stack(target_orig_heights) 314 | target_orig_widths = torch.stack(target_orig_widths) 315 | target_orig_sizes = torch.stack((target_orig_heights, 316 | target_orig_widths)).transpose(0, 1) 317 | orig_shape = (dictionaries[0]['orig_height'].item(), 318 | dictionaries[0]['orig_width'].item()) 319 | 320 | # Tensor -> float & numpy 321 | target_count_int = int(round(target_counts.item())) 322 | target_locations_np = \ 323 | target_locations[0].to(device_cpu).numpy().reshape(-1, 2) 324 | target_orig_size_np = \ 325 | target_orig_sizes[0].to(device_cpu).numpy().reshape(2) 326 | 327 | normalzr = utils.Normalizer(args.height, args.width) 328 | 329 | if target_count_int == 0: 330 | continue 331 | 332 | # Feed-forward 333 | with torch.no_grad(): 334 | est_maps, est_counts = model.forward(imgs) 335 | 336 | # Tensor -> int 337 | est_count_int = int(round(est_counts.item())) 338 | 339 | # The 3 terms 340 | with torch.no_grad(): 341 | est_counts = est_counts.view(-1) 342 | target_counts = target_counts.view(-1) 343 | term1, term2 = loss_loc.forward(est_maps, 344 | target_locations, 345 | target_orig_sizes) 346 | term3 = loss_regress.forward(est_counts, target_counts) 347 | term3 *= args.lambdaa 348 | sum_term1 += term1.item() 349 | sum_term2 += term2.item() 350 | sum_term3 += term3.item() 351 | sum_loss += term1 + term2 + term3 352 | 353 | # Update progress bar 354 | loss_avg_this_epoch = sum_loss.item() / (batch_idx + 1) 355 | iter_val.set_postfix( 356 | avg_val_loss_this_epoch=f'{loss_avg_this_epoch:.1f}-----') 357 | 358 | # The estimated map must be thresholed to obtain estimated points 359 | # BMM thresholding 360 | est_map_numpy = est_maps[0, :, :].to(device_cpu).numpy() 361 | est_map_numpy_origsize = skimage.transform.resize(est_map_numpy, 362 | output_shape=orig_shape, 363 | mode='constant') 364 | mask, _ = utils.threshold(est_map_numpy_origsize, tau=-1) 365 | # Obtain centroids of the mask 366 | centroids_wrt_orig = utils.cluster(mask, est_count_int, 367 | max_mask_pts=args.max_mask_pts) 368 | 369 | # Validation metrics 370 | target_locations_wrt_orig = normalzr.unnormalize(target_locations_np, 371 | orig_img_size=target_orig_size_np) 372 | judge.feed_points(centroids_wrt_orig, target_locations_wrt_orig, 373 | max_ahd=loss_loc.max_dist) 374 | judge.feed_count(est_count_int, target_count_int) 375 | 376 | if time.time() > tic_val + args.log_interval: 377 | tic_val = time.time() 378 | 379 | # Resize to original size 380 | orig_img_origsize = ((skimage.transform.resize(imgs[0].to(device_cpu).squeeze().numpy().transpose((1, 2, 0)), 381 | output_shape=target_orig_size_np.tolist(), 382 | mode='constant') + 1) / 2.0 * 255.0).\ 383 | astype(np.float32).transpose((2, 0, 1)) 384 | est_map_origsize = skimage.transform.resize(est_maps[0].to(device_cpu).unsqueeze(0).numpy().transpose((1, 2, 0)), 385 | output_shape=orig_shape, 386 | mode='constant').\ 387 | astype(np.float32).transpose((2, 0, 1)).squeeze(0) 388 | 389 | # Overlay output on heatmap 390 | orig_img_w_heatmap_origsize = utils.overlay_heatmap(img=orig_img_origsize, 391 | map=est_map_origsize).\ 392 | astype(np.float32) 393 | 394 | # # Read image with GT dots from disk 395 | # gt_img_numpy = skimage.io.imread( 396 | # os.path.join('/home/jprat/projects/phenosorg/data/plant_counts_dots/20160613_F54_validation_256x256_white_bigdots', 397 | # dictionary['filename'][0])) 398 | # # dots_img_tensor = torch.from_numpy(gt_img_numpy).permute( 399 | # # 2, 0, 1)[0, :, :].type(torch.FloatTensor) / 255 400 | # # Send GT image to Visdom 401 | # viz.image(np.moveaxis(gt_img_numpy, 2, 0), 402 | # opts=dict(title='(Validation) Ground Truth'), 403 | # win=7) 404 | if not args.paint: 405 | # Send input and output heatmap (first one in the batch) 406 | log.image(imgs=[orig_img_w_heatmap_origsize], 407 | titles=['(Validation) Image w/ output heatmap'], 408 | window_ids=[5]) 409 | else: 410 | # Send heatmap with a cross at the estimated centroids to Visdom 411 | img_with_x = utils.paint_circles(img=orig_img_w_heatmap_origsize, 412 | points=centroids_wrt_orig, 413 | color='red', 414 | crosshair=True ) 415 | log.image(imgs=[img_with_x], 416 | titles=['(Validation) Image w/ output heatmap ' 417 | 'and point estimations'], 418 | window_ids=[8]) 419 | 420 | avg_term1_val = sum_term1 / len(valset_loader) 421 | avg_term2_val = sum_term2 / len(valset_loader) 422 | avg_term3_val = sum_term3 / len(valset_loader) 423 | avg_loss_val = sum_loss / len(valset_loader) 424 | 425 | # Log validation metrics 426 | log.val_losses(terms=(avg_term1_val, 427 | avg_term2_val, 428 | avg_term3_val, 429 | avg_loss_val / 3, 430 | judge.mahd, 431 | judge.mae, 432 | judge.rmse, 433 | judge.mape, 434 | judge.coeff_of_determination, 435 | judge.pearson_corr \ 436 | if not np.isnan(judge.pearson_corr) else 1, 437 | judge.precision, 438 | judge.recall), 439 | iteration_number=epoch, 440 | terms_legends=['Term 1', 441 | 'Term 2', 442 | 'Term3*%s' % args.lambdaa, 443 | 'Sum/3', 444 | 'AHD', 445 | 'MAE', 446 | 'RMSE', 447 | 'MAPE (%)', 448 | 'R^2', 449 | 'r', 450 | f'r{args.radius}-Precision (%)', 451 | f'r{args.radius}-Recall (%)']) 452 | 453 | # If this is the best epoch (in terms of validation error) 454 | if judge.mahd < lowest_mahd: 455 | # Keep the best model 456 | lowest_mahd = judge.mahd 457 | if args.save: 458 | torch.save({'epoch': epoch + 1, # when resuming, we will start at the next epoch 459 | 'model': model.state_dict(), 460 | 'mahd': lowest_mahd, 461 | 'optimizer': optimizer.state_dict(), 462 | 'n_points': args.n_points, 463 | }, args.save) 464 | print("Saved best checkpoint so far in %s " % args.save) 465 | 466 | epoch += 1 467 | 468 | 469 | """ 470 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 471 | All rights reserved. 472 | 473 | This software is covered by US patents and copyright. 474 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 475 | 476 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 477 | 478 | Last Modified: 10/02/2019 479 | """ 480 | -------------------------------------------------------------------------------- /object-locator/utils.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import torch 19 | import numpy as np 20 | import sklearn.mixture 21 | import scipy.stats 22 | import cv2 23 | from . import bmm 24 | from matplotlib import pyplot as plt 25 | import matplotlib.cm 26 | import scipy.stats 27 | 28 | class Normalizer(): 29 | def __init__(self, new_size_height, new_size_width): 30 | """ 31 | Normalizer. 32 | Converts coordinates in an original image size 33 | to a new image size (resized/normalized). 34 | 35 | :param new_size_height: (int) Height of the new (resized) image size. 36 | :param new_size_width: (int) Width of the new (resized) image size. 37 | """ 38 | new_size_height = int(new_size_height) 39 | new_size_width = int(new_size_width) 40 | 41 | self.new_size = np.array([new_size_height, new_size_width]) 42 | 43 | def unnormalize(self, coordinates_yx_normalized, orig_img_size): 44 | """ 45 | Unnormalize coordinates, 46 | i.e, make them with respect to the original image. 47 | 48 | :param coordinates_yx_normalized: 49 | :param orig_size: Original image size ([height, width]). 50 | :return: Unnormalized coordinates 51 | """ 52 | 53 | orig_img_size = np.array(orig_img_size) 54 | assert orig_img_size.ndim == 1 55 | assert len(orig_img_size) == 2 56 | 57 | norm_factor = orig_img_size / self.new_size 58 | norm_factor = np.tile(norm_factor, (len(coordinates_yx_normalized),1)) 59 | coordinates_yx_unnormalized = norm_factor*coordinates_yx_normalized 60 | 61 | return coordinates_yx_unnormalized 62 | 63 | def threshold(array, tau): 64 | """ 65 | Threshold an array using either hard thresholding, Otsu thresholding or beta-fitting. 66 | 67 | If the threshold value is fixed, this function returns 68 | the mask and the threshold used to obtain the mask. 69 | When using tau=-1, the threshold is obtained as described in the Otsu method. 70 | When using tau=-2, it also returns the fitted 2-beta Mixture Model. 71 | 72 | 73 | :param array: Array to threshold. 74 | :param tau: (float) Threshold to use. 75 | Values above tau become 1, and values below tau become 0. 76 | If -1, use Otsu thresholding. 77 | If -2, fit a mixture of 2 beta distributions, and use 78 | the average of the two means. 79 | :return: The tuple (mask, threshold). 80 | If tau==-2, returns the tuple (mask, otsu_tau, ((rv1, rv2), (pi1, pi2))). 81 | 82 | """ 83 | if tau == -1: 84 | # Otsu thresholding 85 | minn, maxx = array.min(), array.max() 86 | array_scaled = ((array - minn)/(maxx - minn)*255) \ 87 | .round().astype(np.uint8).squeeze() 88 | tau, mask = cv2.threshold(array_scaled, 89 | 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 90 | tau = minn + (tau/255)*(maxx - minn) 91 | # print(f'Otsu selected tau={tau_otsu}') 92 | elif tau == -2: 93 | array_flat = array.flatten() 94 | ((a1, b1), (a2, b2)), (pi1, pi2), niter = bmm.estimate(array_flat, list(range(2))) 95 | rv1 = scipy.stats.beta(a1, b1) 96 | rv2 = scipy.stats.beta(a2, b2) 97 | 98 | tau = rv2.mean() 99 | mask = cv2.inRange(array, tau, 1) 100 | 101 | return mask, tau, ((rv1, pi1), (rv2, pi2)) 102 | else: 103 | # Thresholding with a fixed threshold tau 104 | mask = cv2.inRange(array, tau, 1) 105 | 106 | return mask, tau 107 | 108 | 109 | class AccBetaMixtureModel(): 110 | 111 | def __init__(self, n_components=2, n_pts=1000): 112 | """ 113 | Accumulator that tracks multiple Mixture Models based on Beta distributions. 114 | Each mixture is a tuple (scipy.RV, weight). 115 | 116 | :param n_components: (int) Number of components in the mixtures. 117 | :param n_pts: Number of points in the x axis (values the RV can take in [0, 1]) 118 | """ 119 | self.n_components = n_components 120 | self.mixtures = [] 121 | self.x = np.linspace(0, 1, n_pts) 122 | 123 | def feed(self, mixture): 124 | """ 125 | Accumulate another mixture so that this AccBetaMixtureModel can track it. 126 | 127 | :param mixture: List/Tuple of mixtures, i.e, ((RV, weight), (RV, weight), ...) 128 | """ 129 | assert len(mixture) == self.n_components 130 | 131 | self.mixtures.append(mixture) 132 | 133 | def plot(self): 134 | """ 135 | Create and return plots showing a variety of stats 136 | of the mixtures feeded into this object. 137 | """ 138 | assert len(self.mixtures) > 0 139 | 140 | figs = {} 141 | 142 | # Compute the mean of the pdf of each component 143 | pdf_means = [(1/len(self.mixtures))*np.clip(rv.pdf(self.x), a_min=0, a_max=8)\ 144 | for rv, w in self.mixtures[0]] 145 | for mix in self.mixtures[1:]: 146 | for c, (rv, w) in enumerate(mix): 147 | pdf_means[c] += (1/len(self.mixtures))*np.clip(rv.pdf(self.x), a_min=0, a_max=8) 148 | 149 | # Compute the stdev of the pdf of each component 150 | if len(self.mixtures) > 1: 151 | pdfs_sq_err_sum = [(np.clip(rv.pdf(self.x), a_min=0, a_max=8) - pdf_means[c])**2 \ 152 | for c, (rv, w) in enumerate(self.mixtures[0])] 153 | for mix in self.mixtures[1:]: 154 | for c, (rv, w) in enumerate(mix): 155 | pdfs_sq_err_sum[c] += (np.clip(rv.pdf(self.x), a_min=0, a_max=8) - pdf_means[c])**2 156 | pdf_stdevs = [np.sqrt(pdf_sq_err_sum)/(len(self.mixtures) - 1) \ 157 | for pdf_sq_err_sum in pdfs_sq_err_sum] 158 | 159 | # Plot the means of the pdfs 160 | fig, ax = plt.subplots() 161 | colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] 162 | for c, (pdf_mean, color) in enumerate(zip(pdf_means, colors)): 163 | ax.plot(self.x, pdf_mean, c=color, label=f'BMM Component #{c}') 164 | ax.set_xlabel('Pixel value / $\\tau$') 165 | ax.set_ylabel('Probability Density') 166 | plt.legend() 167 | 168 | if len(self.mixtures) > 1: 169 | # # Plot the standard deviations of the pdfs 170 | # fig, ax = plt.subplots() 171 | # colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] 172 | # max_stdev = 0 173 | # for c, (pdf_stdev, color) in enumerate(zip(pdf_stdevs, colors)): 174 | # ax.plot(self.x, pdf_stdev, c=color, label=f'Component #{c}') 175 | # max_stdev = max(max_stdev, max(pdf_stdev)) 176 | # ax.set_title('Standard Deviation of the\nProbability Density Functions\n' 177 | # 'of the fitted bimodal Beta Mixture Model') 178 | # ax.set_xlabel('Pixel value') 179 | # ax.set_ylabel('Standard Deviation') 180 | # ax.set_ylim([0, max_stdev]) 181 | # figs['std_bmm'] = fig 182 | # plt.close(fig) 183 | 184 | # Plot the KDE of the histogram of the threshold (the mean of last RV) 185 | thresholds = [mix[-1][0].mean() for mix in self.mixtures] 186 | thresholds = np.array(thresholds)[np.bitwise_not(np.isnan(thresholds))] 187 | kde = scipy.stats.gaussian_kde(thresholds.reshape(1, -1)) 188 | ax.plot(self.x, kde.pdf(self.x), 189 | '--', 190 | label='KDE of $\\tau$ selected by BMM method') 191 | ax.set_xlabel('Pixel value / $\\tau$') 192 | ax.set_ylabel('Probability Density') 193 | plt.legend() 194 | figs['bmm_stats'] = fig 195 | plt.close(fig) 196 | 197 | return figs 198 | 199 | def cluster(array, n_clusters, max_mask_pts=np.infty): 200 | """ 201 | Cluster a 2-D binary array. 202 | Applies a Gaussian Mixture Model on the positive elements of the array, 203 | and returns the number of clusters. 204 | 205 | :param array: Binary array. 206 | :param n_clusters: Number of clusters (Gaussians) to fit, 207 | :param max_mask_pts: Randomly subsample "max_pts" points 208 | from the array before fitting. 209 | :return: Centroids in the input array. 210 | """ 211 | 212 | array = np.array(array) 213 | 214 | assert array.ndim == 2 215 | 216 | coord = np.where(array > 0) 217 | y = coord[0].reshape((-1, 1)) 218 | x = coord[1].reshape((-1, 1)) 219 | c = np.concatenate((y, x), axis=1) 220 | if len(c) == 0: 221 | centroids = np.array([]) 222 | else: 223 | # Subsample our points randomly so it is faster 224 | if max_mask_pts != np.infty: 225 | n_pts = min(len(c), max_mask_pts) 226 | np.random.shuffle(c) 227 | c = c[:n_pts] 228 | 229 | # If the estimation is horrible, we cannot fit a GMM if n_components > n_samples 230 | n_components = max(min(n_clusters, x.size), 1) 231 | centroids = sklearn.mixture.GaussianMixture(n_components=n_components, 232 | n_init=1, 233 | covariance_type='full').\ 234 | fit(c).means_.astype(np.int) 235 | 236 | return centroids 237 | 238 | 239 | class RunningAverage(): 240 | 241 | def __init__(self, size): 242 | self.list = [] 243 | self.size = size 244 | 245 | def put(self, elem): 246 | if len(self.list) >= self.size: 247 | self.list.pop(0) 248 | self.list.append(elem) 249 | 250 | def pop(self): 251 | self.list.pop(0) 252 | 253 | @property 254 | def avg(self): 255 | return np.average(self.list) 256 | 257 | 258 | def overlay_heatmap(img, map, colormap=matplotlib.cm.viridis): 259 | """ 260 | Overlay a scalar map onto an image by using a heatmap 261 | 262 | :param img: RGB image (numpy array). 263 | Must be between 0 and 255. 264 | First dimension must be color. 265 | :param map: Scalar image (numpy array) 266 | Must be a 2D array between 0 and 1. 267 | :param colormap: Colormap to use to convert grayscale values 268 | to pseudo-color. 269 | :return: Heatmap on top of the original image in [0, 255] 270 | """ 271 | assert img.ndim == 3 272 | assert map.ndim == 2 273 | assert img.shape[0] == 3 274 | 275 | # Convert image to CHW->HWC 276 | img = img.transpose(1, 2, 0) 277 | 278 | # Generate pseudocolor 279 | heatmap = colormap(map)[:, :, :3] 280 | 281 | # Scale heatmap [0, 1] -> [0, 255] 282 | heatmap *= 255 283 | 284 | # Fusion! 285 | img_w_heatmap = (img + heatmap)/2 286 | 287 | # Convert output to HWC->CHW 288 | img_w_heatmap = img_w_heatmap.transpose(2, 0, 1) 289 | 290 | return img_w_heatmap 291 | 292 | 293 | def paint_circles(img, points, color='red', crosshair=False): 294 | """ 295 | Paint points as circles on top of an image. 296 | 297 | :param img: RGB image (numpy array). 298 | Must be between 0 and 255. 299 | First dimension must be color. 300 | :param centroids: List of centroids in (y, x) format. 301 | :param color: String of the color used to paint centroids. 302 | Default: 'red'. 303 | :param crosshair: Paint crosshair instead of circle. 304 | Default: False. 305 | :return: Image with painted circles centered on the points. 306 | First dimension is be color. 307 | """ 308 | 309 | if color == 'red': 310 | color = [255, 0, 0] 311 | elif color == 'white': 312 | color = [255, 255, 255] 313 | else: 314 | raise NotImplementedError(f'color {color} not implemented') 315 | 316 | points = points.round().astype(np.uint16) 317 | 318 | img = np.moveaxis(img, 0, 2).copy() 319 | if not crosshair: 320 | for y, x in points: 321 | img = cv2.circle(img, (x, y), 3, color, -1) 322 | else: 323 | for y, x in points: 324 | img = cv2.drawMarker(img, 325 | (x, y), 326 | color, cv2.MARKER_TILTED_CROSS, 7, 1, cv2.LINE_AA) 327 | img = np.moveaxis(img, 2, 0) 328 | 329 | return img 330 | 331 | 332 | def nothing(*args, **kwargs): 333 | """ A useless function that does nothing at all. """ 334 | pass 335 | 336 | 337 | """ 338 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 339 | All rights reserved. 340 | 341 | This software is covered by US patents and copyright. 342 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 343 | 344 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 345 | 346 | Last Modified: 10/02/2019 347 | """ 348 | -------------------------------------------------------------------------------- /scripts_dataset_and_results/generate_csv.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import pandas as pd 19 | import cv2 20 | import numpy as np 21 | import sys 22 | import os 23 | import ast 24 | import random 25 | import shutil 26 | from tqdm import tqdm 27 | 28 | np.random.seed(0) 29 | 30 | train_df = pd.DataFrame(columns=['plant_count']) 31 | test_df = pd.DataFrame(columns=['plant_count']) 32 | validate_df = pd.DataFrame(columns=['plant_count']) 33 | 34 | if not os.path.exists('train'): 35 | os.makedirs('train') 36 | if not os.path.exists('test'): 37 | os.makedirs('test') 38 | if not os.path.exists('validate'): 39 | os.makedirs('validate') 40 | 41 | dirs = [i for i in range(1, 18)] 42 | dirs.pop(11) 43 | 44 | filecounter = 0 45 | for dirnum in dirs: 46 | dirname = 'dataset' + str(dirnum).zfill(2) 47 | 48 | fd = open(os.path.join(dirname,'gt.txt')) 49 | 50 | data = [] 51 | for line in fd: 52 | line = line.strip() 53 | imgnum = line.split(' ')[1] 54 | x = line.split(' ')[2] 55 | if (x == 'X'): 56 | continue 57 | y = line.split(' ')[3] 58 | 59 | imagename = imgnum.zfill(10)+'.png' 60 | if not os.path.exists(os.path.join(dirname,imagename)): 61 | continue 62 | image = cv2.imread(os.path.join(dirname,imagename)) 63 | 64 | h = image.shape[0] 65 | x = int(x)/2 66 | y = h - int(y)/2 67 | data.append([imagename, y, x]) 68 | 69 | #print(imagename) 70 | #print(x, y) 71 | 72 | random.shuffle(data) 73 | for i in range(len(data)): 74 | item = data[i] 75 | imagename = item[0] 76 | y = item[1] 77 | x = item[2] 78 | 79 | # newname = str(filecounter).zfill(10) + '.png' 80 | newname = dirname + '_' + imagename 81 | df = pd.DataFrame(data=[[1, [[y, x]]]], 82 | index=[newname], 83 | columns=['plant_count', 'plant_locations']) 84 | if (i < len(data)*0.8): 85 | if os.path.isfile('train/'+newname): 86 | print('%s exists' % 'train/'+newname) 87 | exit(-1) 88 | shutil.move(os.path.join(dirname,imagename), 'train/'+newname) 89 | train_df = train_df.append(df) 90 | elif (i < len(data)*0.9): 91 | if os.path.isfile('train/'+newname): 92 | print('%s exists' % 'test/'+newname) 93 | exit(-1) 94 | shutil.move(os.path.join(dirname,imagename), 'test/'+newname) 95 | test_df = test_df.append(df) 96 | else: 97 | if os.path.isfile('train/'+newname): 98 | print('%s exists' % 'test/'+newname) 99 | exit(-1) 100 | shutil.move(os.path.join(dirname,imagename), 'validate/'+newname) 101 | validate_df = validate_df.append(df) 102 | 103 | train_df.to_csv('train.csv') 104 | shutil.move('train.csv', 'train') 105 | test_df.to_csv('test.csv') 106 | shutil.move('test.csv', 'test') 107 | validate_df.to_csv('validate.csv') 108 | shutil.move('validate.csv', 'validate') 109 | 110 | 111 | """ 112 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 113 | All rights reserved. 114 | 115 | This software is covered by US patents and copyright. 116 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 117 | 118 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 119 | 120 | Last Modified: 10/02/2019 121 | """ 122 | -------------------------------------------------------------------------------- /scripts_dataset_and_results/parseResults.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import pandas as pd 19 | import numpy as np 20 | import sys 21 | import ast 22 | import cv2 23 | from sklearn.cluster import KMeans 24 | from sklearn.metrics.pairwise import pairwise_distances 25 | from sklearn import mixture 26 | 27 | CSV_FILE = "estimations.csv" 28 | 29 | def eval_plant_locations(estimated, gt): 30 | """ 31 | Distance function between the estimated plant locations and the ground 32 | truth. 33 | This function is a symmetric function which parameter is the estimated 34 | plant locations and which is the ground truth should not matter. 35 | The returned value is guaranteed to be always positive, 36 | and is only zero if both lists are exactly equal. 37 | 38 | :param estimated: List of (x, y) or (y,x) plant locations. 39 | :param gt: List of (x, y) or (y, x) plant locations. 40 | :return: Distance between two sets. 41 | """ 42 | 43 | estimated = np.array(estimated) 44 | gt = np.array(gt) 45 | 46 | # Check dimension 47 | assert estimated.ndim == gt.ndim == 2, \ 48 | 'Both estimated and GT plant locations must be 2D, i.e, (x, y) or (y, x)' 49 | 50 | d2_matrix = pairwise_distances(estimated, gt, metric='euclidean') 51 | 52 | res = np.average(np.min(d2_matrix, axis=0)) + \ 53 | np.average(np.min(d2_matrix, axis=1)) 54 | 55 | return res 56 | 57 | def processImg(image, n, GMM=False): 58 | #extract mask from the image 59 | mask = cv2.inRange(image, (5,5,5), (255,255,255)) 60 | coord = np.where(mask > 0) 61 | y = coord[0].reshape((-1, 1)) 62 | x = coord[1].reshape((-1, 1)) 63 | 64 | c = np.concatenate((y, x), axis=1) 65 | 66 | if GMM: 67 | gmm = mixture.GaussianMixture(n_components=n, n_init=1, covariance_type='full').fit(c) 68 | return gmm.means_.astype(np.int) 69 | 70 | else: 71 | 72 | #find kmean cluster 73 | kmeans = KMeans(n_clusters=n, random_state=0).fit(c) 74 | return kmeans.cluster_centers_ 75 | 76 | def processCSV(csvfile): 77 | 78 | df = pd.read_csv(csvfile) 79 | res_array = [] 80 | for i in range(len(df.iloc[:])): 81 | filename = df.iloc[:, 1][i] 82 | 83 | plant_count = df.iloc[:, 2][i] 84 | plant_count = float(plant_count.split('\n')[1].strip()) 85 | 86 | gt = df.iloc[:, 3][i] 87 | gt = ast.literal_eval(gt) 88 | 89 | image = cv2.imread(filename) 90 | detected = processImg(image, int(plant_count), GMM=True) 91 | 92 | res = eval_plant_locations(detected, gt) 93 | res_array.append(res) 94 | print(res) 95 | break 96 | return res_array 97 | 98 | 99 | #Note the script needs to be put into the data directory with the CSV file 100 | res = processCSV(CSV_FILE) 101 | 102 | 103 | """ 104 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 105 | All rights reserved. 106 | 107 | This software is covered by US patents and copyright. 108 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 109 | 110 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 111 | 112 | Last Modified: 10/02/2019 113 | """ 114 | -------------------------------------------------------------------------------- /scripts_dataset_and_results/spacing_stats_to_csv.py: -------------------------------------------------------------------------------- 1 | __copyright__ = \ 2 | """ 3 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 4 | All rights reserved. 5 | 6 | This software is covered by US patents and copyright. 7 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 8 | 9 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 10 | 11 | Last Modified: 10/02/2019 12 | """ 13 | __license__ = "CC BY-NC-SA 4.0" 14 | __authors__ = "Javier Ribera, David Guera, Yuhao Chen, Edward J. Delp" 15 | __version__ = "1.6.0" 16 | 17 | 18 | import argparse 19 | import os 20 | import pandas as pd 21 | from tqdm import tqdm 22 | from scipy.spatial.distance import euclidean as distance 23 | import statistics 24 | import matplotlib.mlab as mlab 25 | import matplotlib.pyplot as plt 26 | import numpy as np 27 | 28 | if __name__ == '__main__': 29 | # Parse command-line arguments 30 | parser = argparse.ArgumentParser( 31 | description='Compute intra-row spacing stats of a CSV. ' 32 | 'Add mean, median, and stdev of each row. ' 33 | 'Optional: plot histograms', 34 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 35 | parser.add_argument('in_csv', 36 | help='Input CSV with plant location info.') 37 | parser.add_argument('out_csv', 38 | help='Output CSV with the added stats.') 39 | parser.add_argument('--hist', 40 | metavar='DIR', 41 | help='Directory with histograms.') 42 | parser.add_argument('--res', 43 | metavar='DIR', 44 | type=float, 45 | default=1, 46 | help='Resolution in centimeters.') 47 | args = parser.parse_args() 48 | 49 | # Import GT from CSV 50 | df = pd.read_csv(args.in_csv) 51 | 52 | # Store stats of each single-row plot 53 | means, medians, stds = [], [], [] 54 | 55 | for idx, row in tqdm(df.iterrows(), total=len(df.index)): 56 | if row['locations_wrt_orthophoto'] is np.nan: 57 | continue 58 | locs = eval(row['locations_wrt_orthophoto']) 59 | 60 | # 1. Sort by row coordinate 61 | locs = sorted(locs, key=lambda x: x[0]) 62 | 63 | # 2. Compute distances (chain-like) between plants 64 | dists = list(map(distance, locs[:-1], locs[1:])) 65 | 66 | # 3. pixels -> centimeters 67 | dists = [d * args.res for d in dists] 68 | 69 | # 4. Statistics! 70 | mean = statistics.mean(dists) 71 | median = statistics.median(dists) 72 | std = statistics.stdev(dists) 73 | means.append(mean) 74 | medians.append(median) 75 | stds.append(std) 76 | 77 | # 5. Put in CSV 78 | df.loc[idx, 'mean_intrarow_spacing_in_cm'] = mean 79 | df.loc[idx, 'median_intrarow_spacing_in_cm'] = median 80 | df.loc[idx, 'stdev_intrarow_spacing_in_cm'] = std 81 | 82 | # Save to disk as CSV 83 | df.to_csv(args.out_csv) 84 | 85 | if args.hist is not None: 86 | os.makedirs(args.hist, exist_ok=True) 87 | 88 | # 6. Generate nice graphs for presentation 89 | # Means 90 | fig = plt.figure() 91 | n, bins, patches = plt.hist( 92 | means, 30, normed=1, facecolor='green', alpha=0.75, label='Histogram') 93 | # add a 'best fit' norm line 94 | y = mlab.normpdf(bins, statistics.mean(means), statistics.stdev(means)) 95 | l = plt.plot(bins, y, 'r--', linewidth=1, label='Fitted Gaussian') 96 | plt.xlabel('Average intra-row spacing [cm]') 97 | plt.ylabel('Probability') 98 | plt.title('Histogram of average intra-row spacing') 99 | plt.axis([5, 30, 0, 0.3]) 100 | plt.grid(True) 101 | plt.legend() 102 | fig.savefig(os.path.join( 103 | args.hist, 'histogram_averages.png'), dpi=fig.dpi) 104 | 105 | # Medians 106 | fig = plt.figure() 107 | n, bins, patches = plt.hist( 108 | medians, 30, normed=1, facecolor='green', alpha=0.75, label='Histogram') 109 | # add a 'best fit' norm line 110 | y = mlab.normpdf(bins, statistics.mean( 111 | medians), statistics.stdev(medians)) 112 | l = plt.plot(bins, y, 'r--', linewidth=1, label='Fitted Gaussian') 113 | plt.xlabel('Median of intra-row spacing [cm]') 114 | plt.ylabel('Probability') 115 | plt.title('Histogram of medians intra-row spacing') 116 | plt.axis([5, 30, 0, 0.3]) 117 | plt.grid(True) 118 | plt.legend() 119 | fig.savefig(os.path.join( 120 | args.hist, 'histogram_medians.png'), dpi=fig.dpi) 121 | 122 | # Standard deviations 123 | fig = plt.figure() 124 | n, bins, patches = plt.hist( 125 | stds, 30, normed=1, facecolor='green', alpha=0.75, label='Histogram') 126 | # add a 'best fit' norm line 127 | y = mlab.normpdf(bins, statistics.mean(stds), statistics.stdev(stds)) 128 | l = plt.plot(bins, y, 'r--', linewidth=1, label='Fitted Gaussian') 129 | plt.xlabel('Standard deviation of intra-row spacing [cm]') 130 | plt.ylabel('Probability') 131 | plt.title('Histogram of standard deviations of intra-row spacing') 132 | plt.axis([0, 25, 0, 0.3]) 133 | plt.grid(True) 134 | plt.legend() 135 | fig.savefig(os.path.join( 136 | args.hist, 'histogram_stdevs.png'), dpi=fig.dpi) 137 | 138 | 139 | """ 140 | Copyright ©right © (c) 2019 The Board of Trustees of Purdue University and the Purdue Research Foundation. 141 | All rights reserved. 142 | 143 | This software is covered by US patents and copyright. 144 | This source code is to be used for academic research purposes only, and no commercial use is allowed. 145 | 146 | For any questions, please contact Edward J. Delp (ace@ecn.purdue.edu) at Purdue University. 147 | 148 | Last Modified: 10/02/2019 149 | """ 150 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='object-locator', 5 | version='1.6.0', 6 | description='Object Location using PyTorch.', 7 | 8 | # The project's main homepage. 9 | url='https://engineering.purdue.edu/~sorghum', 10 | 11 | # Author details 12 | author='Javier Ribera, David Guera, Yuhao Chen, and Edward J. Delp', 13 | author_email='ace@ecn.purdue.edu', 14 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 15 | classifiers=[ 16 | # How mature is this project? Common values are 17 | # 3 - Alpha 18 | # 4 - Beta 19 | # 5 - Production/Stable 20 | 'Development Status :: 4 - Beta', 21 | 22 | # Specify the Python versions you support here. In particular, ensure 23 | # that you indicate whether you support Python 2, Python 3 or both. 24 | 'Programming Language :: Python :: 3.6', 25 | ], 26 | python_requires='~=3.6', 27 | # What does your project relate to? 28 | keywords='object localization location purdue', 29 | 30 | # You can just specify the packages manually here if your project is 31 | # simple. Or you can use find_packages(). 32 | packages=['object-locator', 'object-locator.models'], 33 | package_dir={'object-locator': 'object-locator'}, 34 | 35 | package_data={'object-locator': ['checkpoints/*.ckpt', 36 | '../COPYRIGHT.txt', 37 | '../README.md']}, 38 | include_package_data=True, 39 | 40 | # List run-time dependencies here. These will be installed by pip when 41 | # your project is installed. For an analysis of "install_requires" vs pip's 42 | # requirements files see: 43 | # https://packaging.python.org/en/latest/requirements.html 44 | # (We actually use conda for dependency management) 45 | # install_requires=['matplotlib', 'numpy', 46 | # 'scikit-image', 'tqdm', 'argparse', 'parse', 47 | # 'scikit-learn', 'pandas'], 48 | ) 49 | --------------------------------------------------------------------------------