├── .gitignore ├── LICENSE ├── README.md ├── agedb-dir ├── README.md ├── data │ ├── agedb.csv │ ├── create_agedb.py │ └── preprocess_agedb.py ├── datasets.py ├── loss.py ├── ranking.py ├── ranksim.py ├── resnet.py ├── train.py └── utils.py ├── imdb-wiki-dir ├── README.md ├── data │ ├── create_imd_wiki.py │ ├── download_imdb_wiki.py │ ├── imdb_wiki.csv │ └── preprocess_imdb_wiki.py ├── datasets.py ├── loss.py ├── ranking.py ├── ranksim.py ├── resnet.py ├── train.py └── utils.py ├── nyud2-dir ├── ConR.py ├── README.md ├── balanaced_mse.py ├── data │ ├── nyu2_test.csv │ ├── nyu2_train.csv │ └── nyu2_train_FDS_subset.csv ├── download_nyud2.py ├── loaddata.py ├── models │ ├── __init__.py │ ├── fds.py │ ├── modules.py │ ├── net.py │ └── resnet.py ├── nyu_transform.py ├── preprocess_gmm.py ├── preprocess_nyud2.py ├── test.py ├── train.py └── util.py └── teaser.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. 438 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConR: Contrastive Regularizer for Deep Imbalanced Regression 2 | 3 | This repository contains the code for the ICLR 2024 paper: 4 | [__ConR: Contrastive Regularizer for Deep Imbalanced Regression__](https://openreview.net/forum?id=RIuevDSK5V)
5 | Mahsa Keramati, Lili Meng, R David Evans
6 | 7 |

8 |
9 | ConR key insights. a) Without ConR, it 10 | is common to have minority examples mixed with 11 | majority examples. b) ConR adds additional loss 12 | weight for minority, and mis-labelled examples, 13 | resulting in better feature representations and c) 14 | better prediction error. 15 |

16 | 17 | ## Quick Preview 18 | ConR is complementary to conventional imbalanced learning techniques. The following code snippent shows the implementation of ConR for the task of Age estimation 19 | 20 | ```python 21 | def ConR(features, targets, preds, w=1, weights=1, t=0.07, e=0.01): 22 | q = torch.nn.functional.normalize(features, dim=1) 23 | k = torch.nn.functional.normalize(features, dim=1) 24 | 25 | l_k = targets.flatten()[None, :] 26 | l_q = targets 27 | 28 | p_k = preds.flatten()[None, :] 29 | p_q = preds 30 | 31 | l_dist = torch.abs(l_q - l_k) 32 | p_dist = torch.abs(p_q - p_k) 33 | 34 | pos_i = l_dist.le(w) 35 | neg_i = ((~ (l_dist.le(w))) * (p_dist.le(w))) 36 | 37 | for i in range(pos_i.shape[0]): 38 | pos_i[i][i] = 0 39 | 40 | prod = torch.einsum("nc,kc->nk", [q, k]) / t 41 | pos = prod * pos_i 42 | neg = prod * neg_i 43 | 44 | pushing_w = weights * torch.exp(l_dist * e) 45 | neg_exp_dot = (pushing_w * (torch.exp(neg)) * neg_i).sum(1) 46 | 47 | # For each query sample, if there is no negative pair, zero-out the loss. 48 | no_neg_flag = (neg_i).sum(1).bool() 49 | 50 | # Loss = sum over all samples in the batch (sum over (positive dot product/(negative dot product+positive dot product))) 51 | denom = pos_i.sum(1) 52 | 53 | loss = ((-torch.log( 54 | torch.div(torch.exp(pos), (torch.exp(pos).sum(1) + neg_exp_dot).unsqueeze(-1))) * ( 55 | pos_i)).sum(1) / denom) 56 | 57 | loss = (weights * (loss * no_neg_flag).unsqueeze(-1)).mean() 58 | 59 | return loss 60 | ``` 61 | 62 | ## Usage 63 | 64 | Please go into the sub-folder to run experiments for different datasets. 65 | 66 | - [IMDB-WIKI-DIR](./imdb-wiki-dir) 67 | - [AgeDB-DIR](./agedb-dir) 68 | - [NYUD2-DIR](./nyud2-dir) 69 | 70 | 71 | ## Acknowledgment 72 | 73 | The code is based on [Yang et al., Delving into Deep Imbalanced Regression, ICML 2021](https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir) and [Ren et al.,Balanced MSE for Imbalanced Visual Regression, CVPR 2022](https://github.com/jiawei-ren/BalancedMSE). 74 | -------------------------------------------------------------------------------- /agedb-dir/README.md: -------------------------------------------------------------------------------- 1 | # ConR on AgeDB-DIR 2 | This repository contains the implementation of __ConR__ on *AgeDB-DIR* dataset. 3 | 4 | The imbalanced regression framework and LDS+FDS are based on the public repository of [Gong et al., ICML 2022](https://github.com/BorealisAI/ranksim-imbalanced-regression). 5 | 6 | 7 | 8 | ## Installation 9 | 10 | #### Prerequisites 11 | 12 | 1. Download AgeDB dataset from [here](https://ibug.doc.ic.ac.uk/resources/agedb/) and extract the zip file (you may need to contact the authors of AgeDB dataset for the zip password) to folder `./data` 13 | 14 | 2. We use the standard train/val/test split file (`agedb.csv` in folder `./data`) provided by Yang et al.(ICML 2021), which is used to set up balanced val/test set. To reproduce the results in the paper, please directly use this file. You can also generate it using 15 | 16 | ```bash 17 | python data/create_agedb.py 18 | python data/preprocess_agedb.py 19 | ``` 20 | 21 | #### Dependencies 22 | 23 | - PyTorch (>= 1.2, tested on 1.6) 24 | - tensorboard_logger 25 | - numpy, pandas, scipy, tqdm, matplotlib, PIL, wget 26 | 27 | ## Code Overview 28 | 29 | #### Main Files 30 | 31 | - `train.py`: main training and evaluation script 32 | - `create_agedb.py`: create AgeDB raw meta data 33 | - `preprocess_agedb.py`: create AgeDB-DIR meta file `agedb.csv` with balanced val/test set 34 | 35 | #### Main Arguments 36 | 37 | - `--data_dir`: data directory to place data and meta file 38 | - `--reweight`: cost-sensitive re-weighting scheme to use 39 | - `--loss`: training loss type 40 | - `--conr`: wether to use ConR or not. 41 | - `-w`: distance threshold (default 1.0) 42 | - `--beta`: the scale of ConR loss (default 4.0) 43 | - `-t`: temperature(default 0.2) 44 | - `-e`: pushing power scale(default 0.01) 45 | ## Getting Started 46 | 47 | ### 1. Train baselines 48 | 49 | To use Vanilla model 50 | 51 | ```bash 52 | python train.py --batch_size 64 --lr 2.5e-4 53 | ``` 54 | 55 | 56 | 57 | ### 2. Train a model with ConR 58 | ##### batch size 64, learning rate 2.5e-4 59 | 60 | ```bash 61 | python train.py --batch_size 64 --lr 2.5e-4 --conr -w 1.0 --beta 4.0 -e 0.01 62 | ``` 63 | 64 | 65 | 66 | ### 3. Evaluate and reproduce 67 | 68 | If you do not train the model, you can evaluate the model and reproduce our results directly using the pretrained weights from the anonymous links below. 69 | 70 | ```bash 71 | python train.py --evaluate [...evaluation model arguments...] --resume 72 | ``` 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /agedb-dir/data/create_agedb.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | # MIT License 3 | 4 | # Copyright (c) 2021 Yuzhe Yang 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | ######################################################################################## 24 | import os 25 | import argparse 26 | import pandas as pd 27 | from tqdm import tqdm 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 32 | parser.add_argument("--data_path", type=str, default="./data") 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def main(): 38 | args = get_args() 39 | ages, img_paths = [], [] 40 | 41 | for filename in tqdm(os.listdir(os.path.join(args.data_path, 'AgeDB'))): 42 | _, _, age, gender = filename.split('.')[0].split('_') 43 | 44 | ages.append(age) 45 | img_paths.append(f"AgeDB/{filename}") 46 | 47 | outputs = dict(age=ages, path=img_paths) 48 | output_dir = os.path.join(args.data_path, "meta") 49 | os.makedirs(output_dir, exist_ok=True) 50 | output_path = os.path.join(output_dir, "agedb.csv") 51 | df = pd.DataFrame(data=outputs) 52 | df.to_csv(str(output_path), index=False) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | 58 | """ 59 | age,path,split 60 | 31,AgeDB/11706_OliviaHussey_31_f.jpg,train 61 | 59,AgeDB/11684_MireilleDarc_59_f.jpg,val 62 | 44,AgeDB/7955_GilbertRoland_44_m.jpg,train 63 | 61,AgeDB/9352_GeorgesMarchal_61_m.jpg,val 64 | 28,AgeDB/3888_TomasMilian_28_m.jpg,val 65 | 8,AgeDB/16107_DannyGlover_8_m.jpg,test 66 | 34,AgeDB/13784_ThelmaRitter_34_f.jpg,train 67 | 74,AgeDB/9945_AliMacGraw_74_f.jpg,train 68 | """ -------------------------------------------------------------------------------- /agedb-dir/data/preprocess_agedb.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | # MIT License 3 | 4 | # Copyright (c) 2021 Yuzhe Yang 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | ######################################################################################## 24 | from os.path import join 25 | import pandas as pd 26 | import matplotlib.pyplot as plt 27 | 28 | 29 | BASE_PATH = './data' 30 | 31 | 32 | def visualize_dataset(db="agedb"): 33 | file_path = join(BASE_PATH, "meta", "agedb.csv") 34 | data = pd.read_csv(file_path) 35 | _, ax = plt.subplots(figsize=(6, 3), sharex='all', sharey='all') 36 | ax.hist(data['age'], range(max(data['age']) + 2)) 37 | # ax.set_xlim([0, 102]) 38 | plt.title(f"{db.upper()} (total: {data.shape[0]})") 39 | plt.tight_layout() 40 | plt.show() 41 | 42 | 43 | def make_balanced_testset(db="agedb", max_size=30, seed=666, verbose=True, vis=True, save=False): 44 | file_path = join(BASE_PATH, "meta", f"{db}.csv") 45 | df = pd.read_csv(file_path) 46 | df['age'] = df.age.astype(int) 47 | val_set, test_set = [], [] 48 | import random 49 | random.seed(seed) 50 | for value in range(121): 51 | curr_df = df[df['age'] == value] 52 | curr_data = curr_df['path'].values 53 | random.shuffle(curr_data) 54 | curr_size = min(len(curr_data) // 3, max_size) 55 | val_set += list(curr_data[:curr_size]) 56 | test_set += list(curr_data[curr_size:curr_size * 2]) 57 | if verbose: 58 | print(f"Val: {len(val_set)}\nTest: {len(test_set)}") 59 | assert len(set(val_set).intersection(set(test_set))) == 0 60 | combined_set = dict(zip(val_set, ['val' for _ in range(len(val_set))])) 61 | combined_set.update(dict(zip(test_set, ['test' for _ in range(len(test_set))]))) 62 | df['split'] = df['path'].map(combined_set) 63 | df['split'].fillna('train', inplace=True) 64 | if verbose: 65 | print(df) 66 | if save: 67 | df.to_csv(str(join(BASE_PATH, f"{db}.csv")), index=False) 68 | if vis: 69 | _, ax = plt.subplots(3, figsize=(6, 9), sharex='all') 70 | df_train = df[df['split'] == 'train'] 71 | ax[0].hist(df_train['age'], range(max(df['age']))) 72 | ax[0].set_title(f"[{db.upper()}] train: {df_train.shape[0]}") 73 | ax[1].hist(df[df['split'] == 'val']['age'], range(max(df['age']))) 74 | ax[1].set_title(f"[{db.upper()}] val: {df[df['split'] == 'val'].shape[0]}") 75 | ax[2].hist(df[df['split'] == 'test']['age'], range(max(df['age']))) 76 | ax[2].set_title(f"[{db.upper()}] test: {df[df['split'] == 'test'].shape[0]}") 77 | ax[0].set_xlim([0, 120]) 78 | plt.tight_layout() 79 | plt.show() 80 | 81 | 82 | if __name__ == '__main__': 83 | make_balanced_testset() 84 | visualize_dataset() 85 | -------------------------------------------------------------------------------- /agedb-dir/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | import os 14 | import logging 15 | import numpy as np 16 | from PIL import Image 17 | from scipy.ndimage import convolve1d 18 | from torch.utils import data 19 | import torchvision.transforms as transforms 20 | 21 | from utils import get_lds_kernel_window 22 | import torch 23 | 24 | 25 | print = logging.info 26 | 27 | 28 | class AgeDB(data.Dataset): 29 | def __init__(self, df, data_dir, img_size, split='', reweight='none', 30 | lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2,aug_plus=False,reg=False,args=None): 31 | self.df = df 32 | self.data_dir = data_dir 33 | self.img_size = img_size 34 | self.split = split 35 | self.aug_plus = aug_plus 36 | self.reg=reg 37 | self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma) 38 | self.args =args 39 | def __len__(self): 40 | return len(self.df) 41 | 42 | def __getitem__(self, index): 43 | index = index % len(self.df) 44 | row = self.df.iloc[index] 45 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB') 46 | transform = self.get_transform() 47 | 48 | label = np.asarray([row['age']]).astype('float32') 49 | weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)]) 50 | 51 | if self.split == 'train': 52 | return transform(img), [transform(img),transform(img)], label, weight 53 | else: 54 | return transform(img), label 55 | 56 | 57 | def get_transform(self): 58 | 59 | reg_aug = transforms.Compose([ 60 | transforms.Resize((self.img_size, self.img_size)), 61 | transforms.RandomCrop(self.img_size, padding=16), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 65 | 66 | ]) 67 | if self.split == 'train': 68 | transform = reg_aug 69 | else: 70 | transform=transforms.Compose([ 71 | transforms.Resize((self.img_size, self.img_size)), 72 | transforms.ToTensor(), 73 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 74 | ]) 75 | 76 | 77 | return transform 78 | 79 | def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): 80 | assert reweight in {'none', 'inverse', 'sqrt_inv'} 81 | assert reweight != 'none' if lds else True, \ 82 | "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS" 83 | 84 | value_dict = {x: 0 for x in range(max_target)} 85 | labels = self.df['age'].values 86 | 87 | for label in labels: 88 | value_dict[min(max_target - 1, int(label))] += 1 89 | if reweight == 'sqrt_inv': 90 | value_dict = {k: np.sqrt(v) for k, v in value_dict.items()} 91 | elif reweight == 'inverse': 92 | value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight 93 | num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels] 94 | if not len(num_per_label) or reweight == 'none': 95 | return None 96 | 97 | if lds: 98 | lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma) 99 | smoothed_value = convolve1d( 100 | np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant') 101 | num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels] 102 | 103 | weights = [np.float32(1 / x) for x in num_per_label] 104 | scaling = len(weights) / np.sum(weights) 105 | weights = [scaling * x for x in weights] 106 | return weights 107 | 108 | 109 | 110 | 111 | 112 | class TwoCropsTransform: 113 | """Take two random crops of one image as the query and key.""" 114 | 115 | def __init__(self, base_transform): 116 | self.base_transform = base_transform 117 | 118 | def __call__(self, x): 119 | q = self.base_transform(x) 120 | k = self.base_transform(x) 121 | return [q, k] 122 | 123 | -------------------------------------------------------------------------------- /agedb-dir/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | import torch 14 | import torch.nn.functional as F 15 | import pickle 16 | 17 | 18 | def weighted_mse_loss(inputs, targets, weights=None): 19 | loss = (inputs - targets) ** 2 20 | if weights is not None: 21 | loss *= weights.expand_as(loss) 22 | loss = torch.mean(loss) 23 | return loss 24 | 25 | 26 | def weighted_l1_loss(inputs, targets, weights=None): 27 | loss = F.l1_loss(inputs, targets, reduction='none') 28 | if weights is not None: 29 | loss *= weights.expand_as(loss) 30 | loss = torch.mean(loss) 31 | return loss 32 | 33 | 34 | def weighted_focal_mse_loss(inputs, targets, activate='sigmoid', beta=.2, gamma=1, weights=None): 35 | loss = (inputs - targets) ** 2 36 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 37 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 38 | if weights is not None: 39 | loss *= weights.expand_as(loss) 40 | loss = torch.mean(loss) 41 | return loss 42 | 43 | 44 | def weighted_focal_l1_loss(inputs, targets, activate='sigmoid', beta=.2, gamma=1, weights=None): 45 | loss = F.l1_loss(inputs, targets, reduction='none') 46 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 47 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 48 | if weights is not None: 49 | loss *= weights.expand_as(loss) 50 | loss = torch.mean(loss) 51 | return loss 52 | 53 | 54 | def weighted_huber_loss(inputs, targets, beta=1., weights=None): 55 | l1_loss = torch.abs(inputs - targets) 56 | cond = l1_loss < beta 57 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) 58 | if weights is not None: 59 | loss *= weights.expand_as(loss) 60 | loss = torch.mean(loss) 61 | return loss 62 | 63 | 64 | def weighted_huber_loss(inputs, targets, beta=1., weights=None): 65 | l1_loss = torch.abs(inputs - targets) 66 | cond = l1_loss < beta 67 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) 68 | if weights is not None: 69 | loss *= weights.expand_as(loss) 70 | loss = torch.mean(loss) 71 | return loss 72 | 73 | 74 | # ConR loss function 75 | def ConR(features,targets,preds,w=1,weights =1,t=0.2,e=0.01): 76 | 77 | t = 0.07 78 | 79 | q = torch.nn.functional.normalize(features, dim=1) 80 | k = torch.nn.functional.normalize(features, dim=1) 81 | 82 | l_k = targets.flatten()[None,:] 83 | l_q = targets 84 | 85 | p_k = preds.flatten()[None,:] 86 | p_q = preds 87 | 88 | # label distance as a coefficient for neg samples 89 | eta = e*weights 90 | 91 | l_dist= torch.abs(l_q - l_k) 92 | p_dist= torch.abs(p_q - p_k) 93 | 94 | 95 | pos_i = l_dist.le(w) 96 | neg_i = ((~ (l_dist.le(w)))*(p_dist.le(w))) 97 | 98 | for i in range(pos_i.shape[0]): 99 | pos_i[i][i] = 0 100 | 101 | prod = torch.einsum("nc,kc->nk", [q, k])/t 102 | pos = prod * pos_i 103 | neg = prod * neg_i 104 | 105 | pushing_w = weights*torch.exp(l_dist*e) 106 | neg_exp_dot=(pushing_w*(torch.exp(neg))*neg_i).sum(1) 107 | 108 | # For each query sample, if there is no negative pair, zero-out the loss. 109 | no_neg_flag = (neg_i).sum(1).bool() 110 | 111 | # Loss = sum over all samples in the batch (sum over (positive dot product/(negative dot product+positive dot product))) 112 | denom=pos_i.sum(1) 113 | 114 | loss = ((-torch.log(torch.div(torch.exp(pos),(torch.exp(pos).sum(1) + neg_exp_dot).unsqueeze(-1)))*(pos_i)).sum(1)/denom) 115 | 116 | loss = (weights*(loss*no_neg_flag).unsqueeze(-1)).mean() 117 | 118 | 119 | 120 | return loss 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /agedb-dir/ranking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-present, Royal Bank of Canada. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | 11 | def rank(seq): 12 | return torch.argsort(torch.argsort(seq).flip(1)) 13 | 14 | 15 | def rank_normalised(seq): 16 | return (rank(seq) + 1).float() / seq.size()[1] 17 | 18 | 19 | class TrueRanker(torch.autograd.Function): 20 | @staticmethod 21 | def forward(ctx, sequence, lambda_val): 22 | rank = rank_normalised(sequence) 23 | ctx.lambda_val = lambda_val 24 | ctx.save_for_backward(sequence, rank) 25 | return rank 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | sequence, rank = ctx.saved_tensors 30 | assert grad_output.shape == rank.shape 31 | sequence_prime = sequence + ctx.lambda_val * grad_output 32 | rank_prime = rank_normalised(sequence_prime) 33 | gradient = -(rank - rank_prime) / (ctx.lambda_val + 1e-8) 34 | return gradient, None 35 | -------------------------------------------------------------------------------- /agedb-dir/ranksim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-present, Royal Bank of Canada. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import random 10 | import torch.nn.functional as F 11 | 12 | from ranking import TrueRanker, rank_normalised 13 | 14 | 15 | def batchwise_ranking_regularizer(features, targets, lambda_val): 16 | loss = 0 17 | 18 | # Reduce ties and boost relative representation of infrequent labels by computing the 19 | # regularizer over a subset of the batch in which each label appears at most once 20 | batch_unique_targets = torch.unique(targets) 21 | if len(batch_unique_targets) < len(targets): 22 | sampled_indices = [] 23 | for target in batch_unique_targets: 24 | sampled_indices.append(random.choice((targets == target).nonzero()[:,0]).item()) 25 | x = features[sampled_indices] 26 | y = targets[sampled_indices] 27 | else: 28 | x = features 29 | y = targets 30 | 31 | # Compute feature similarities 32 | xxt = torch.matmul(F.normalize(x.view(x.size(0),-1)), F.normalize(x.view(x.size(0),-1)).permute(1,0)) 33 | 34 | # Compute ranking loss 35 | for i in range(len(y)): 36 | label_ranks = rank_normalised(-torch.abs(y[i] - y).transpose(0,1)) 37 | feature_ranks = TrueRanker.apply(xxt[i].unsqueeze(dim=0), lambda_val) 38 | loss += F.mse_loss(feature_ranks, label_ranks) 39 | 40 | return loss -------------------------------------------------------------------------------- /agedb-dir/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | import math 14 | 15 | import numpy as np 16 | from scipy.ndimage import gaussian_filter1d 17 | from scipy.signal.windows import triang 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from utils import calibrate_mean_var 24 | 25 | import logging 26 | 27 | 28 | print = logging.info 29 | 30 | class FDS(nn.Module): 31 | 32 | def __init__(self, feature_dim, bucket_num=100, bucket_start=3, start_update=0, start_smooth=1, 33 | kernel='gaussian', ks=5, sigma=2, momentum=0.9): 34 | super(FDS, self).__init__() 35 | self.feature_dim = feature_dim 36 | self.bucket_num = bucket_num 37 | self.bucket_start = bucket_start 38 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma) 39 | self.half_ks = (ks - 1) // 2 40 | self.momentum = momentum 41 | self.start_update = start_update 42 | self.start_smooth = start_smooth 43 | 44 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update)) 45 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim)) 46 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim)) 47 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 48 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 49 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 50 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 51 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start)) 52 | 53 | @staticmethod 54 | def _get_kernel_window(kernel, ks, sigma): 55 | assert kernel in ['gaussian', 'triang', 'laplace'] 56 | half_ks = (ks - 1) // 2 57 | if kernel == 'gaussian': 58 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 59 | base_kernel = np.array(base_kernel, dtype=np.float32) 60 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma)) 61 | elif kernel == 'triang': 62 | kernel_window = triang(ks) / sum(triang(ks)) 63 | else: 64 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 65 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1))) 66 | 67 | print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') 68 | return torch.tensor(kernel_window, dtype=torch.float32).cuda() 69 | 70 | def _update_last_epoch_stats(self): 71 | self.running_mean_last_epoch = self.running_mean 72 | self.running_var_last_epoch = self.running_var 73 | 74 | self.smoothed_mean_last_epoch = F.conv1d( 75 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), 76 | pad=(self.half_ks, self.half_ks), mode='reflect'), 77 | weight=self.kernel_window.view(1, 1, -1), padding=0 78 | ).permute(2, 1, 0).squeeze(1) 79 | self.smoothed_var_last_epoch = F.conv1d( 80 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), 81 | pad=(self.half_ks, self.half_ks), mode='reflect'), 82 | weight=self.kernel_window.view(1, 1, -1), padding=0 83 | ).permute(2, 1, 0).squeeze(1) 84 | 85 | def reset(self): 86 | self.running_mean.zero_() 87 | self.running_var.fill_(1) 88 | self.running_mean_last_epoch.zero_() 89 | self.running_var_last_epoch.fill_(1) 90 | self.smoothed_mean_last_epoch.zero_() 91 | self.smoothed_var_last_epoch.fill_(1) 92 | self.num_samples_tracked.zero_() 93 | 94 | def update_last_epoch_stats(self, epoch): 95 | if epoch == self.epoch + 1: 96 | self.epoch += 1 97 | self._update_last_epoch_stats() 98 | print(f"Updated smoothed statistics on Epoch [{epoch}]!") 99 | 100 | def update_running_stats(self, features, labels, epoch): 101 | if epoch < self.epoch: 102 | return 103 | 104 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" 105 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" 106 | 107 | for label in torch.unique(labels): 108 | if label > self.bucket_num - 1 or label < self.bucket_start: 109 | continue 110 | elif label == self.bucket_start: 111 | curr_feats = features[labels <= label] 112 | elif label == self.bucket_num - 1: 113 | curr_feats = features[labels >= label] 114 | else: 115 | curr_feats = features[labels == label] 116 | curr_num_sample = curr_feats.size(0) 117 | curr_mean = torch.mean(curr_feats, 0) 118 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) 119 | 120 | self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample 121 | factor = self.momentum if self.momentum is not None else \ 122 | (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)])) 123 | factor = 0 if epoch == self.start_update else factor 124 | self.running_mean[int(label - self.bucket_start)] = \ 125 | (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)] 126 | self.running_var[int(label - self.bucket_start)] = \ 127 | (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)] 128 | 129 | print(f"Updated running statistics with Epoch [{epoch}] features!") 130 | 131 | def smooth(self, features, labels, epoch): 132 | if epoch < self.start_smooth: 133 | return features 134 | 135 | labels = labels.squeeze(1) 136 | for label in torch.unique(labels): 137 | if label > self.bucket_num - 1 or label < self.bucket_start: 138 | continue 139 | elif label == self.bucket_start: 140 | features[labels <= label] = calibrate_mean_var( 141 | features[labels <= label], 142 | self.running_mean_last_epoch[int(label - self.bucket_start)], 143 | self.running_var_last_epoch[int(label - self.bucket_start)], 144 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 145 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 146 | elif label == self.bucket_num - 1: 147 | features[labels >= label] = calibrate_mean_var( 148 | features[labels >= label], 149 | self.running_mean_last_epoch[int(label - self.bucket_start)], 150 | self.running_var_last_epoch[int(label - self.bucket_start)], 151 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 152 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 153 | else: 154 | features[labels == label] = calibrate_mean_var( 155 | features[labels == label], 156 | self.running_mean_last_epoch[int(label - self.bucket_start)], 157 | self.running_var_last_epoch[int(label - self.bucket_start)], 158 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 159 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 160 | return features 161 | 162 | 163 | def conv3x3(in_planes, out_planes, stride=1): 164 | """3x3 convolution with padding""" 165 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 166 | 167 | 168 | class BasicBlock(nn.Module): 169 | expansion = 1 170 | 171 | def __init__(self, inplanes, planes, stride=1, downsample=None): 172 | super(BasicBlock, self).__init__() 173 | self.conv1 = conv3x3(inplanes, planes, stride) 174 | self.bn1 = nn.BatchNorm2d(planes) 175 | self.relu = nn.ReLU(inplace=True) 176 | self.conv2 = conv3x3(planes, planes) 177 | self.bn2 = nn.BatchNorm2d(planes) 178 | self.downsample = downsample 179 | self.stride = stride 180 | 181 | def forward(self, x): 182 | residual = x 183 | out = self.conv1(x) 184 | out = self.bn1(out) 185 | out = self.relu(out) 186 | out = self.conv2(out) 187 | out = self.bn2(out) 188 | if self.downsample is not None: 189 | residual = self.downsample(x) 190 | out += residual 191 | out = self.relu(out) 192 | return out 193 | 194 | 195 | class Bottleneck(nn.Module): 196 | expansion = 4 197 | 198 | def __init__(self, inplanes, planes, stride=1, downsample=None): 199 | super(Bottleneck, self).__init__() 200 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 201 | self.bn1 = nn.BatchNorm2d(planes) 202 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 203 | self.bn2 = nn.BatchNorm2d(planes) 204 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 205 | self.bn3 = nn.BatchNorm2d(planes * 4) 206 | self.relu = nn.ReLU(inplace=True) 207 | self.downsample = downsample 208 | self.stride = stride 209 | 210 | def forward(self, x): 211 | residual = x 212 | out = self.conv1(x) 213 | out = self.bn1(out) 214 | out = self.relu(out) 215 | out = self.conv2(out) 216 | out = self.bn2(out) 217 | out = self.relu(out) 218 | out = self.conv3(out) 219 | out = self.bn3(out) 220 | if self.downsample is not None: 221 | residual = self.downsample(x) 222 | out += residual 223 | out = self.relu(out) 224 | return out 225 | 226 | 227 | class ResNet(nn.Module): 228 | 229 | def __init__(self, block, layers, fds, bucket_num, bucket_start, start_update, start_smooth, 230 | kernel, ks, sigma, momentum, dropout=None, return_features=False): 231 | self.inplanes = 64 232 | super(ResNet, self).__init__() 233 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 234 | self.bn1 = nn.BatchNorm2d(64) 235 | self.relu = nn.ReLU(inplace=True) 236 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 237 | self.layer1 = self._make_layer(block, 64, layers[0]) 238 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 239 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 240 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 241 | self.avgpool = nn.AvgPool2d(7, stride=1) 242 | self.linear = nn.Linear(512 * block.expansion, 1) 243 | 244 | if fds: 245 | self.FDS = FDS( 246 | feature_dim=512 * block.expansion, bucket_num=bucket_num, bucket_start=bucket_start, 247 | start_update=start_update, start_smooth=start_smooth, kernel=kernel, ks=ks, sigma=sigma, momentum=momentum 248 | ) 249 | self.fds = fds 250 | self.start_smooth = start_smooth 251 | self.return_features = return_features 252 | 253 | self.use_dropout = True if dropout else False 254 | if self.use_dropout: 255 | print(f'Using dropout: {dropout}') 256 | self.dropout = nn.Dropout(p=dropout) 257 | 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, nn.BatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | 266 | def _make_layer(self, block, planes, blocks, stride=1): 267 | downsample = None 268 | if stride != 1 or self.inplanes != planes * block.expansion: 269 | downsample = nn.Sequential( 270 | nn.Conv2d(self.inplanes, planes * block.expansion, 271 | kernel_size=1, stride=stride, bias=False), 272 | nn.BatchNorm2d(planes * block.expansion), 273 | ) 274 | layers = [] 275 | layers.append(block(self.inplanes, planes, stride, downsample)) 276 | self.inplanes = planes * block.expansion 277 | for i in range(1, blocks): 278 | layers.append(block(self.inplanes, planes)) 279 | 280 | return nn.Sequential(*layers) 281 | 282 | def forward(self, x, targets=None, epoch=None,reg = True): 283 | x = self.conv1(x) 284 | x = self.bn1(x) 285 | x = self.relu(x) 286 | x = self.maxpool(x) 287 | 288 | x = self.layer1(x) 289 | x = self.layer2(x) 290 | x = self.layer3(x) 291 | x = self.layer4(x) 292 | x = self.avgpool(x) 293 | encoding = x.view(x.size(0), -1) 294 | 295 | encoding_s = encoding 296 | 297 | if self.training and self.fds and reg: 298 | if epoch >= self.start_smooth: 299 | encoding_s = self.FDS.smooth(encoding_s, targets, epoch) 300 | 301 | if self.use_dropout: 302 | encoding_s = self.dropout(encoding_s) 303 | x = self.linear(encoding_s) 304 | 305 | 306 | return x, encoding 307 | 308 | 309 | def resnet50(**kwargs): 310 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) -------------------------------------------------------------------------------- /agedb-dir/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | 14 | import os 15 | import shutil 16 | import torch 17 | import logging 18 | import numpy as np 19 | import pandas as pd 20 | from scipy.ndimage import gaussian_filter1d 21 | from scipy.signal.windows import triang 22 | import math 23 | from collections import defaultdict 24 | from scipy.stats import gmean 25 | 26 | import seaborn as sns 27 | import matplotlib.pyplot as plt 28 | from sklearn.manifold import TSNE 29 | from scipy.ndimage.filters import gaussian_filter 30 | import random 31 | import matplotlib 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | def query_yes_no(question): 42 | """ Ask a yes/no question via input() and return their answer. """ 43 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 44 | prompt = " [Y/n] " 45 | 46 | while True: 47 | print(question + prompt, end=':') 48 | choice = input().lower() 49 | if choice == '': 50 | return valid['y'] 51 | elif choice in valid: 52 | return valid[choice] 53 | else: 54 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") 55 | 56 | 57 | def prepare_folders(args): 58 | folders_util = [args.store_root, os.path.join(args.store_root, args.store_name)] 59 | if os.path.exists(folders_util[-1]) and not args.resume and not args.pretrained and not args.evaluate: 60 | if query_yes_no('overwrite previous folder: {} ?'.format(folders_util[-1])): 61 | shutil.rmtree(folders_util[-1]) 62 | print(folders_util[-1] + ' removed.') 63 | else: 64 | raise RuntimeError('Output folder {} already exists'.format(folders_util[-1])) 65 | for folder in folders_util: 66 | if not os.path.exists(folder): 67 | print(f"===> Creating folder: {folder}") 68 | os.mkdir(folder) 69 | 70 | 71 | 72 | 73 | 74 | def save_checkpoint(args, state, is_best, prefix=''): 75 | filename = f"{args.store_root}/{args.store_name}/{prefix}ckpt.pth.tar" 76 | torch.save(state, filename) 77 | if is_best: 78 | logging.info("===> Saving current best checkpoint...") 79 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 80 | 81 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.1, clip_max=10): 82 | if torch.sum(v1) < 1e-10: 83 | return matrix 84 | if (v1 == 0.).any(): 85 | valid = (v1 != 0.) 86 | factor = torch.clamp(v2[valid] / v1[valid], clip_min, clip_max) 87 | matrix[:, valid] = (matrix[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid] 88 | return matrix 89 | 90 | factor = torch.clamp(v2 / v1, clip_min, clip_max) 91 | return (matrix - m1) * torch.sqrt(factor) + m2 92 | 93 | 94 | 95 | 96 | def get_lds_kernel_window(kernel, ks, sigma): 97 | assert kernel in ['gaussian', 'triang', 'laplace'] 98 | half_ks = (ks - 1) // 2 99 | if kernel == 'gaussian': 100 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 101 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) 102 | elif kernel == 'triang': 103 | kernel_window = triang(ks) 104 | else: 105 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 106 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) 107 | 108 | return kernel_window 109 | 110 | 111 | 112 | def shot_metrics(preds, labels, train_labels, many_shot_thr=100, low_shot_thr=20): 113 | train_labels = np.array(train_labels).astype(int) 114 | 115 | if isinstance(preds, torch.Tensor): 116 | preds = preds.detach().cpu().numpy() 117 | labels = labels.detach().cpu().numpy() 118 | elif isinstance(preds, np.ndarray): 119 | pass 120 | else: 121 | raise TypeError(f'Type ({type(preds)}) of predictions not supported') 122 | 123 | train_class_count, test_class_count = [], [] 124 | mse_per_class, l1_per_class, l1_all_per_class = [], [], [] 125 | for l in np.unique(labels): 126 | train_class_count.append(len(train_labels[train_labels == l])) 127 | test_class_count.append(len(labels[labels == l])) 128 | mse_per_class.append(np.sum((preds[labels == l] - labels[labels == l]) ** 2)) 129 | l1_per_class.append(np.sum(np.abs(preds[labels == l] - labels[labels == l]))) 130 | l1_all_per_class.append(np.abs(preds[labels == l] - labels[labels == l])) 131 | 132 | many_shot_mse, median_shot_mse, low_shot_mse = [], [], [] 133 | many_shot_l1, median_shot_l1, low_shot_l1 = [], [], [] 134 | many_shot_gmean, median_shot_gmean, low_shot_gmean = [], [], [] 135 | many_shot_cnt, median_shot_cnt, low_shot_cnt = [], [], [] 136 | 137 | for i in range(len(train_class_count)): 138 | if train_class_count[i] > many_shot_thr: 139 | many_shot_mse.append(mse_per_class[i]) 140 | many_shot_l1.append(l1_per_class[i]) 141 | many_shot_gmean += list(l1_all_per_class[i]) 142 | many_shot_cnt.append(test_class_count[i]) 143 | elif train_class_count[i] < low_shot_thr: 144 | low_shot_mse.append(mse_per_class[i]) 145 | low_shot_l1.append(l1_per_class[i]) 146 | low_shot_gmean += list(l1_all_per_class[i]) 147 | low_shot_cnt.append(test_class_count[i]) 148 | else: 149 | median_shot_mse.append(mse_per_class[i]) 150 | median_shot_l1.append(l1_per_class[i]) 151 | median_shot_gmean += list(l1_all_per_class[i]) 152 | median_shot_cnt.append(test_class_count[i]) 153 | 154 | shot_dict = defaultdict(dict) 155 | shot_dict['many']['mse'] = np.sum(many_shot_mse) / np.sum(many_shot_cnt) 156 | shot_dict['many']['l1'] = np.sum(many_shot_l1) / np.sum(many_shot_cnt) 157 | shot_dict['many']['gmean'] = gmean(np.hstack(many_shot_gmean), axis=None).astype(float) 158 | shot_dict['median']['mse'] = np.sum(median_shot_mse) / np.sum(median_shot_cnt) 159 | shot_dict['median']['l1'] = np.sum(median_shot_l1) / np.sum(median_shot_cnt) 160 | shot_dict['median']['gmean'] = gmean(np.hstack(median_shot_gmean), axis=None).astype(float) 161 | shot_dict['low']['mse'] = np.sum(low_shot_mse) / np.sum(low_shot_cnt) 162 | shot_dict['low']['l1'] = np.sum(low_shot_l1) / np.sum(low_shot_cnt) 163 | shot_dict['low']['gmean'] = gmean(np.hstack(low_shot_gmean), axis=None).astype(float) 164 | 165 | return shot_dict 166 | 167 | class AverageMeter(object): 168 | """Computes and stores the average and current value""" 169 | 170 | def __init__(self, name, fmt=":f"): 171 | self.name = name 172 | self.fmt = fmt 173 | self.reset() 174 | 175 | def reset(self): 176 | self.val = 0 177 | self.avg = 0 178 | self.sum = 0 179 | self.count = 0 180 | 181 | def update(self, val, n=1): 182 | self.val = val 183 | self.sum += val * n 184 | self.count += n 185 | self.avg = self.sum / self.count 186 | 187 | def __str__(self): 188 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 189 | return fmtstr.format(**self.__dict__) 190 | 191 | 192 | class ProgressMeter(object): 193 | def __init__(self, num_batches, meters, prefix=""): 194 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 195 | self.meters = meters 196 | self.prefix = prefix 197 | 198 | def display(self, batch): 199 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 200 | entries += [str(meter) for meter in self.meters] 201 | logging.info('\t'.join(entries)) 202 | 203 | def _get_batch_fmtstr(self, num_batches): 204 | num_digits = len(str(num_batches // 1)) 205 | fmt = "{:" + str(num_digits) + "d}" 206 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 207 | 208 | def adjust_learning_rate(optimizer, epoch, args): 209 | 210 | """Decay the learning rate based on schedule""" 211 | lr = args.lr 212 | # stepwise lr schedule 213 | for milestone in args.schedule: 214 | lr *= 0.1 if epoch >= milestone else 1.0 215 | for param_group in optimizer.param_groups: 216 | param_group["lr"] = lr 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /imdb-wiki-dir/README.md: -------------------------------------------------------------------------------- 1 | # ConR on IMDB-WIKI-DIR 2 | This repository contains the implementation of __ConR__ on *IMDB-WIKI-DIR* dataset. 3 | 4 | The imbalanced regression framework and LDS+FDS are based on the public repository of [Gong et al., ICML 2022](https://github.com/BorealisAI/ranksim-imbalanced-regression). 5 | 6 | 7 | 8 | ## Installation 9 | 10 | #### Prerequisites 11 | 12 | 1. Download and extract IMDB faces and WIKI faces respectively using 13 | 14 | ```bash 15 | python download_imdb_wiki.py 16 | ``` 17 | 18 | 2. We use the standard train/val/test split file (`imdb_wiki.csv` in folder `./data`) provided by Yang et al.(ICML 2021), which is used to set up balanced val/test set. To reproduce the results in the paper, please directly use this file. You can also generate it using 19 | 20 | ```bash 21 | python data/create_imdb_wiki.py 22 | python data/preprocess_imdb_wiki.py 23 | ``` 24 | 25 | #### Dependencies 26 | 27 | - PyTorch (>= 1.2, tested on 1.6) 28 | - numpy, pandas, scipy, tqdm, matplotlib, PIL, wget 29 | 30 | ## Code Overview 31 | 32 | #### Main Files 33 | 34 | - `train.py`: main training and evaluation script 35 | - `create_imdb_wiki.py`: create IMDB-WIKI raw meta data 36 | - `preprocess_imdb_wiki.py`: create IMDB-WIKI-DIR meta file `imdb_wiki.csv` with balanced val/test set 37 | 38 | #### Main Arguments 39 | 40 | - `--data_dir`: data directory to place data and meta file 41 | - `--reweight`: cost-sensitive re-weighting scheme to use 42 | - `--loss`: training loss type 43 | - `--conr`: wether to use ConR or not. 44 | - `-w`: distance threshold (default 1.0) 45 | - `--beta`: the scale of ConR loss (default 4.0) 46 | - `-t`: temperature(default 0.2) 47 | - `-e`: pushing power scale(default 0.01) 48 | 49 | ## Getting Started 50 | 51 | ### 1. Train baselines 52 | 53 | To use Vanilla model 54 | 55 | ```bash 56 | python train.py --batch_size 64 --lr 2.5e-4 57 | ``` 58 | 59 | 60 | 61 | ### 2. Train a model with ConR 62 | ##### batch size 64, learning rate 2.5e-4 63 | 64 | ```bash 65 | python train.py --batch_size 64 --lr 2.5e-4 --conr -w 1.0 --beta 4.0 -e 0.01 66 | ``` 67 | 68 | 69 | 70 | ### 3. Evaluate and reproduce 71 | 72 | If you do not train the model, you can evaluate the model and reproduce our results directly using the pretrained weights from the anonymous links below. 73 | 74 | ```bash 75 | python train.py --evaluate [...evaluation model arguments...] --resume 76 | ``` 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /imdb-wiki-dir/data/create_imd_wiki.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | # MIT License 3 | 4 | # Copyright (c) 2021 Yuzhe Yang 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | ######################################################################################## 24 | import os 25 | import argparse 26 | import numpy as np 27 | import pandas as pd 28 | from tqdm import tqdm 29 | from scipy.io import loadmat 30 | from datetime import datetime 31 | 32 | 33 | def calc_age(taken, dob): 34 | birth = datetime.fromordinal(max(int(dob) - 366, 1)) 35 | # assume the photo was taken in the middle of the year 36 | if birth.month < 7: 37 | return taken - birth.year 38 | else: 39 | return taken - birth.year - 1 40 | 41 | 42 | def get_meta(mat_path, db): 43 | meta = loadmat(mat_path) 44 | full_path = meta[db][0, 0]["full_path"][0] 45 | dob = meta[db][0, 0]["dob"][0] # date 46 | gender = meta[db][0, 0]["gender"][0] 47 | photo_taken = meta[db][0, 0]["photo_taken"][0] # year 48 | face_score = meta[db][0, 0]["face_score"][0] 49 | second_face_score = meta[db][0, 0]["second_face_score"][0] 50 | age = [calc_age(photo_taken[i], dob[i]) for i in range(len(dob))] 51 | 52 | return full_path, dob, gender, photo_taken, face_score, second_face_score, age 53 | 54 | 55 | def load_data(mat_path): 56 | d = loadmat(mat_path) 57 | return d["image"], d["gender"][0], d["age"][0], d["db"][0], d["img_size"][0, 0], d["min_score"][0, 0] 58 | 59 | 60 | def combine_dataset(path='meta'): 61 | args = get_args() 62 | data_imdb = pd.read_csv(os.path.join(args.data_path, path, "imdb.csv")) 63 | data_wiki = pd.read_csv(os.path.join(args.data_path, path, "wiki.csv")) 64 | data_imdb['path'] = data_imdb['path'].apply(lambda x: f"imdb_crop/{x}") 65 | data_wiki['path'] = data_wiki['path'].apply(lambda x: f"wiki_crop/{x}") 66 | df = pd.concat((data_imdb, data_wiki)) 67 | output_path = os.path.join(args.data_path, path, "imdb_wiki.csv") 68 | df.to_csv(str(output_path), index=False) 69 | 70 | 71 | def get_args(): 72 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 73 | parser.add_argument("--data_path", type=str, default="./data") 74 | parser.add_argument("--min_score", type=float, default=1., help="minimum face score") 75 | args = parser.parse_args() 76 | return args 77 | 78 | 79 | def create(db): 80 | args = get_args() 81 | mat_path = os.path.join(args.data_path, f"{db}_crop", f"{db}.mat") 82 | full_path, dob, gender, photo_taken, face_score, second_face_score, age = get_meta(mat_path, db) 83 | 84 | ages, img_paths = [], [] 85 | 86 | for i in tqdm(range(len(face_score))): 87 | if face_score[i] < args.min_score: 88 | continue 89 | 90 | if (~np.isnan(second_face_score[i])) and second_face_score[i] > 0.0: 91 | continue 92 | 93 | if ~(0 <= age[i] <= 200): 94 | continue 95 | 96 | ages.append(age[i]) 97 | img_paths.append(full_path[i][0]) 98 | 99 | outputs = dict(age=ages, path=img_paths) 100 | output_dir = os.path.join(args.data_path, "meta") 101 | os.makedirs(output_dir, exist_ok=True) 102 | output_path = os.path.join(output_dir, f"{db}.csv") 103 | df = pd.DataFrame(data=outputs) 104 | df.to_csv(str(output_path), index=False) 105 | 106 | 107 | if __name__ == '__main__': 108 | create("imdb") 109 | create("wiki") 110 | combine_dataset() 111 | -------------------------------------------------------------------------------- /imdb-wiki-dir/data/download_imdb_wiki.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | # MIT License 3 | 4 | # Copyright (c) 2021 Yuzhe Yang 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | ######################################################################################## 24 | import os 25 | import wget 26 | 27 | print("Downloading IMDB faces...") 28 | imdb_file = "imdb_crop.tar" 29 | wget.download("https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/imdb_crop.tar", out=imdb_file) 30 | print("Downloading WIKI faces...") 31 | wiki_file = "wiki_crop.tar" 32 | wget.download("https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar", out=wiki_file) 33 | print("Extracting IMDB faces...") 34 | os.system(f"tar -xvf {imdb_file} -C ./data") 35 | print("Extracting WIKI faces...") 36 | os.system(f"tar -xvf {wiki_file} -C ./data") 37 | os.remove(imdb_file) 38 | os.remove(wiki_file) 39 | print("\nCompleted!") -------------------------------------------------------------------------------- /imdb-wiki-dir/data/preprocess_imdb_wiki.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | # MIT License 3 | 4 | # Copyright (c) 2021 Yuzhe Yang 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | ######################################################################################## 24 | from os.path import join 25 | import pandas as pd 26 | import matplotlib.pyplot as plt 27 | 28 | 29 | BASE_PATH = './' 30 | 31 | 32 | def visualize_dataset(db="imdb_wiki"): 33 | file_path = join(BASE_PATH, f"{db}.csv") 34 | data = pd.read_csv(file_path) 35 | _, ax = plt.subplots(figsize=(6, 3), sharex='all', sharey='all') 36 | ax.hist(data['age'], range(max(data['age']))) 37 | ax.set_xlim([0, 120]) 38 | plt.title(f"{db.upper()} (total: {data.shape[0]})") 39 | plt.tight_layout() 40 | plt.show() 41 | 42 | 43 | def make_balanced_testset(db="imdb_wiki", max_size=150, seed=666, verbose=True, vis=True, save=True): 44 | file_path = join(BASE_PATH, f"{db}.csv") 45 | df = pd.read_csv(file_path) 46 | df['age'] = df.age.astype(int) 47 | val_set, test_set = [], [] 48 | import random 49 | random.seed(seed) 50 | for value in range(121): 51 | curr_df = df[df['age'] == value] 52 | curr_data = curr_df['path'].values 53 | random.shuffle(curr_data) 54 | curr_size = min(len(curr_data) // 5, max_size) 55 | val_set += list(curr_data[:curr_size]) 56 | test_set += list(curr_data[curr_size:curr_size * 2]) 57 | if verbose: 58 | print(f"Val: {len(val_set)}\nTest: {len(test_set)}") 59 | assert len(set(val_set).intersection(set(test_set))) == 0 60 | combined_set = dict(zip(val_set, ['val' for _ in range(len(val_set))])) 61 | combined_set.update(dict(zip(test_set, ['test' for _ in range(len(test_set))]))) 62 | df['split'] = df['path'].map(combined_set) 63 | df['split'].fillna('train', inplace=True) 64 | if verbose: 65 | print(df) 66 | if save: 67 | df.to_csv(str(join(BASE_PATH, f"{db}.csv")), index=False) 68 | if vis: 69 | _, ax = plt.subplots(3, figsize=(6, 9), sharex='all') 70 | df_train = df[df['split'] == 'train'] 71 | # df_train = df_train[(df_train['age'] <= 20) | (df_train['age'] > 50)] 72 | ax[0].hist(df_train['age'], range(max(df['age']))) 73 | ax[0].set_title(f"[{db.upper()}] train: {df_train.shape[0]}") 74 | ax[1].hist(df[df['split'] == 'val']['age'], range(max(df['age']))) 75 | ax[1].set_title(f"[{db.upper()}] val: {df[df['split'] == 'val'].shape[0]}") 76 | ax[2].hist(df[df['split'] == 'test']['age'], range(max(df['age']))) 77 | ax[2].set_title(f"[{db.upper()}] test: {df[df['split'] == 'test'].shape[0]}") 78 | ax[0].set_xlim([0, 120]) 79 | plt.tight_layout() 80 | plt.show() 81 | 82 | 83 | if __name__ == '__main__': 84 | make_balanced_testset() 85 | visualize_dataset(db="imdb_wiki") 86 | -------------------------------------------------------------------------------- /imdb-wiki-dir/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | import os 14 | import logging 15 | import numpy as np 16 | from PIL import Image 17 | from scipy.ndimage import convolve1d 18 | from torch.utils import data 19 | import torchvision.transforms as transforms 20 | 21 | from utils import get_lds_kernel_window 22 | 23 | 24 | print = logging.info 25 | 26 | 27 | class IMDBWIKI(data.Dataset): 28 | def __init__(self, df, data_dir, img_size, split='train', reweight='none', 29 | lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2,args=None): 30 | self.df = df 31 | self.data_dir = data_dir 32 | self.img_size = img_size 33 | self.split = split 34 | 35 | self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma) 36 | 37 | def __len__(self): 38 | return len(self.df) 39 | 40 | def __getitem__(self, index): 41 | index = index % len(self.df) 42 | row = self.df.iloc[index] 43 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB') 44 | transform = self.get_transform() 45 | 46 | label = np.asarray([row['age']]).astype('float32') 47 | weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)]) 48 | 49 | if self.split == 'train': 50 | return transform(img), [transform(img),transform(img)], label, weight 51 | else: 52 | return transform(img), label 53 | 54 | def get_transform(self): 55 | 56 | reg_aug = transforms.Compose([ 57 | transforms.Resize((self.img_size, self.img_size)), 58 | transforms.RandomCrop(self.img_size, padding=16), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 62 | 63 | ]) 64 | if self.split == 'train': 65 | # transform['reg']=transforms.Compose(augmentation) 66 | transform = reg_aug 67 | else: 68 | transform=transforms.Compose([ 69 | transforms.Resize((self.img_size, self.img_size)), 70 | transforms.ToTensor(), 71 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 72 | 73 | ]) 74 | return transform 75 | 76 | def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): 77 | assert reweight in {'none', 'inverse', 'sqrt_inv'} 78 | assert reweight != 'none' if lds else True, \ 79 | "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS" 80 | 81 | value_dict = {x: 0 for x in range(max_target)} 82 | labels = self.df['age'].values 83 | for label in labels: 84 | value_dict[min(max_target - 1, int(label))] += 1 85 | if reweight == 'sqrt_inv': 86 | value_dict = {k: np.sqrt(v) for k, v in value_dict.items()} 87 | elif reweight == 'inverse': 88 | value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight 89 | num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels] 90 | if not len(num_per_label) or reweight == 'none': 91 | return None 92 | print(f"Using re-weighting: [{reweight.upper()}]") 93 | 94 | if lds: 95 | lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma) 96 | print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})') 97 | smoothed_value = convolve1d( 98 | np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant') 99 | num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels] 100 | 101 | weights = [np.float32(1 / x) for x in num_per_label] 102 | scaling = len(weights) / np.sum(weights) 103 | weights = [scaling * x for x in weights] 104 | return weights 105 | -------------------------------------------------------------------------------- /imdb-wiki-dir/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | 17 | def weighted_mse_loss(inputs, targets, weights=None): 18 | loss = (inputs - targets) ** 2 19 | if weights is not None: 20 | loss *= weights.expand_as(loss) 21 | loss = torch.mean(loss) 22 | return loss 23 | 24 | 25 | def weighted_l1_loss(inputs, targets, weights=None): 26 | loss = F.l1_loss(inputs, targets, reduction='none') 27 | if weights is not None: 28 | loss *= weights.expand_as(loss) 29 | loss = torch.mean(loss) 30 | return loss 31 | 32 | 33 | def weighted_focal_mse_loss(inputs, targets, activate='sigmoid', beta=.2, gamma=1, weights=None): 34 | loss = (inputs - targets) ** 2 35 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 36 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 37 | if weights is not None: 38 | loss *= weights.expand_as(loss) 39 | loss = torch.mean(loss) 40 | return loss 41 | 42 | 43 | def weighted_focal_l1_loss(inputs, targets, activate='sigmoid', beta=.2, gamma=1, weights=None): 44 | loss = F.l1_loss(inputs, targets, reduction='none') 45 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 46 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 47 | if weights is not None: 48 | loss *= weights.expand_as(loss) 49 | loss = torch.mean(loss) 50 | return loss 51 | 52 | 53 | def weighted_huber_loss(inputs, targets, beta=1., weights=None): 54 | l1_loss = torch.abs(inputs - targets) 55 | cond = l1_loss < beta 56 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) 57 | if weights is not None: 58 | loss *= weights.expand_as(loss) 59 | loss = torch.mean(loss) 60 | return loss 61 | 62 | # ConR loss function 63 | def ConR(features,targets,preds,w=1,weights =1,t=0.07,e=0.01): 64 | 65 | 66 | 67 | q = torch.nn.functional.normalize(features, dim=1) 68 | k = torch.nn.functional.normalize(features, dim=1) 69 | 70 | l_k = targets.flatten()[None,:] 71 | l_q = targets 72 | 73 | p_k = preds.flatten()[None,:] 74 | p_q = preds 75 | 76 | 77 | l_dist= torch.abs(l_q - l_k) 78 | p_dist= torch.abs(p_q - p_k) 79 | 80 | 81 | pos_i = l_dist.le(w) 82 | neg_i = ((~ (l_dist.le(w)))*(p_dist.le(w))) 83 | 84 | for i in range(pos_i.shape[0]): 85 | pos_i[i][i] = 0 86 | 87 | prod = torch.einsum("nc,kc->nk", [q, k])/t 88 | pos = prod * pos_i 89 | neg = prod * neg_i 90 | 91 | pushing_w = weights*torch.exp(l_dist*e) 92 | neg_exp_dot=(pushing_w*(torch.exp(neg))*neg_i).sum(1) 93 | 94 | # For each query sample, if there is no negative pair, zero-out the loss. 95 | no_neg_flag = (neg_i).sum(1).bool() 96 | 97 | # Loss = sum over all samples in the batch (sum over (positive dot product/(negative dot product+positive dot product))) 98 | denom=pos_i.sum(1) 99 | 100 | loss = ((-torch.log(torch.div(torch.exp(pos),(torch.exp(pos).sum(1) + neg_exp_dot).unsqueeze(-1)))*(pos_i)).sum(1)/denom) 101 | 102 | loss = (weights*(loss*no_neg_flag).unsqueeze(-1)).mean() 103 | 104 | 105 | 106 | return loss 107 | 108 | -------------------------------------------------------------------------------- /imdb-wiki-dir/ranking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-present, Royal Bank of Canada. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | def rank(seq): 11 | return torch.argsort(torch.argsort(seq).flip(1)) 12 | 13 | 14 | def rank_normalised(seq): 15 | return (rank(seq) + 1).float() / seq.size()[1] 16 | 17 | 18 | class TrueRanker(torch.autograd.Function): 19 | @staticmethod 20 | def forward(ctx, sequence, lambda_val): 21 | rank = rank_normalised(sequence) 22 | ctx.lambda_val = lambda_val 23 | ctx.save_for_backward(sequence, rank) 24 | return rank 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | sequence, rank = ctx.saved_tensors 29 | assert grad_output.shape == rank.shape 30 | sequence_prime = sequence + ctx.lambda_val * grad_output 31 | rank_prime = rank_normalised(sequence_prime) 32 | gradient = -(rank - rank_prime) / (ctx.lambda_val + 1e-8) 33 | return gradient, None 34 | -------------------------------------------------------------------------------- /imdb-wiki-dir/ranksim.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2023-present, Royal Bank of Canada. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch 10 | import random 11 | import torch.nn.functional as F 12 | 13 | from ranking import TrueRanker, rank_normalised 14 | 15 | def batchwise_ranking_regularizer(features, targets, lambda_val): 16 | loss = 0 17 | 18 | # Reduce ties and boost relative representation of infrequent labels by computing the 19 | # regularizer over a subset of the batch in which each label appears at most once 20 | batch_unique_targets = torch.unique(targets) 21 | if len(batch_unique_targets) < len(targets): 22 | sampled_indices = [] 23 | for target in batch_unique_targets: 24 | sampled_indices.append(random.choice((targets == target).nonzero()[:,0]).item()) 25 | x = features[sampled_indices] 26 | y = targets[sampled_indices] 27 | else: 28 | x = features 29 | y = targets 30 | 31 | # Compute feature similarities 32 | xxt = torch.matmul(F.normalize(x.view(x.size(0),-1)), F.normalize(x.view(x.size(0),-1)).permute(1,0)) 33 | 34 | # Compute ranking similarity loss 35 | for i in range(len(y)): 36 | label_ranks = rank_normalised(-torch.abs(y[i] - y).transpose(0,1)) 37 | feature_ranks = TrueRanker.apply(xxt[i].unsqueeze(dim=0), lambda_val) 38 | loss += F.mse_loss(feature_ranks, label_ranks) 39 | 40 | return loss -------------------------------------------------------------------------------- /imdb-wiki-dir/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | 14 | import math 15 | 16 | import numpy as np 17 | from scipy.ndimage import gaussian_filter1d 18 | from scipy.signal.windows import triang 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | from utils import calibrate_mean_var 25 | 26 | import logging 27 | 28 | print = logging.info 29 | 30 | 31 | class FDS(nn.Module): 32 | 33 | def __init__(self, feature_dim, bucket_num=100, bucket_start=0, start_update=0, start_smooth=1, 34 | kernel='gaussian', ks=5, sigma=2, momentum=0.9): 35 | super(FDS, self).__init__() 36 | self.feature_dim = feature_dim 37 | self.bucket_num = bucket_num 38 | self.bucket_start = bucket_start 39 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma) 40 | self.half_ks = (ks - 1) // 2 41 | self.momentum = momentum 42 | self.start_update = start_update 43 | self.start_smooth = start_smooth 44 | 45 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update)) 46 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim)) 47 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim)) 48 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 49 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 50 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 51 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 52 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start)) 53 | 54 | @staticmethod 55 | def _get_kernel_window(kernel, ks, sigma): 56 | assert kernel in ['gaussian', 'triang', 'laplace'] 57 | half_ks = (ks - 1) // 2 58 | if kernel == 'gaussian': 59 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 60 | base_kernel = np.array(base_kernel, dtype=np.float32) 61 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma)) 62 | elif kernel == 'triang': 63 | kernel_window = triang(ks) / sum(triang(ks)) 64 | else: 65 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 66 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1))) 67 | 68 | print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') 69 | return torch.tensor(kernel_window, dtype=torch.float32).cuda() 70 | 71 | def _update_last_epoch_stats(self): 72 | self.running_mean_last_epoch = self.running_mean 73 | self.running_var_last_epoch = self.running_var 74 | 75 | self.smoothed_mean_last_epoch = F.conv1d( 76 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), 77 | pad=(self.half_ks, self.half_ks), mode='reflect'), 78 | weight=self.kernel_window.view(1, 1, -1), padding=0 79 | ).permute(2, 1, 0).squeeze(1) 80 | self.smoothed_var_last_epoch = F.conv1d( 81 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), 82 | pad=(self.half_ks, self.half_ks), mode='reflect'), 83 | weight=self.kernel_window.view(1, 1, -1), padding=0 84 | ).permute(2, 1, 0).squeeze(1) 85 | 86 | def reset(self): 87 | self.running_mean.zero_() 88 | self.running_var.fill_(1) 89 | self.running_mean_last_epoch.zero_() 90 | self.running_var_last_epoch.fill_(1) 91 | self.smoothed_mean_last_epoch.zero_() 92 | self.smoothed_var_last_epoch.fill_(1) 93 | self.num_samples_tracked.zero_() 94 | 95 | def update_last_epoch_stats(self, epoch): 96 | if epoch == self.epoch + 1: 97 | self.epoch += 1 98 | self._update_last_epoch_stats() 99 | print(f"Updated smoothed statistics on Epoch [{epoch}]!") 100 | 101 | def update_running_stats(self, features, labels, epoch): 102 | if epoch < self.epoch: 103 | return 104 | 105 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" 106 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" 107 | 108 | for label in torch.unique(labels): 109 | if label > self.bucket_num - 1 or label < self.bucket_start: 110 | continue 111 | elif label == self.bucket_start: 112 | curr_feats = features[labels <= label] 113 | elif label == self.bucket_num - 1: 114 | curr_feats = features[labels >= label] 115 | else: 116 | curr_feats = features[labels == label] 117 | curr_num_sample = curr_feats.size(0) 118 | curr_mean = torch.mean(curr_feats, 0) 119 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) 120 | 121 | self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample 122 | factor = self.momentum if self.momentum is not None else \ 123 | (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)])) 124 | factor = 0 if epoch == self.start_update else factor 125 | self.running_mean[int(label - self.bucket_start)] = \ 126 | (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)] 127 | self.running_var[int(label - self.bucket_start)] = \ 128 | (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)] 129 | 130 | print(f"Updated running statistics with Epoch [{epoch}] features!") 131 | 132 | def smooth(self, features, labels, epoch): 133 | if epoch < self.start_smooth: 134 | return features 135 | 136 | labels = labels.squeeze(1) 137 | for label in torch.unique(labels): 138 | if label > self.bucket_num - 1 or label < self.bucket_start: 139 | continue 140 | elif label == self.bucket_start: 141 | features[labels <= label] = calibrate_mean_var( 142 | features[labels <= label], 143 | self.running_mean_last_epoch[int(label - self.bucket_start)], 144 | self.running_var_last_epoch[int(label - self.bucket_start)], 145 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 146 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 147 | elif label == self.bucket_num - 1: 148 | features[labels >= label] = calibrate_mean_var( 149 | features[labels >= label], 150 | self.running_mean_last_epoch[int(label - self.bucket_start)], 151 | self.running_var_last_epoch[int(label - self.bucket_start)], 152 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 153 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 154 | else: 155 | features[labels == label] = calibrate_mean_var( 156 | features[labels == label], 157 | self.running_mean_last_epoch[int(label - self.bucket_start)], 158 | self.running_var_last_epoch[int(label - self.bucket_start)], 159 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 160 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 161 | return features 162 | 163 | 164 | def conv3x3(in_planes, out_planes, stride=1): 165 | """3x3 convolution with padding""" 166 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 167 | 168 | 169 | class BasicBlock(nn.Module): 170 | expansion = 1 171 | 172 | def __init__(self, inplanes, planes, stride=1, downsample=None): 173 | super(BasicBlock, self).__init__() 174 | self.conv1 = conv3x3(inplanes, planes, stride) 175 | self.bn1 = nn.BatchNorm2d(planes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.conv2 = conv3x3(planes, planes) 178 | self.bn2 = nn.BatchNorm2d(planes) 179 | self.downsample = downsample 180 | self.stride = stride 181 | 182 | def forward(self, x): 183 | residual = x 184 | out = self.conv1(x) 185 | out = self.bn1(out) 186 | out = self.relu(out) 187 | out = self.conv2(out) 188 | out = self.bn2(out) 189 | if self.downsample is not None: 190 | residual = self.downsample(x) 191 | out += residual 192 | out = self.relu(out) 193 | return out 194 | 195 | 196 | class Bottleneck(nn.Module): 197 | expansion = 4 198 | 199 | def __init__(self, inplanes, planes, stride=1, downsample=None): 200 | super(Bottleneck, self).__init__() 201 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 202 | self.bn1 = nn.BatchNorm2d(planes) 203 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 204 | self.bn2 = nn.BatchNorm2d(planes) 205 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 206 | self.bn3 = nn.BatchNorm2d(planes * 4) 207 | self.relu = nn.ReLU(inplace=True) 208 | self.downsample = downsample 209 | self.stride = stride 210 | 211 | def forward(self, x): 212 | residual = x 213 | out = self.conv1(x) 214 | out = self.bn1(out) 215 | out = self.relu(out) 216 | out = self.conv2(out) 217 | out = self.bn2(out) 218 | out = self.relu(out) 219 | out = self.conv3(out) 220 | out = self.bn3(out) 221 | if self.downsample is not None: 222 | residual = self.downsample(x) 223 | out += residual 224 | out = self.relu(out) 225 | return out 226 | 227 | 228 | class ResNet(nn.Module): 229 | 230 | def __init__(self, block, layers, fds, bucket_num, bucket_start, start_update, start_smooth, 231 | kernel, ks, sigma, momentum, dropout=None, return_features=False): 232 | self.inplanes = 64 233 | super(ResNet, self).__init__() 234 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 235 | self.bn1 = nn.BatchNorm2d(64) 236 | self.relu = nn.ReLU(inplace=True) 237 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 238 | self.layer1 = self._make_layer(block, 64, layers[0]) 239 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 240 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 241 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 242 | self.avgpool = nn.AvgPool2d(7, stride=1) 243 | self.linear = nn.Linear(512 * block.expansion, 1) 244 | 245 | if fds: 246 | self.FDS = FDS( 247 | feature_dim=512 * block.expansion, bucket_num=bucket_num, bucket_start=bucket_start, 248 | start_update=start_update, start_smooth=start_smooth, kernel=kernel, ks=ks, sigma=sigma, momentum=momentum 249 | ) 250 | self.fds = fds 251 | self.start_smooth = start_smooth 252 | 253 | self.use_dropout = True if dropout else False 254 | if self.use_dropout: 255 | print(f'Using dropout: {dropout}') 256 | self.dropout = nn.Dropout(p=dropout) 257 | 258 | self.return_features = return_features 259 | 260 | for m in self.modules(): 261 | if isinstance(m, nn.Conv2d): 262 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 263 | m.weight.data.normal_(0, math.sqrt(2. / n)) 264 | elif isinstance(m, nn.BatchNorm2d): 265 | m.weight.data.fill_(1) 266 | m.bias.data.zero_() 267 | 268 | def _make_layer(self, block, planes, blocks, stride=1): 269 | downsample = None 270 | if stride != 1 or self.inplanes != planes * block.expansion: 271 | downsample = nn.Sequential( 272 | nn.Conv2d(self.inplanes, planes * block.expansion, 273 | kernel_size=1, stride=stride, bias=False), 274 | nn.BatchNorm2d(planes * block.expansion), 275 | ) 276 | layers = [] 277 | layers.append(block(self.inplanes, planes, stride, downsample)) 278 | self.inplanes = planes * block.expansion 279 | for i in range(1, blocks): 280 | layers.append(block(self.inplanes, planes)) 281 | 282 | return nn.Sequential(*layers) 283 | 284 | def forward(self, x, targets=None, epoch=None,reg = True): 285 | x = self.conv1(x) 286 | x = self.bn1(x) 287 | x = self.relu(x) 288 | x = self.maxpool(x) 289 | 290 | x = self.layer1(x) 291 | x = self.layer2(x) 292 | x = self.layer3(x) 293 | x = self.layer4(x) 294 | x = self.avgpool(x) 295 | encoding = x.view(x.size(0), -1) 296 | 297 | encoding_s = encoding 298 | 299 | if self.training and self.fds and reg: 300 | if epoch >= self.start_smooth: 301 | encoding_s = self.FDS.smooth(encoding_s, targets, epoch) 302 | 303 | if self.use_dropout: 304 | encoding_s = self.dropout(encoding_s) 305 | x = self.linear(encoding_s) 306 | 307 | return x, encoding 308 | 309 | 310 | def resnet50(**kwargs): 311 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 312 | -------------------------------------------------------------------------------- /imdb-wiki-dir/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2021-present, Yuzhe Yang 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | ######################################################################################## 9 | # Code is based on the LDS and FDS (https://arxiv.org/pdf/2102.09554.pdf) implementation 10 | # from https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir 11 | # by Yuzhe Yang et al. 12 | ######################################################################################## 13 | 14 | import os 15 | import shutil 16 | import torch 17 | import logging 18 | import numpy as np 19 | from scipy.ndimage import gaussian_filter1d 20 | from scipy.signal.windows import triang 21 | 22 | 23 | class AverageMeter(object): 24 | def __init__(self, name, fmt=':f'): 25 | self.name = name 26 | self.fmt = fmt 27 | self.reset() 28 | 29 | def reset(self): 30 | self.val = 0 31 | self.avg = 0 32 | self.sum = 0 33 | self.count = 0 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | def __str__(self): 42 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 43 | return fmtstr.format(**self.__dict__) 44 | 45 | 46 | class ProgressMeter(object): 47 | def __init__(self, num_batches, meters, prefix=""): 48 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 49 | self.meters = meters 50 | self.prefix = prefix 51 | 52 | def display(self, batch): 53 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 54 | entries += [str(meter) for meter in self.meters] 55 | logging.info('\t'.join(entries)) 56 | 57 | @staticmethod 58 | def _get_batch_fmtstr(num_batches): 59 | num_digits = len(str(num_batches // 1)) 60 | fmt = '{:' + str(num_digits) + 'd}' 61 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 62 | 63 | 64 | def query_yes_no(question): 65 | """ Ask a yes/no question via input() and return their answer. """ 66 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 67 | prompt = " [Y/n] " 68 | 69 | while True: 70 | print(question + prompt, end=':') 71 | choice = input().lower() 72 | if choice == '': 73 | return valid['y'] 74 | elif choice in valid: 75 | return valid[choice] 76 | else: 77 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") 78 | 79 | 80 | def prepare_folders(args): 81 | folders_util = [args.store_root, os.path.join(args.store_root, args.store_name)] 82 | if os.path.exists(folders_util[-1]) and not args.resume and not args.pretrained and not args.evaluate: 83 | if query_yes_no('overwrite previous folder: {} ?'.format(folders_util[-1])): 84 | shutil.rmtree(folders_util[-1]) 85 | print(folders_util[-1] + ' removed.') 86 | else: 87 | raise RuntimeError('Output folder {} already exists'.format(folders_util[-1])) 88 | for folder in folders_util: 89 | if not os.path.exists(folder): 90 | print(f"===> Creating folder: {folder}") 91 | os.mkdir(folder) 92 | 93 | 94 | def adjust_learning_rate(optimizer, epoch, args): 95 | lr = args.lr 96 | for milestone in args.schedule: 97 | lr *= 0.1 if epoch >= milestone else 1. 98 | for param_group in optimizer.param_groups: 99 | param_group['lr'] = lr 100 | 101 | 102 | def save_checkpoint(args, state, is_best, prefix=''): 103 | filename = f"{args.store_root}/{args.store_name}/{prefix}ckpt.pth.tar" 104 | torch.save(state, filename) 105 | if is_best: 106 | logging.info("===> Saving current best checkpoint...") 107 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 108 | 109 | def save_checkpoint_per_epoch(args, state, epoch, prefix=''): 110 | filename = f"{args.store_root}/{args.store_name}/{prefix}ckpt_ep"+str(epoch)+".pth.tar" 111 | torch.save(state, filename) 112 | 113 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.1, clip_max=10): 114 | if torch.sum(v1) < 1e-10: 115 | return matrix 116 | if (v1 == 0.).any(): 117 | valid = (v1 != 0.) 118 | factor = torch.clamp(v2[valid] / v1[valid], clip_min, clip_max) 119 | matrix[:, valid] = (matrix[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid] 120 | return matrix 121 | 122 | factor = torch.clamp(v2 / v1, clip_min, clip_max) 123 | return (matrix - m1) * torch.sqrt(factor) + m2 124 | 125 | 126 | def get_lds_kernel_window(kernel, ks, sigma): 127 | assert kernel in ['gaussian', 'triang', 'laplace'] 128 | half_ks = (ks - 1) // 2 129 | if kernel == 'gaussian': 130 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 131 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) 132 | elif kernel == 'triang': 133 | kernel_window = triang(ks) 134 | else: 135 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 136 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) 137 | 138 | return kernel_window 139 | 140 | 141 | 142 | def get_lambda(epoch, max_epoch): 143 | p = epoch / max_epoch 144 | return 2. / (1+np.exp(-10.*p)) - 1. 145 | -------------------------------------------------------------------------------- /nyud2-dir/ConR.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | def ConR(feature,depth,output,weights,w=0.2,t=0.07,e=0.2): 10 | 11 | 12 | 13 | k = feature.reshape([feature.shape[0],-1]) 14 | q = feature.reshape([feature.shape[0],-1]) 15 | 16 | 17 | depth = depth.reshape(depth.shape[0],-1) 18 | l_k = torch.mean(depth,dim=1).unsqueeze(-1) 19 | l_q = torch.mean(depth,dim=1).unsqueeze(-1) 20 | 21 | output = output.reshape(output.shape[0],-1) 22 | p_k = torch.mean(output,dim=1).unsqueeze(-1) 23 | p_q = torch.mean(output,dim=1).unsqueeze(-1) 24 | 25 | 26 | 27 | 28 | l_dist = torch.abs(l_q -l_k.T) 29 | p_dist = torch.abs(p_q -p_k.T) 30 | 31 | 32 | 33 | 34 | q = torch.nn.functional.normalize(q, dim=1) 35 | k = torch.nn.functional.normalize(k, dim=1) 36 | 37 | Temp = 0.07 38 | 39 | # dot product of anchor with positives. Positives are keys with similar label 40 | pos_i = l_dist.le(w) 41 | neg_i = ((~ (l_dist.le(w)))*(p_dist.le(w))) 42 | 43 | for i in range(pos_i.shape[0]): 44 | pos_i[i][i] = 0 45 | 46 | prod = torch.einsum("nc,kc->nk", [q, k])/t 47 | pos = prod * pos_i 48 | neg = prod * neg_i 49 | 50 | 51 | 52 | # Pushing weight 53 | weights = torch.mean(weights.reshape(weights.shape[0],-1),dim=1).unsqueeze(-1) 54 | pushing_w = l_dist*weights*e 55 | 56 | 57 | # Sum exp of negative dot products 58 | neg_exp_dot=(pushing_w*(torch.exp(neg))*(neg_i)).sum(1) 59 | 60 | # For each query sample, if there is no negative pair, zero-out the loss. 61 | no_neg_flag = (neg_i).sum(1).bool() 62 | 63 | # Loss = sum over all samples in the batch (sum over (positive dot product/(negative dot product+positive dot product))) 64 | denom=l_dist.le(w).sum(1) 65 | 66 | loss = ((-torch.log(torch.div(torch.exp(pos),(torch.exp(pos).sum(1) + neg_exp_dot).unsqueeze(-1)))*(pos_i)).sum(1)/denom) 67 | 68 | 69 | 70 | 71 | loss = ((loss*no_neg_flag).unsqueeze(-1)).mean() 72 | 73 | 74 | 75 | return loss -------------------------------------------------------------------------------- /nyud2-dir/README.md: -------------------------------------------------------------------------------- 1 | # ConR on NYUD2-DIR 2 | This repository contains the implementation of __ConR__ on NYUD2-DIR 3 | 4 | The imbalanced regression framework and LDS+FDS are based on the public repository of [Ren et al., CVPR 2022](https://github.com/jiawei-ren/BalancedMSE). 5 | 6 | ## Installation 7 | 8 | #### Prerequisites 9 | 10 | 1. Download and extract NYU v2 dataset to folder `./data` using 11 | 12 | ```bash 13 | python download_nyud2.py 14 | ``` 15 | 16 | 2. __(Optional)__ We use required meta files `nyu2_train_FDS_subset.csv` and `test_balanced_mask.npy` provided by Yang et al.(ICML 2021), which is used to set up efficient FDS feature statistics computation and balanced test set mask in folder `./data`. To reproduce the results in the paper, please directly use these two files. For different FDS computation subsets and balanced test set masks, you can run 17 | 18 | ```bash 19 | python preprocess_nyud2.py 20 | ``` 21 | 22 | #### Dependencies 23 | 24 | - PyTorch (>= 1.2, tested on 1.6) 25 | - numpy, pandas, scipy, tqdm, matplotlib, PIL, gdown, tensorboardX 26 | 27 | 28 | ## Getting Started 29 | 30 | ### 1. Train baselines 31 | 32 | To use Balanced MSE 33 | 34 | ```bash 35 | python train.py --bmse --imp bni --init_noise_sigma 1.0 --fix_noise_sigma 36 | ``` 37 | 38 | 39 | 40 | ### 2. Train a model with ConR 41 | 42 | 43 | 44 | ```bash 45 | python train.py --conr -w 0.2 --beta 0.2 -e 0.2 46 | ``` 47 | ### 3. Evaluate and reproduce 48 | 49 | 50 | ```bash 51 | python test.py --eval_model 52 | ``` -------------------------------------------------------------------------------- /nyud2-dir/balanaced_mse.py: -------------------------------------------------------------------------------- 1 | 2 | ##################################################################################### 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | #################################################################################### 25 | import torch 26 | import torch.nn.functional as F 27 | from torch.nn.modules.loss import _Loss 28 | import joblib 29 | 30 | 31 | class GAILoss(_Loss): 32 | def __init__(self, init_noise_sigma, gmm): 33 | super(GAILoss, self).__init__() 34 | self.gmm = joblib.load(gmm) 35 | self.gmm = {k: torch.tensor(self.gmm[k]).cuda() for k in self.gmm} 36 | self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma, device="cuda")) 37 | 38 | def forward(self, pred, target): 39 | noise_var = self.noise_sigma ** 2 40 | loss = gai_loss(pred, target, self.gmm, noise_var) 41 | return loss 42 | 43 | 44 | def gai_loss(pred, target, gmm, noise_var): 45 | gmm = {k: gmm[k].reshape(1, -1).expand(pred.shape[0], -1) for k in gmm} 46 | mse_term = F.mse_loss(pred, target, reduction='none') / 2 / noise_var + 0.5 * noise_var.log() 47 | sum_var = gmm['variances'] + noise_var 48 | balancing_term = - 0.5 * sum_var.log() - 0.5 * (pred - gmm['means']).pow(2) / sum_var + gmm['weights'].log() 49 | balancing_term = torch.logsumexp(balancing_term, dim=-1, keepdim=True) 50 | loss = mse_term + balancing_term 51 | loss = loss * (2 * noise_var).detach() 52 | 53 | return loss.mean() 54 | 55 | 56 | class BMCLoss(_Loss): 57 | def __init__(self, init_noise_sigma): 58 | super(BMCLoss, self).__init__() 59 | self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma, device="cuda")) 60 | 61 | def forward(self, pred, target): 62 | noise_var = self.noise_sigma ** 2 63 | loss = bmc_loss(pred, target, noise_var) 64 | return loss 65 | 66 | 67 | def bmc_loss(pred, target, noise_var): 68 | logits = - 0.5 * (pred - target.T).pow(2) / noise_var 69 | loss = F.cross_entropy(logits, torch.arange(pred.shape[0]).cuda()) 70 | loss = loss * (2 * noise_var).detach() 71 | 72 | return loss 73 | 74 | 75 | class BNILoss(_Loss): 76 | def __init__(self, init_noise_sigma, bucket_centers, bucket_weights): 77 | super(BNILoss, self).__init__() 78 | self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma, device="cuda")) 79 | self.bucket_centers = torch.tensor(bucket_centers).cuda() 80 | self.bucket_weights = torch.tensor(bucket_weights).cuda() 81 | 82 | def forward(self, pred, target): 83 | noise_var = self.noise_sigma ** 2 84 | loss = bni_loss(pred, target, noise_var, self.bucket_centers, self.bucket_weights) 85 | return loss 86 | 87 | 88 | def bni_loss(pred, target, noise_var, bucket_centers, bucket_weights): 89 | mse_term = F.mse_loss(pred, target, reduction='none') / 2 / noise_var 90 | 91 | num_bucket = bucket_centers.shape[0] 92 | bucket_center = bucket_centers.unsqueeze(0).repeat(pred.shape[0], 1) 93 | bucket_weights = bucket_weights.unsqueeze(0).repeat(pred.shape[0], 1) 94 | 95 | balancing_term = - 0.5 * (pred.expand(-1, num_bucket) - bucket_center).pow(2) / noise_var + bucket_weights.log() 96 | balancing_term = torch.logsumexp(balancing_term, dim=-1, keepdim=True) 97 | loss = mse_term + balancing_term 98 | loss = loss * (2 * noise_var).detach() 99 | return loss.mean() 100 | -------------------------------------------------------------------------------- /nyud2-dir/download_nyud2.py: -------------------------------------------------------------------------------- 1 | 2 | ##################################################################################### 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | #################################################################################### 25 | import os 26 | import gdown 27 | import zipfile 28 | 29 | print("Downloading and extracting NYU v2 dataset to folder './data'...") 30 | data_file = "./data.zip" 31 | gdown.download("https://drive.google.com/uc?id=1WoOZOBpOWfmwe7bknWS5PMUCLBPFKTOw") 32 | print('Extracting...') 33 | with zipfile.ZipFile(data_file) as zip_ref: 34 | zip_ref.extractall('.') 35 | os.remove(data_file) 36 | print("Completed!") -------------------------------------------------------------------------------- /nyud2-dir/loaddata.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2023-present, Royal Bank of Canada. 3 | # Copyright (c) 2022 Jiawei Ren 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | ##################################################################################### 9 | # Code is based on the Balanced MSE (https://openaccess.thecvf.com/content/CVPR2022/html/Ren_Balanced_MSE_for_Imbalanced_Visual_Regression_CVPR_2022_paper.html) implementation 10 | # from https://github.com/jiawei-ren/BalancedMSE/tree/main/nyud2-dir by Jiawei Ren 11 | #################################################################################### 12 | import os 13 | import logging 14 | import pandas as pd 15 | from torch.utils.data import Dataset, DataLoader 16 | from torchvision import transforms 17 | from nyu_transform import * 18 | from scipy.ndimage import convolve1d 19 | from util import get_lds_kernel_window 20 | 21 | # for data loading efficiency 22 | TRAIN_BUCKET_NUM = [0, 0, 0, 0, 0, 0, 0, 25848691, 24732940, 53324326, 69112955, 54455432, 95637682, 71403954, 117244217, 23 | 84813007, 126524456, 84486706, 133130272, 95464874, 146051415, 146133612, 96561379, 138366677, 89680276, 24 | 127689043, 81608990, 119121178, 74360607, 106839384, 97595765, 66718296, 90661239, 53103021, 83340912, 25 | 51365604, 71262770, 42243737, 65860580, 38415940, 53647559, 54038467, 28335524, 41485143, 32106001, 26 | 35936734, 23966211, 32018765, 19297203, 31503743, 21681574, 16363187, 25743420, 12769509, 17675327, 27 | 13147819, 15798560, 9547180, 14933200, 9663019, 12887283, 11803562, 7656609, 11515700, 7756306, 9046228, 28 | 5114894, 8653419, 6859433, 8001904, 6430700, 3305839, 6318461, 3486268, 5621065, 4030498, 3839488, 3220208, 29 | 4483027, 2555777, 4685983, 3145082, 2951048, 2762369, 2367581, 2546089, 2343867, 2481579, 1722140, 3018892, 30 | 2325197, 1952354, 2047038, 1858707, 2052729, 1348558, 2487278, 1314198, 3338550, 1132666] 31 | 32 | class depthDataset(Dataset): 33 | def __init__(self, data_dir, csv_file, mask_file=None, transform=None, args=None): 34 | self.data_dir = data_dir 35 | self.frame = pd.read_csv(csv_file, header=None) 36 | self.mask = torch.tensor(np.load(mask_file), dtype=torch.bool) if mask_file is not None else None 37 | self.transform = transform 38 | self.bucket_weights = self._get_bucket_weights(args) if args is not None else None 39 | 40 | def _get_bucket_weights(self, args): 41 | assert args.reweight in {'none', 'inverse', 'sqrt_inv'} 42 | assert args.reweight != 'none' if args.lds else True, "Set reweight to \'sqrt_inv\' or \'inverse\' (default) when using LDS" 43 | if args.reweight == 'none': 44 | return None 45 | logging.info(f"Using re-weighting: [{args.reweight.upper()}]") 46 | 47 | if args.lds: 48 | value_lst = TRAIN_BUCKET_NUM[args.bucket_start:] 49 | lds_kernel_window = get_lds_kernel_window(args.lds_kernel, args.lds_ks, args.lds_sigma) 50 | logging.info(f'Using LDS: [{args.lds_kernel.upper()}] ({args.lds_ks}/{args.lds_sigma})') 51 | if args.reweight == 'sqrt_inv': 52 | value_lst = np.sqrt(value_lst) 53 | smoothed_value = convolve1d(np.asarray(value_lst), weights=lds_kernel_window, mode='reflect') 54 | smoothed_value = [smoothed_value[0]] * args.bucket_start + list(smoothed_value) 55 | scaling = np.sum(TRAIN_BUCKET_NUM) / np.sum(np.array(TRAIN_BUCKET_NUM) / np.array(smoothed_value)) 56 | bucket_weights = [np.float32(scaling / smoothed_value[bucket]) for bucket in range(args.bucket_num)] 57 | else: 58 | value_lst = [TRAIN_BUCKET_NUM[args.bucket_start]] * args.bucket_start + TRAIN_BUCKET_NUM[args.bucket_start:] 59 | if args.reweight == 'sqrt_inv': 60 | value_lst = np.sqrt(value_lst) 61 | scaling = np.sum(TRAIN_BUCKET_NUM) / np.sum(np.array(TRAIN_BUCKET_NUM) / np.array(value_lst)) 62 | bucket_weights = [np.float32(scaling / value_lst[bucket]) for bucket in range(args.bucket_num)] 63 | 64 | return bucket_weights 65 | 66 | def get_bin_idx(self, x): 67 | return min(int(x * np.float32(10)), 99) 68 | 69 | def _get_weights(self, depth): 70 | sp = depth.shape 71 | if self.bucket_weights is not None: 72 | depth = depth.view(-1).cpu().numpy() 73 | assert depth.dtype == np.float32 74 | weights = np.array(list(map(lambda v: self.bucket_weights[self.get_bin_idx(v)], depth))) 75 | weights = torch.tensor(weights, dtype=torch.float32).view(*sp) 76 | else: 77 | weights = torch.tensor([np.float32(1.)], dtype=torch.float32).repeat(*sp) 78 | return weights 79 | 80 | def __getitem__(self, idx): 81 | image_name = self.frame.iloc[idx, 0] 82 | depth_name = self.frame.iloc[idx, 1] 83 | 84 | image_name = os.path.join(self.data_dir, '/'.join(image_name.split('/')[1:])) 85 | depth_name = os.path.join(self.data_dir, '/'.join(depth_name.split('/')[1:])) 86 | 87 | image = Image.open(image_name) 88 | depth = Image.open(depth_name) 89 | 90 | sample = {'image': image, 'depth': depth} 91 | 92 | if self.transform: 93 | sample = self.transform(sample) 94 | 95 | sample['weight'] = self._get_weights(sample['depth']) 96 | sample['idx'] = idx 97 | 98 | if self.mask is not None: 99 | sample['mask'] = self.mask[idx].unsqueeze(0) 100 | 101 | return sample 102 | 103 | def __len__(self): 104 | return len(self.frame) 105 | 106 | 107 | def getTrainingData(args, batch_size=64): 108 | __imagenet_pca = { 109 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 110 | 'eigvec': torch.Tensor([ 111 | [-0.5675, 0.7192, 0.4009], 112 | [-0.5808, -0.0045, -0.8140], 113 | [-0.5836, -0.6948, 0.4203], 114 | ]) 115 | } 116 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 117 | 'std': [0.229, 0.224, 0.225]} 118 | 119 | transformed_training = depthDataset(data_dir=args.data_dir, 120 | csv_file=os.path.join(args.data_dir, 'nyu2_train.csv'), 121 | transform=transforms.Compose([ 122 | Scale(240), 123 | RandomHorizontalFlip(), 124 | RandomRotate(5), 125 | CenterCrop([304, 228], [152, 114]), 126 | ToTensor(), 127 | Lighting(0.1, __imagenet_pca[ 128 | 'eigval'], __imagenet_pca['eigvec']), 129 | ColorJitter( 130 | brightness=0.4, 131 | contrast=0.4, 132 | saturation=0.4, 133 | ), 134 | Normalize(__imagenet_stats['mean'], 135 | __imagenet_stats['std']) 136 | ]), args=args) 137 | 138 | dataloader_training = DataLoader(transformed_training, batch_size, 139 | shuffle=True, num_workers=8, pin_memory=False) 140 | 141 | return dataloader_training 142 | 143 | def getTrainingFDSData(args, batch_size=64): 144 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 145 | 'std': [0.229, 0.224, 0.225]} 146 | 147 | transformed_training = depthDataset(data_dir=args.data_dir, 148 | csv_file=os.path.join(args.data_dir, 'nyu2_train_FDS_subset.csv'), 149 | transform=transforms.Compose([ 150 | Scale(240), 151 | CenterCrop([304, 228], [152, 114]), 152 | ToTensor(), 153 | Normalize(__imagenet_stats['mean'], 154 | __imagenet_stats['std']) 155 | ])) 156 | 157 | dataloader_training = DataLoader(transformed_training, batch_size, 158 | shuffle=False, num_workers=8, pin_memory=False) 159 | return dataloader_training 160 | 161 | 162 | def getTestingData(args, batch_size=64): 163 | 164 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 165 | 'std': [0.229, 0.224, 0.225]} 166 | 167 | transformed_testing = depthDataset(data_dir=args.data_dir, 168 | csv_file=os.path.join(args.data_dir, 'nyu2_test.csv'), 169 | mask_file=os.path.join(args.data_dir, 'test_balanced_mask.npy'), 170 | transform=transforms.Compose([ 171 | Scale(240), 172 | CenterCrop([304, 228], [304, 228]), 173 | ToTensor(is_test=True), 174 | Normalize(__imagenet_stats['mean'], 175 | __imagenet_stats['std']) 176 | ])) 177 | 178 | dataloader_testing = DataLoader(transformed_testing, batch_size, 179 | shuffle=False, num_workers=0, pin_memory=False) 180 | 181 | return dataloader_testing 182 | 183 | 184 | def get_bucket_info(args): 185 | if args.lds: 186 | value_lst = TRAIN_BUCKET_NUM[args.bucket_start:] 187 | lds_kernel_window = get_lds_kernel_window(args.lds_kernel, args.lds_ks, args.lds_sigma) 188 | logging.info(f'Using LDS: [{args.lds_kernel.upper()}] ({args.lds_ks}/{args.lds_sigma})') 189 | if args.reweight == 'sqrt_inv': 190 | value_lst = np.sqrt(value_lst) 191 | smoothed_value = convolve1d(np.asarray(value_lst), weights=lds_kernel_window, mode='reflect') 192 | smoothed_value = [smoothed_value[0]] * args.bucket_start + list(smoothed_value) 193 | bucket_weights = np.asarray(smoothed_value) 194 | else: 195 | value_lst = [TRAIN_BUCKET_NUM[args.bucket_start]] * args.bucket_start + TRAIN_BUCKET_NUM[args.bucket_start:] 196 | if args.reweight == 'sqrt_inv': 197 | value_lst = np.sqrt(value_lst) 198 | bucket_weights = np.asarray(value_lst) 199 | 200 | bucket_centers = np.linspace(0, 10, 101)[:-1] + 0.05 201 | bucket_weights = bucket_weights / bucket_weights.sum() 202 | return bucket_centers, bucket_weights 203 | -------------------------------------------------------------------------------- /nyud2-dir/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | ######################################################################################## 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | # ######################################################################################## -------------------------------------------------------------------------------- /nyud2-dir/models/fds.py: -------------------------------------------------------------------------------- 1 | 2 | ######################################################################################## 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | # ######################################################################################## 25 | import logging 26 | import numpy as np 27 | import torch 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | from scipy.ndimage import gaussian_filter1d 31 | from scipy.signal.windows import triang 32 | from util import calibrate_mean_var 33 | 34 | 35 | class FDS(nn.Module): 36 | 37 | def __init__(self, feature_dim, bucket_num=100, bucket_start=7, start_update=0, start_smooth=1, 38 | kernel='gaussian', ks=5, sigma=2, momentum=0.9): 39 | super(FDS, self).__init__() 40 | self.feature_dim = feature_dim 41 | self.bucket_num = bucket_num 42 | self.bucket_start = bucket_start 43 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma) 44 | self.half_ks = (ks - 1) // 2 45 | self.momentum = momentum 46 | self.start_update = start_update 47 | self.start_smooth = start_smooth 48 | 49 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update)) 50 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim)) 51 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim)) 52 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 53 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 54 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 55 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 56 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start)) 57 | 58 | @staticmethod 59 | def _get_kernel_window(kernel, ks, sigma): 60 | assert kernel in ['gaussian', 'triang', 'laplace'] 61 | half_ks = (ks - 1) // 2 62 | if kernel == 'gaussian': 63 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 64 | base_kernel = np.array(base_kernel, dtype=np.float32) 65 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma)) 66 | elif kernel == 'triang': 67 | kernel_window = triang(ks) / sum(triang(ks)) 68 | else: 69 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 70 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1))) 71 | 72 | logging.info(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') 73 | return torch.tensor(kernel_window, dtype=torch.float32).cuda() 74 | 75 | def _get_bucket_idx(self, label): 76 | label = np.float32(label.cpu()) 77 | return max(min(int(label * np.float32(10)), self.bucket_num - 1), self.bucket_start) 78 | 79 | def _update_last_epoch_stats(self): 80 | self.running_mean_last_epoch = self.running_mean 81 | self.running_var_last_epoch = self.running_var 82 | 83 | self.smoothed_mean_last_epoch = F.conv1d( 84 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), 85 | pad=(self.half_ks, self.half_ks), mode='reflect'), 86 | weight=self.kernel_window.view(1, 1, -1), padding=0 87 | ).permute(2, 1, 0).squeeze(1) 88 | self.smoothed_var_last_epoch = F.conv1d( 89 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), 90 | pad=(self.half_ks, self.half_ks), mode='reflect'), 91 | weight=self.kernel_window.view(1, 1, -1), padding=0 92 | ).permute(2, 1, 0).squeeze(1) 93 | 94 | assert self.smoothed_mean_last_epoch.shape == self.running_mean_last_epoch.shape, \ 95 | "Smoothed shape is not aligned with running shape!" 96 | 97 | def reset(self): 98 | self.running_mean.zero_() 99 | self.running_var.fill_(1) 100 | self.running_mean_last_epoch.zero_() 101 | self.running_var_last_epoch.fill_(1) 102 | self.smoothed_mean_last_epoch.zero_() 103 | self.smoothed_var_last_epoch.fill_(1) 104 | self.num_samples_tracked.zero_() 105 | 106 | def update_last_epoch_stats(self, epoch): 107 | if epoch == self.epoch + 1: 108 | self.epoch += 1 109 | self._update_last_epoch_stats() 110 | logging.info(f"Updated smoothed statistics of last epoch on Epoch [{epoch}]!") 111 | 112 | def _running_stats_to_device(self, device): 113 | if device == 'cpu': 114 | self.num_samples_tracked = self.num_samples_tracked.cpu() 115 | self.running_mean = self.running_mean.cpu() 116 | self.running_var = self.running_var.cpu() 117 | else: 118 | self.num_samples_tracked = self.num_samples_tracked.cuda() 119 | self.running_mean = self.running_mean.cuda() 120 | self.running_var = self.running_var.cuda() 121 | 122 | def update_running_stats(self, features, labels, epoch): 123 | if epoch < self.epoch: 124 | return 125 | 126 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" 127 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" 128 | 129 | self._running_stats_to_device('cpu') 130 | 131 | labels = labels.squeeze(1).view(-1) 132 | features = features.permute(0, 2, 3, 1).contiguous().view(-1, self.feature_dim) 133 | 134 | buckets = np.array([self._get_bucket_idx(label) for label in labels]) 135 | for bucket in np.unique(buckets): 136 | curr_feats = features[torch.tensor((buckets == bucket).astype(np.bool))] 137 | curr_num_sample = curr_feats.size(0) 138 | curr_mean = torch.mean(curr_feats, 0) 139 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) 140 | 141 | self.num_samples_tracked[bucket - self.bucket_start] += curr_num_sample 142 | factor = self.momentum if self.momentum is not None else \ 143 | (1 - curr_num_sample / float(self.num_samples_tracked[bucket - self.bucket_start])) 144 | factor = 0 if epoch == self.start_update else factor 145 | self.running_mean[bucket - self.bucket_start] = \ 146 | (1 - factor) * curr_mean + factor * self.running_mean[bucket - self.bucket_start] 147 | self.running_var[bucket - self.bucket_start] = \ 148 | (1 - factor) * curr_var + factor * self.running_var[bucket - self.bucket_start] 149 | 150 | self._running_stats_to_device('cuda') 151 | logging.info(f"Updated running statistics with Epoch [{epoch}] features!") 152 | 153 | def smooth(self, features, labels, epoch): 154 | if epoch < self.start_smooth: 155 | return features 156 | 157 | sp = labels.squeeze(1).shape 158 | 159 | labels = labels.squeeze(1).view(-1) 160 | features = features.permute(0, 2, 3, 1).contiguous().view(-1, self.feature_dim) 161 | 162 | buckets = torch.max(torch.stack([torch.min(torch.stack([torch.floor(labels * torch.tensor([10.]).cuda()).int(), 163 | torch.zeros(labels.size(0)).fill_(self.bucket_num - 1).int().cuda()], 0), 0)[0], torch.zeros(labels.size(0)).fill_(self.bucket_start).int().cuda()], 0), 0)[0] 164 | for bucket in torch.unique(buckets): 165 | features[buckets.eq(bucket)] = calibrate_mean_var( 166 | features[buckets.eq(bucket)], 167 | self.running_mean_last_epoch[bucket.item() - self.bucket_start], 168 | self.running_var_last_epoch[bucket.item() - self.bucket_start], 169 | self.smoothed_mean_last_epoch[bucket.item() - self.bucket_start], 170 | self.smoothed_var_last_epoch[bucket.item() - self.bucket_start] 171 | ) 172 | 173 | return features.view(*sp, self.feature_dim).permute(0, 3, 1, 2) 174 | -------------------------------------------------------------------------------- /nyud2-dir/models/modules.py: -------------------------------------------------------------------------------- 1 | 2 | ##################################################################################### 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | #################################################################################### 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | from .fds import FDS 29 | 30 | class _UpProjection(nn.Sequential): 31 | 32 | def __init__(self, num_input_features, num_output_features): 33 | super(_UpProjection, self).__init__() 34 | 35 | self.conv1 = nn.Conv2d(num_input_features, num_output_features, 36 | kernel_size=5, stride=1, padding=2, bias=False) 37 | self.bn1 = nn.BatchNorm2d(num_output_features) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv1_2 = nn.Conv2d(num_output_features, num_output_features, 40 | kernel_size=3, stride=1, padding=1, bias=False) 41 | self.bn1_2 = nn.BatchNorm2d(num_output_features) 42 | 43 | self.conv2 = nn.Conv2d(num_input_features, num_output_features, 44 | kernel_size=5, stride=1, padding=2, bias=False) 45 | self.bn2 = nn.BatchNorm2d(num_output_features) 46 | 47 | def forward(self, x, size): 48 | x = F.upsample(x, size=size, mode='bilinear') 49 | x_conv1 = self.relu(self.bn1(self.conv1(x))) 50 | bran1 = self.bn1_2(self.conv1_2(x_conv1)) 51 | bran2 = self.bn2(self.conv2(x)) 52 | 53 | out = self.relu(bran1 + bran2) 54 | 55 | return out 56 | 57 | class E_resnet(nn.Module): 58 | 59 | def __init__(self, original_model, num_features = 2048): 60 | super(E_resnet, self).__init__() 61 | self.conv1 = original_model.conv1 62 | self.bn1 = original_model.bn1 63 | self.relu = original_model.relu 64 | self.maxpool = original_model.maxpool 65 | 66 | self.layer1 = original_model.layer1 67 | self.layer2 = original_model.layer2 68 | self.layer3 = original_model.layer3 69 | self.layer4 = original_model.layer4 70 | 71 | 72 | def forward(self, x): 73 | x = self.conv1(x) 74 | x = self.bn1(x) 75 | x = self.relu(x) 76 | x = self.maxpool(x) 77 | 78 | x_block1 = self.layer1(x) 79 | x_block2 = self.layer2(x_block1) 80 | x_block3 = self.layer3(x_block2) 81 | x_block4 = self.layer4(x_block3) 82 | 83 | return x_block1, x_block2, x_block3, x_block4 84 | 85 | class D(nn.Module): 86 | 87 | def __init__(self, num_features = 2048): 88 | super(D, self).__init__() 89 | self.conv = nn.Conv2d(num_features, num_features // 90 | 2, kernel_size=1, stride=1, bias=False) 91 | num_features = num_features // 2 92 | self.bn = nn.BatchNorm2d(num_features) 93 | 94 | self.up1 = _UpProjection( 95 | num_input_features=num_features, num_output_features=num_features // 2) 96 | num_features = num_features // 2 97 | 98 | self.up2 = _UpProjection( 99 | num_input_features=num_features, num_output_features=num_features // 2) 100 | num_features = num_features // 2 101 | 102 | self.up3 = _UpProjection( 103 | num_input_features=num_features, num_output_features=num_features // 2) 104 | num_features = num_features // 2 105 | 106 | self.up4 = _UpProjection( 107 | num_input_features=num_features, num_output_features=num_features // 2) 108 | num_features = num_features // 2 109 | 110 | 111 | def forward(self, x_block1, x_block2, x_block3, x_block4): 112 | x_d0 = F.relu(self.bn(self.conv(x_block4))) 113 | x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)]) 114 | x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)]) 115 | x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)]) 116 | x_d4 = self.up4(x_d3, [x_block1.size(2)*2, x_block1.size(3)*2]) 117 | 118 | return x_d4 119 | 120 | class MFF(nn.Module): 121 | 122 | def __init__(self, block_channel, num_features=64): 123 | 124 | super(MFF, self).__init__() 125 | 126 | self.up1 = _UpProjection( 127 | num_input_features=block_channel[0], num_output_features=16) 128 | 129 | self.up2 = _UpProjection( 130 | num_input_features=block_channel[1], num_output_features=16) 131 | 132 | self.up3 = _UpProjection( 133 | num_input_features=block_channel[2], num_output_features=16) 134 | 135 | self.up4 = _UpProjection( 136 | num_input_features=block_channel[3], num_output_features=16) 137 | 138 | self.conv = nn.Conv2d( 139 | num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False) 140 | self.bn = nn.BatchNorm2d(num_features) 141 | 142 | 143 | def forward(self, x_block1, x_block2, x_block3, x_block4, size): 144 | x_m1 = self.up1(x_block1, size) 145 | x_m2 = self.up2(x_block2, size) 146 | x_m3 = self.up3(x_block3, size) 147 | x_m4 = self.up4(x_block4, size) 148 | 149 | x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1))) 150 | x = F.relu(x) 151 | 152 | return x 153 | 154 | 155 | class R(nn.Module): 156 | def __init__(self, args, block_channel): 157 | 158 | super(R, self).__init__() 159 | 160 | num_features = 64 + block_channel[3] // 32 161 | self.conv0 = nn.Conv2d(num_features, num_features, 162 | kernel_size=5, stride=1, padding=2, bias=False) 163 | self.bn0 = nn.BatchNorm2d(num_features) 164 | 165 | self.conv1 = nn.Conv2d(num_features, num_features, 166 | kernel_size=5, stride=1, padding=2, bias=False) 167 | self.bn1 = nn.BatchNorm2d(num_features) 168 | 169 | self.conv2 = nn.Conv2d(num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 170 | 171 | self.args = args 172 | 173 | if args is not None and args.fds: 174 | self.FDS = FDS(feature_dim=num_features, bucket_num=args.bucket_num, bucket_start=args.bucket_start, 175 | start_update=args.start_update, start_smooth=args.start_smooth, kernel=args.fds_kernel, 176 | ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt) 177 | 178 | def forward(self, x, depth=None, epoch=None): 179 | x0 = self.conv0(x) 180 | x0 = self.bn0(x0) 181 | x0 = F.relu(x0) 182 | 183 | x1 = self.conv1(x0) 184 | x1 = self.bn1(x1) 185 | x1 = F.relu(x1) 186 | 187 | x1_s = x1 188 | 189 | if self.training and self.args.fds: 190 | if epoch >= self.args.start_smooth: 191 | x1_s = self.FDS.smooth(x1_s, depth, epoch) 192 | 193 | x2 = self.conv2(x1_s) 194 | 195 | 196 | return x2, x1 197 | # if self.training and self.args.fds: 198 | # return x2, x1 199 | # else: 200 | # return x2 -------------------------------------------------------------------------------- /nyud2-dir/models/net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2022 Jiawei Ren 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Balanced MSE (https://openaccess.thecvf.com/content/CVPR2022/html/Ren_Balanced_MSE_for_Imbalanced_Visual_Regression_CVPR_2022_paper.html) implementation 9 | # from https://github.com/jiawei-ren/BalancedMSE/tree/main/nyud2-dir by Jiawei Ren 10 | #################################################################################### 11 | import torch 12 | import torch.nn as nn 13 | from models import modules 14 | 15 | class model(nn.Module): 16 | def __init__(self, args, Encoder, num_features, block_channel): 17 | 18 | super(model, self).__init__() 19 | 20 | self.E = Encoder 21 | self.D = modules.D(num_features) 22 | self.MFF = modules.MFF(block_channel) 23 | self.R = modules.R(args, block_channel) 24 | 25 | 26 | def forward(self, x, depth=None, epoch=None): 27 | x_block1, x_block2, x_block3, x_block4 = self.E(x) 28 | x_decoder = self.D(x_block1, x_block2, x_block3, x_block4) 29 | x_mff = self.MFF(x_block1, x_block2, x_block3, x_block4,[x_decoder.size(2),x_decoder.size(3)]) 30 | out,features = self.R(torch.cat((x_decoder, x_mff), 1), depth, epoch) 31 | 32 | return out,features 33 | -------------------------------------------------------------------------------- /nyud2-dir/models/resnet.py: -------------------------------------------------------------------------------- 1 | 2 | ##################################################################################### 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | #################################################################################### 25 | import torch.nn as nn 26 | import math 27 | import torch.utils.model_zoo as model_zoo 28 | 29 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 30 | 31 | 32 | model_urls = { 33 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 34 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 35 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 36 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 37 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 38 | } 39 | 40 | 41 | def conv3x3(in_planes, out_planes, stride=1): 42 | "3x3 convolution with padding" 43 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 44 | padding=1, bias=False) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | residual = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | residual = self.downsample(x) 72 | 73 | out += residual 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None): 83 | super(Bottleneck, self).__init__() 84 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(planes) 86 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 87 | padding=1, bias=False) 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 90 | self.bn3 = nn.BatchNorm2d(planes * 4) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | 120 | def __init__(self, block, layers, num_classes=1000): 121 | self.inplanes = 64 122 | super(ResNet, self).__init__() 123 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 124 | bias=False) 125 | self.bn1 = nn.BatchNorm2d(64) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 128 | self.layer1 = self._make_layer(block, 64, layers[0]) 129 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 130 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 132 | self.avgpool = nn.AvgPool2d(7, stride=1) 133 | self.fc = nn.Linear(512 * block.expansion, num_classes) 134 | 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | 143 | def _make_layer(self, block, planes, blocks, stride=1): 144 | downsample = None 145 | if stride != 1 or self.inplanes != planes * block.expansion: 146 | downsample = nn.Sequential( 147 | nn.Conv2d(self.inplanes, planes * block.expansion, 148 | kernel_size=1, stride=stride, bias=False), 149 | nn.BatchNorm2d(planes * block.expansion), 150 | ) 151 | 152 | layers = [] 153 | layers.append(block(self.inplanes, planes, stride, downsample)) 154 | self.inplanes = planes * block.expansion 155 | for i in range(1, blocks): 156 | layers.append(block(self.inplanes, planes)) 157 | 158 | return nn.Sequential(*layers) 159 | 160 | def forward(self, x): 161 | x = self.conv1(x) 162 | x = self.bn1(x) 163 | x = self.relu(x) 164 | x = self.maxpool(x) 165 | 166 | x = self.layer1(x) 167 | x = self.layer2(x) 168 | x = self.layer3(x) 169 | x = self.layer4(x) 170 | 171 | x = self.avgpool(x) 172 | x = x.view(x.size(0), -1) 173 | x = self.fc(x) 174 | 175 | return x 176 | 177 | def resnet18(pretrained=False, **kwargs): 178 | """Constructs a ResNet-18 model. 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 183 | if pretrained: 184 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 185 | return model 186 | 187 | 188 | def resnet34(pretrained=False, **kwargs): 189 | """Constructs a ResNet-34 model. 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 196 | return model 197 | 198 | 199 | def resnet50(pretrained=False, **kwargs): 200 | """Constructs a ResNet-50 model. 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 205 | if pretrained: 206 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], 'pretrained_model/encoder')) 207 | return model 208 | 209 | 210 | def resnet101(pretrained=False, **kwargs): 211 | """Constructs a ResNet-101 model. 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 218 | return model 219 | 220 | 221 | def resnet152(pretrained=False, **kwargs): 222 | """Constructs a ResNet-152 model. 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | """ 226 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 227 | if pretrained: 228 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 229 | return model 230 | -------------------------------------------------------------------------------- /nyud2-dir/nyu_transform.py: -------------------------------------------------------------------------------- 1 | 2 | ##################################################################################### 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | #################################################################################### 25 | import torch 26 | import numpy as np 27 | from PIL import Image 28 | import collections 29 | try: 30 | import accimage 31 | except ImportError: 32 | accimage = None 33 | import random 34 | import scipy.ndimage as ndimage 35 | 36 | 37 | def _is_pil_image(img): 38 | if accimage is not None: 39 | return isinstance(img, (Image.Image, accimage.Image)) 40 | else: 41 | return isinstance(img, Image.Image) 42 | 43 | 44 | def _is_numpy_image(img): 45 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 46 | 47 | 48 | class RandomRotate(object): 49 | """Random rotation of the image from -angle to angle (in degrees) 50 | This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation 51 | angle: max angle of the rotation 52 | interpolation order: Default: 2 (bilinear) 53 | reshape: Default: false. If set to true, image size will be set to keep every pixel in the image. 54 | diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off. 55 | """ 56 | 57 | def __init__(self, angle, diff_angle=0, order=2, reshape=False): 58 | self.angle = angle 59 | self.reshape = reshape 60 | self.order = order 61 | 62 | def __call__(self, sample): 63 | image, depth = sample['image'], sample['depth'] 64 | 65 | applied_angle = random.uniform(-self.angle, self.angle) 66 | angle1 = applied_angle 67 | angle1_rad = angle1 * np.pi / 180 68 | 69 | image = ndimage.interpolation.rotate( 70 | image, angle1, reshape=self.reshape, order=self.order) 71 | depth = ndimage.interpolation.rotate( 72 | depth, angle1, reshape=self.reshape, order=self.order) 73 | 74 | image = Image.fromarray(image) 75 | depth = Image.fromarray(depth) 76 | 77 | return {'image': image, 'depth': depth} 78 | 79 | class RandomHorizontalFlip(object): 80 | 81 | def __call__(self, sample): 82 | image, depth = sample['image'], sample['depth'] 83 | 84 | if not _is_pil_image(image): 85 | raise TypeError( 86 | 'img should be PIL Image. Got {}'.format(type(image))) 87 | if not _is_pil_image(depth): 88 | raise TypeError( 89 | 'img should be PIL Image. Got {}'.format(type(depth))) 90 | 91 | if random.random() < 0.5: 92 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 93 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 94 | 95 | return {'image': image, 'depth': depth} 96 | 97 | 98 | class Scale(object): 99 | """ Rescales the inputs and target arrays to the given 'size'. 100 | 'size' will be the size of the smaller edge. 101 | For example, if height > width, then image will be 102 | rescaled to (size * height / width, size) 103 | size: size of the smaller edge 104 | interpolation order: Default: 2 (bilinear) 105 | """ 106 | 107 | def __init__(self, size): 108 | self.size = size 109 | 110 | def __call__(self, sample): 111 | image, depth = sample['image'], sample['depth'] 112 | 113 | image = self.changeScale(image, self.size) 114 | depth = self.changeScale(depth, self.size,Image.NEAREST) 115 | 116 | return {'image': image, 'depth': depth} 117 | 118 | def changeScale(self, img, size, interpolation=Image.BILINEAR): 119 | 120 | if not _is_pil_image(img): 121 | raise TypeError( 122 | 'img should be PIL Image. Got {}'.format(type(img))) 123 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 124 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 125 | 126 | if isinstance(size, int): 127 | w, h = img.size 128 | if (w <= h and w == size) or (h <= w and h == size): 129 | return img 130 | if w < h: 131 | ow = size 132 | oh = int(size * h / w) 133 | return img.resize((ow, oh), interpolation) 134 | else: 135 | oh = size 136 | ow = int(size * w / h) 137 | return img.resize((ow, oh), interpolation) 138 | else: 139 | return img.resize(size[::-1], interpolation) 140 | 141 | 142 | class CenterCrop(object): 143 | def __init__(self, size_image, size_depth): 144 | self.size_image = size_image 145 | self.size_depth = size_depth 146 | 147 | def __call__(self, sample): 148 | image, depth = sample['image'], sample['depth'] 149 | 150 | image = self.centerCrop(image, self.size_image) 151 | depth = self.centerCrop(depth, self.size_image) 152 | 153 | ow, oh = self.size_depth 154 | depth = depth.resize((ow, oh)) 155 | 156 | return {'image': image, 'depth': depth} 157 | 158 | def centerCrop(self, image, size): 159 | 160 | w1, h1 = image.size 161 | 162 | tw, th = size 163 | 164 | if w1 == tw and h1 == th: 165 | return image 166 | 167 | x1 = int(round((w1 - tw) / 2.)) 168 | y1 = int(round((h1 - th) / 2.)) 169 | 170 | image = image.crop((x1, y1, tw + x1, th + y1)) 171 | 172 | return image 173 | 174 | 175 | class ToTensor(object): 176 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 177 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 178 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 179 | """ 180 | def __init__(self,is_test=False): 181 | self.is_test = is_test 182 | 183 | def __call__(self, sample): 184 | image, depth = sample['image'], sample['depth'] 185 | """ 186 | Args: 187 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 188 | Returns: 189 | Tensor: Converted image. 190 | """ 191 | # ground truth depth of training samples is stored in 8-bit while test samples are saved in 16 bit 192 | image = self.to_tensor(image) 193 | if self.is_test: 194 | depth = self.to_tensor(depth).float() / 1000 195 | else: 196 | depth = self.to_tensor(depth).float() * 10 197 | return {'image': image, 'depth': depth} 198 | 199 | def to_tensor(self, pic): 200 | if not(_is_pil_image(pic) or _is_numpy_image(pic)): 201 | raise TypeError( 202 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 203 | 204 | if isinstance(pic, np.ndarray): 205 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 206 | 207 | return img.float().div(255) 208 | 209 | if accimage is not None and isinstance(pic, accimage.Image): 210 | nppic = np.zeros( 211 | [pic.channels, pic.height, pic.width], dtype=np.float32) 212 | pic.copyto(nppic) 213 | return torch.from_numpy(nppic) 214 | 215 | # handle PIL Image 216 | if pic.mode == 'I': 217 | img = torch.from_numpy(np.array(pic, np.int32)) 218 | elif pic.mode == 'I;16': 219 | img = torch.from_numpy(np.array(pic, np.int16)) 220 | else: 221 | img = torch.ByteTensor( 222 | torch.ByteStorage.from_buffer(pic.tobytes())) 223 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 224 | if pic.mode == 'YCbCr': 225 | nchannel = 3 226 | elif pic.mode == 'I;16': 227 | nchannel = 1 228 | else: 229 | nchannel = len(pic.mode) 230 | img = img.view(pic.size[1], pic.size[0], nchannel) 231 | # put it from HWC to CHW format 232 | # yikes, this transpose takes 80% of the loading time/CPU 233 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 234 | if isinstance(img, torch.ByteTensor): 235 | return img.float().div(255) 236 | else: 237 | return img 238 | 239 | 240 | class Lighting(object): 241 | 242 | def __init__(self, alphastd, eigval, eigvec): 243 | self.alphastd = alphastd 244 | self.eigval = eigval 245 | self.eigvec = eigvec 246 | 247 | def __call__(self, sample): 248 | image, depth = sample['image'], sample['depth'] 249 | if self.alphastd == 0: 250 | return image 251 | 252 | alpha = image.new().resize_(3).normal_(0, self.alphastd) 253 | rgb = self.eigvec.type_as(image).clone()\ 254 | .mul(alpha.view(1, 3).expand(3, 3))\ 255 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 256 | .sum(1).squeeze() 257 | 258 | image = image.add(rgb.view(3, 1, 1).expand_as(image)) 259 | 260 | return {'image': image, 'depth': depth} 261 | 262 | 263 | class Grayscale(object): 264 | 265 | def __call__(self, img): 266 | gs = img.clone() 267 | gs[0].mul_(0.299).add_(gs[1], alpha=0.587).add_(gs[2], alpha=0.114) 268 | gs[1].copy_(gs[0]) 269 | gs[2].copy_(gs[0]) 270 | return gs 271 | 272 | 273 | class Saturation(object): 274 | 275 | def __init__(self, var): 276 | self.var = var 277 | 278 | def __call__(self, img): 279 | gs = Grayscale()(img) 280 | alpha = random.uniform(-self.var, self.var) 281 | return img.lerp(gs, alpha) 282 | 283 | 284 | class Brightness(object): 285 | 286 | def __init__(self, var): 287 | self.var = var 288 | 289 | def __call__(self, img): 290 | gs = img.new().resize_as_(img).zero_() 291 | alpha = random.uniform(-self.var, self.var) 292 | 293 | return img.lerp(gs, alpha) 294 | 295 | 296 | class Contrast(object): 297 | 298 | def __init__(self, var): 299 | self.var = var 300 | 301 | def __call__(self, img): 302 | gs = Grayscale()(img) 303 | gs.fill_(gs.mean()) 304 | alpha = random.uniform(-self.var, self.var) 305 | return img.lerp(gs, alpha) 306 | 307 | 308 | class RandomOrder(object): 309 | """ Composes several transforms together in random order. 310 | """ 311 | 312 | def __init__(self, transforms): 313 | self.transforms = transforms 314 | 315 | def __call__(self, sample): 316 | image, depth = sample['image'], sample['depth'] 317 | 318 | if self.transforms is None: 319 | return {'image': image, 'depth': depth} 320 | order = torch.randperm(len(self.transforms)) 321 | for i in order: 322 | image = self.transforms[i](image) 323 | 324 | return {'image': image, 'depth': depth} 325 | 326 | 327 | class ColorJitter(RandomOrder): 328 | 329 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 330 | self.transforms = [] 331 | if brightness != 0: 332 | self.transforms.append(Brightness(brightness)) 333 | if contrast != 0: 334 | self.transforms.append(Contrast(contrast)) 335 | if saturation != 0: 336 | self.transforms.append(Saturation(saturation)) 337 | 338 | 339 | class Normalize(object): 340 | def __init__(self, mean, std): 341 | self.mean = mean 342 | self.std = std 343 | 344 | def __call__(self, sample): 345 | """ 346 | Args: 347 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 348 | Returns: 349 | Tensor: Normalized image. 350 | """ 351 | image, depth = sample['image'], sample['depth'] 352 | 353 | image = self.normalize(image, self.mean, self.std) 354 | 355 | return {'image': image, 'depth': depth} 356 | 357 | def normalize(self, tensor, mean, std): 358 | """Normalize a tensor image with mean and standard deviation. 359 | See ``Normalize`` for more details. 360 | Args: 361 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 362 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 363 | std (sequence): Sequence of standard deviations for R, G, B channels 364 | respecitvely. 365 | Returns: 366 | Tensor: Normalized image. 367 | """ 368 | 369 | for t, m, s in zip(tensor, mean, std): 370 | t.sub_(m).div_(s) 371 | return tensor 372 | -------------------------------------------------------------------------------- /nyud2-dir/preprocess_gmm.py: -------------------------------------------------------------------------------- 1 | 2 | ##################################################################################### 3 | # MIT License 4 | 5 | # Copyright (c) 2022 Jiawei Ren 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | #################################################################################### 25 | import loaddata 26 | import argparse 27 | from sklearn.mixture import GaussianMixture 28 | import torch 29 | import joblib 30 | import time 31 | from loaddata import TRAIN_BUCKET_NUM 32 | 33 | parser = argparse.ArgumentParser(description='') 34 | 35 | # Args for GMM 36 | parser.add_argument('--K', type=int, default=16, help='GMM number of components') 37 | parser.add_argument('--batch_size', default=32, type=int, help='batch size number') 38 | 39 | bucket_centers = torch.linspace(0, 10, 101)[:-1] + 0.05 40 | TRAIN_BUCKET_NUM = [TRAIN_BUCKET_NUM[7]] * 7 + TRAIN_BUCKET_NUM[7:] 41 | 42 | def fit_gmm(args): 43 | end_time = time.time() 44 | all_labels = [] 45 | # There are too many pixels in NYUD2-DIR to fit a GMM 46 | # We directly use the statistics provided in the original code 47 | for i in range(100): 48 | all_labels += [bucket_centers[i] for _ in range(TRAIN_BUCKET_NUM[i] // 1000000)] 49 | all_labels = torch.tensor(all_labels).reshape(1, -1) 50 | print('All labels shape: ', all_labels.shape) 51 | print(time.time() - end_time) 52 | end_time = time.time() 53 | print('Training labels curated') 54 | print('Fitting GMM...') 55 | gmm = GaussianMixture(n_components=args.K, random_state=0, verbose=2).fit( 56 | all_labels.reshape(-1, 1).numpy()) 57 | print(time.time() - end_time) 58 | print('GMM fiited') 59 | print("Dumping...") 60 | gmm_dict = {} 61 | gmm_dict['means'] = gmm.means_ 62 | gmm_dict['weights'] = gmm.weights_ 63 | gmm_dict['variances'] = gmm.covariances_ 64 | return gmm_dict 65 | 66 | def main(): 67 | args = parser.parse_args() 68 | train_loader = loaddata.getTrainingData(args, args.batch_size) 69 | gmm_dict = fit_gmm(train_loader, args) 70 | gmm_path = 'gmm.pkl' 71 | joblib.dump(gmm_dict, gmm_path) 72 | print('Dumped at {}'.format(gmm_path)) 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /nyud2-dir/preprocess_nyud2.py: -------------------------------------------------------------------------------- 1 | 2 | ##################################################################################### 3 | # MIT License 4 | 5 | # Copyright (c) 2021 Yuzhe Yang 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | #################################################################################### 25 | 26 | import os 27 | import argparse 28 | import pandas as pd 29 | from tqdm import tqdm 30 | from torchvision import transforms 31 | from torch.utils.data import DataLoader 32 | from nyu_transform import * 33 | from loaddata import depthDataset 34 | 35 | def load_data(args): 36 | train_dataset = depthDataset( 37 | csv_file=os.path.join(args.data_dir, 'nyu2_train.csv'), 38 | transform=transforms.Compose([ 39 | Scale(240), 40 | CenterCrop([304, 228], [304, 228]), 41 | ToTensor(is_test=False), 42 | ]) 43 | ) 44 | train_dataloader = DataLoader(train_dataset, 256, shuffle=False, num_workers=16, pin_memory=False) 45 | 46 | test_dataset = depthDataset( 47 | csv_file=os.path.join(args.data_dir, 'nyu2_test.csv'), 48 | transform=transforms.Compose([ 49 | Scale(240), 50 | CenterCrop([304, 228], [304, 228]), 51 | ToTensor(is_test=True), 52 | ]) 53 | ) 54 | # print(train_dataset.__len__(), test_dataset.__len__()) 55 | test_dataloader = DataLoader(test_dataset, 256, shuffle=False, num_workers=16, pin_memory=False) 56 | 57 | return train_dataloader, test_dataloader 58 | 59 | def create_FDS_train_subset_id(args): 60 | print('Creating FDS statistics updating subset ids...') 61 | train_dataloader, _ = load_data(args) 62 | train_depth_values = [] 63 | for i, sample in enumerate(tqdm(train_dataloader)): 64 | train_depth_values.append(sample['depth'].squeeze()) 65 | train_depth_values = torch.cat(train_depth_values, 0) 66 | select_idx = np.random.choice(a=list(range(train_depth_values.size(0))), size=600, replace=False) 67 | np.save(os.path.join(args.data_dir, 'FDS_train_subset_id.npy'), select_idx) 68 | 69 | def create_FDS_train_subset(args): 70 | print('Creating FDS statistics updating subset...') 71 | frame = pd.read_csv(os.path.join(args.data_dir, 'nyu2_train.csv'), header=None) 72 | select_id = np.load(os.path.join(args.data_dir, 'FDS_train_subset_id.npy')) 73 | frame.iloc[select_id].to_csv(os.path.join(args.data_dir, 'nyu2_train_FDS_subset.csv'), index=False, header=False) 74 | 75 | def get_bin_idx(x): 76 | return min(int(x * np.float32(10)), 99) 77 | 78 | def create_balanced_testset(args): 79 | print('Creating balanced test set mask...') 80 | _, test_dataloader = load_data(args) 81 | test_depth_values = [] 82 | 83 | for i, sample in enumerate(tqdm(test_dataloader)): 84 | test_depth_values.append(sample['depth'].squeeze()) 85 | test_depth_values = torch.cat(test_depth_values, 0) 86 | test_depth_values_flatten = test_depth_values.view(-1).numpy() 87 | test_bins_number, _ = np.histogram(a=test_depth_values_flatten, bins=100, range=(0., 10.)) 88 | 89 | select_pixel_num = min(test_bins_number[test_bins_number != 0]) 90 | test_depth_values_flatten_bins = np.array(list(map(lambda v: get_bin_idx(v), test_depth_values_flatten))) 91 | 92 | test_depth_flatten_mask = np.zeros(test_depth_values_flatten.shape[0], dtype=np.uint8) 93 | for i in range(7, 100): 94 | bucket_idx = np.where(test_depth_values_flatten_bins == i)[0] 95 | select_bucket_idx = np.random.choice(a=bucket_idx, size=select_pixel_num, replace=False) 96 | test_depth_flatten_mask[select_bucket_idx] = np.uint8(1) 97 | test_depth_mask = test_depth_flatten_mask.reshape(test_depth_values.numpy().shape) 98 | np.save(os.path.join(args.data_dir, 'test_balanced_mask.npy'), test_depth_mask) 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser(description='') 102 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory') 103 | args = parser.parse_args() 104 | 105 | create_FDS_train_subset_id(args) 106 | create_FDS_train_subset(args) 107 | create_balanced_testset(args) -------------------------------------------------------------------------------- /nyud2-dir/test.py: -------------------------------------------------------------------------------- 1 | ##################################################################################### 2 | # MIT License 3 | 4 | # Copyright (c) 2022 Jiawei Ren 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | #################################################################################### 24 | import os 25 | import logging 26 | import argparse 27 | from tqdm import tqdm 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.nn.parallel 32 | 33 | import loaddata 34 | from models import modules, net, resnet 35 | from util import Evaluator 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--eval_model', type=str, default='', help='evaluation model path') 40 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory') 41 | args = parser.parse_args() 42 | 43 | logging.root.handlers = [] 44 | logging.basicConfig( 45 | level=logging.INFO, 46 | format="%(asctime)s | %(message)s", 47 | handlers=[ 48 | logging.StreamHandler() 49 | ]) 50 | 51 | model = define_model() 52 | assert os.path.isfile(args.eval_model), f"No checkpoint found at '{args.eval_model}'" 53 | model = torch.nn.DataParallel(model).cuda() 54 | model_state = torch.load(args.eval_model) 55 | logging.info(f"Loading checkpoint from {args.eval_model}") 56 | model.load_state_dict(model_state['state_dict'], strict=False) 57 | logging.info('Loaded successfully!') 58 | 59 | test_loader = loaddata.getTestingData(args, 1) 60 | test(test_loader, model) 61 | 62 | def test(test_loader, model): 63 | model.eval() 64 | 65 | logging.info('Starting testing...') 66 | 67 | evaluator = Evaluator() 68 | 69 | with torch.no_grad(): 70 | for i, sample_batched in enumerate(tqdm(test_loader)): 71 | image, depth, mask = sample_batched['image'], sample_batched['depth'], sample_batched['mask'] 72 | depth = depth.cuda(non_blocking=True) 73 | image = image.cuda() 74 | output,_ = model(image) 75 | output = nn.functional.interpolate(output, size=[depth.size(2),depth.size(3)], mode='bilinear', align_corners=True) 76 | 77 | evaluator(output[mask], depth[mask]) 78 | 79 | logging.info('Finished testing. Start printing statistics below...') 80 | metric_dict = evaluator.evaluate_shot() 81 | 82 | return metric_dict['overall']['RMSE'], metric_dict 83 | 84 | def define_model(): 85 | original_model = resnet.resnet50(pretrained = True) 86 | Encoder = modules.E_resnet(original_model) 87 | model = net.model(None, Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) 88 | 89 | return model 90 | 91 | if __name__ == '__main__': 92 | main() -------------------------------------------------------------------------------- /nyud2-dir/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-present, Royal Bank of Canada. 2 | # Copyright (c) 2022 Jiawei Ren 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Balanced MSE (https://openaccess.thecvf.com/content/CVPR2022/html/Ren_Balanced_MSE_for_Imbalanced_Visual_Regression_CVPR_2022_paper.html) implementation 9 | # from https://github.com/jiawei-ren/BalancedMSE/tree/main/nyud2-dir by Jiawei Ren 10 | #################################################################################### 11 | import argparse 12 | import time 13 | import os 14 | import datetime 15 | import shutil 16 | import logging 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | import loaddata 20 | from tqdm import tqdm 21 | from models import modules, net, resnet 22 | from util import query_yes_no 23 | from test import test 24 | from tensorboardX import SummaryWriter 25 | from balanaced_mse import * 26 | from ConR import ConR 27 | 28 | parser = argparse.ArgumentParser(description='') 29 | 30 | # training/optimization related 31 | parser.add_argument('--epochs', default=20, type=int, 32 | help='number of total epochs to run') 33 | parser.add_argument('--start_epoch', default=0, type=int, 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 36 | help='initial learning rate') 37 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, 38 | help='weight decay (default: 1e-4)') 39 | parser.add_argument('--batch_size', default=32, type=int, help='batch size number') # 1 GPU - 8 40 | parser.add_argument('--store_root', type=str, default='checkpoint') 41 | parser.add_argument('--store_name', type=str, default='nyud2') 42 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory') 43 | parser.add_argument('--resume', action='store_true', default=False, help='whether to resume training') 44 | 45 | # imbalanced related 46 | # LDS 47 | parser.add_argument('--lds', action='store_true', default=False, help='whether to enable LDS') 48 | parser.add_argument('--lds_kernel', type=str, default='gaussian', 49 | choices=['gaussian', 'triang', 'laplace'], help='LDS kernel type') 50 | parser.add_argument('--lds_ks', type=int, default=5, help='LDS kernel size: should be odd number') 51 | parser.add_argument('--lds_sigma', type=float, default=2, help='LDS gaussian/laplace kernel sigma') 52 | # FDS 53 | parser.add_argument('--fds', action='store_true', default=False, help='whether to enable FDS') 54 | parser.add_argument('--fds_kernel', type=str, default='gaussian', 55 | choices=['gaussian', 'triang', 'laplace'], help='FDS kernel type') 56 | parser.add_argument('--fds_ks', type=int, default=5, help='FDS kernel size: should be odd number') 57 | parser.add_argument('--fds_sigma', type=float, default=2, help='FDS gaussian/laplace kernel sigma') 58 | parser.add_argument('--start_update', type=int, default=0, help='which epoch to start FDS updating') 59 | parser.add_argument('--start_smooth', type=int, default=1, help='which epoch to start using FDS to smooth features') 60 | parser.add_argument('--bucket_num', type=int, default=100, help='maximum bucket considered for FDS') 61 | parser.add_argument('--bucket_start', type=int, default=7, help='minimum(starting) bucket for FDS, 7 for NYUDv2') 62 | parser.add_argument('--fds_mmt', type=float, default=0.9, help='FDS momentum') 63 | 64 | # BMSE 65 | parser.add_argument('--bmse', action='store_true', default=False, help='use Balanced MSE') 66 | parser.add_argument('--imp', type=str, default='gai', choices=['gai', 'bmc', 'bni'], help='implementation options') 67 | parser.add_argument('--gmm', type=str, default='gmm.pkl', help='path to preprocessed GMM') 68 | parser.add_argument('--init_noise_sigma', type=float, default=1., help='initial scale of the noise') 69 | parser.add_argument('--sigma_lr', type=float, default=1e-2, help='learning rate of the noise scale') 70 | parser.add_argument('--fix_noise_sigma', action='store_true', default=False, help='disable joint optimization') 71 | 72 | 73 | # re-weighting: SQRT_INV / INV 74 | parser.add_argument('--reweight', type=str, default='none', choices=['none', 'inverse', 'sqrt_inv'], 75 | help='cost-sensitive reweighting scheme') 76 | # two-stage training: RRT 77 | parser.add_argument('--retrain_fc', action='store_true', default=False, 78 | help='whether to retrain last regression layer (regressor)') 79 | parser.add_argument('--pretrained', type=str, default='', help='pretrained checkpoint file path to load backbone weights for RRT') 80 | 81 | # ConR 82 | parser.add_argument('--conr', action='store_true', default=False, help='whether to enable conr') 83 | parser.add_argument('-w', type=float, default=0.2, help='similarity window for conR loss') 84 | parser.add_argument('--beta', type=float, default=0.2, help='conR loss coeff') 85 | parser.add_argument('-t', type=float, default=.2, help='temperature') 86 | parser.add_argument('-e', type=float, default=0.2, help="coeff for eta in ConR") 87 | 88 | 89 | 90 | def define_model(args): 91 | original_model = resnet.resnet50(pretrained=True) 92 | Encoder = modules.E_resnet(original_model) 93 | model = net.model(args, Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) 94 | 95 | return model 96 | 97 | def main(): 98 | error_best = 1e5 99 | metric_dict_best = {} 100 | epoch_best = -1 101 | 102 | global args 103 | args = parser.parse_args() 104 | 105 | if not args.lds and args.reweight != 'none': 106 | args.store_name += f'_{args.reweight}' 107 | if args.lds: 108 | args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}' 109 | if args.lds_kernel in ['gaussian', 'laplace']: 110 | args.store_name += f'_{args.lds_sigma}' 111 | if args.fds: 112 | args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}' 113 | if args.fds_kernel in ['gaussian', 'laplace']: 114 | args.store_name += f'_{args.fds_sigma}' 115 | args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}' 116 | if args.retrain_fc: 117 | args.store_name += f'_retrain_fc' 118 | if args.bmse: 119 | args.store_name += f'_{args.imp}_{args.init_noise_sigma}_{args.sigma_lr}' 120 | if args.imp == 'gai': 121 | args.store_name += f'_{args.gmm[:-4]}' 122 | args.store_name += f'_lr_{args.lr}_bs_{args.batch_size}' 123 | 124 | timestamp = str(datetime.datetime.now()) 125 | timestamp = '-'.join(timestamp.split(' ')) 126 | args.store_name = args.store_name + '_' + timestamp 127 | 128 | args.store_dir = os.path.join(args.store_root, args.store_name) 129 | 130 | if not args.resume: 131 | if os.path.exists(args.store_dir): 132 | if query_yes_no('overwrite previous folder: {} ?'.format(args.store_dir)): 133 | shutil.rmtree(args.store_dir) 134 | print(args.store_dir + ' removed.') 135 | else: 136 | raise RuntimeError('Output folder {} already exists'.format(args.store_dir)) 137 | print(f"===> Creating folder: {args.store_dir}") 138 | os.makedirs(args.store_dir) 139 | 140 | logging.root.handlers = [] 141 | log_file = os.path.join(args.store_dir, 'training_log.log') 142 | logging.basicConfig( 143 | level=logging.INFO, 144 | format="%(asctime)s | %(message)s", 145 | handlers=[ 146 | logging.FileHandler(log_file), 147 | logging.StreamHandler() 148 | ]) 149 | logging.info(args) 150 | 151 | writer = SummaryWriter(args.store_dir) 152 | 153 | model = define_model(args) 154 | model = torch.nn.DataParallel(model).cuda() 155 | 156 | if args.resume: 157 | model_state = torch.load(os.path.join(args.store_dir, 'checkpoint.pth.tar')) 158 | logging.info(f"Loading checkpoint from {os.path.join(args.store_dir, 'checkpoint.pth.tar')}" 159 | f" (Epoch [{model_state['epoch']}], RMSE: {model_state['error']:.3f})") 160 | model.load_state_dict(model_state['state_dict']) 161 | 162 | args.start_epoch = model_state['epoch'] + 1 163 | epoch_best = model_state['epoch'] 164 | error_best = model_state['error'] 165 | metric_dict_best = model_state['metric'] 166 | 167 | if args.retrain_fc: 168 | assert os.path.isfile(args.pretrained), f"No checkpoint found at '{args.pretrained}'" 169 | model_state = torch.load(args.pretrained, map_location="cpu") 170 | from collections import OrderedDict 171 | new_state_dict = OrderedDict() 172 | for k, v in model_state['state_dict'].items(): 173 | if 'R' not in k: 174 | new_state_dict[k] = v 175 | model.load_state_dict(new_state_dict, strict=False) 176 | logging.info(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]') 177 | logging.info(f'===> Pre-trained model loaded: {args.pretrained}') 178 | for name, param in model.named_parameters(): 179 | if 'R' not in name: 180 | param.requires_grad = False 181 | logging.info(f'Only optimize parameters: {[n for n, p in model.named_parameters() if p.requires_grad]}') 182 | 183 | cudnn.benchmark = True 184 | if not args.retrain_fc: 185 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 186 | else: 187 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 188 | optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay) 189 | 190 | train_loader = loaddata.getTrainingData(args, args.batch_size) 191 | train_fds_loader = loaddata.getTrainingFDSData(args, args.batch_size) 192 | test_loader = loaddata.getTestingData(args, 1) 193 | 194 | if args.bmse: 195 | if args.imp == 'gai': 196 | criterion = GAILoss(args.init_noise_sigma, args.gmm) 197 | elif args.imp == 'bmc': 198 | criterion = BMCLoss(args.init_noise_sigma) 199 | elif args.imp == 'bni': 200 | bucket_centers, bucket_weights = loaddata.get_bucket_info(args) 201 | criterion = BNILoss(args.init_noise_sigma, bucket_centers, bucket_weights) 202 | if not args.fix_noise_sigma: 203 | optimizer.add_param_group({'params': criterion.noise_sigma, 'lr': args.sigma_lr, 'name': 'noise_sigma'}) 204 | else: 205 | criterion = None 206 | 207 | for epoch in range(args.start_epoch, args.epochs): 208 | adjust_learning_rate(optimizer, epoch) 209 | train(train_loader, train_fds_loader, model, optimizer, epoch, writer, criterion) 210 | error, metric_dict = test(test_loader, model) 211 | if error < error_best: 212 | error_best = error 213 | metric_dict_best = metric_dict 214 | epoch_best = epoch 215 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint_best.pth.tar') 216 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint.pth.tar') 217 | 218 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint_final.pth.tar') 219 | logging.info(f'Best epoch: {epoch_best}; RMSE: {error_best:.3f}') 220 | logging.info('***** TEST RESULTS *****') 221 | for shot in ['Overall', 'Many', 'Medium', 'Few']: 222 | logging.info(f" * {shot}: RMSE {metric_dict_best[shot.lower()]['RMSE']:.3f}\t" 223 | f"ABS_REL {metric_dict_best[shot.lower()]['ABS_REL']:.3f}\t" 224 | f"LG10 {metric_dict_best[shot.lower()]['LG10']:.3f}\t" 225 | f"MAE {metric_dict_best[shot.lower()]['MAE']:.3f}\t" 226 | f"DELTA1 {metric_dict_best[shot.lower()]['DELTA1']:.3f}\t" 227 | f"DELTA2 {metric_dict_best[shot.lower()]['DELTA2']:.3f}\t" 228 | f"DELTA3 {metric_dict_best[shot.lower()]['DELTA3']:.3f}\t" 229 | f"NUM {metric_dict_best[shot.lower()]['NUM']}") 230 | 231 | writer.close() 232 | 233 | def train(train_loader, train_fds_loader, model, optimizer, epoch, writer, criterion): 234 | batch_time = AverageMeter() 235 | losses = AverageMeter() 236 | noise_var = AverageMeter() 237 | l2 = AverageMeter() 238 | if args.conr: 239 | loss_c = AverageMeter() 240 | 241 | model.train() 242 | 243 | end = time.time() 244 | for i, sample_batched in enumerate(train_loader): 245 | image, depth, weight = sample_batched['image'], sample_batched['depth'], sample_batched['weight'] 246 | 247 | depth = depth.cuda(non_blocking=True) 248 | weight = weight.cuda(non_blocking=True) 249 | image = image.cuda() 250 | optimizer.zero_grad() 251 | # if args.fds: 252 | # output, feature = model(image, depth, epoch) 253 | # else: 254 | # output = model(image, depth, epoch) 255 | 256 | output, feature = model(image, depth, epoch) 257 | loss = 0 258 | if args.conr: 259 | l_c = ConR(feature,depth,output,weights =weight,w=args.w,t=args.t,e=args.e) 260 | loss+=args.beta*l_c 261 | loss_c.update(l_c.item(), image.size(0)) 262 | 263 | if args.bmse: 264 | output = output[depth >= 0.7] 265 | depth = depth[depth >= 0.7] 266 | output = output.reshape(-1, 1) 267 | depth = depth.reshape(-1, 1) 268 | loss += criterion(output, depth) 269 | noise_var.update(criterion.noise_sigma.item() ** 2) 270 | l2.update(F.mse_loss(output, depth).item()) 271 | else: 272 | loss += torch.mean(((output - depth) ** 2) * weight) 273 | 274 | 275 | 276 | 277 | losses.update(loss.item(), image.size(0)) 278 | 279 | loss.backward() 280 | optimizer.step() 281 | 282 | batch_time.update(time.time() - end) 283 | end = time.time() 284 | 285 | writer.add_scalar('data/loss', loss.item(), i + epoch * len(train_loader)) 286 | 287 | if args.conr: 288 | logging.info('Epoch: [{0}][{1}/{2}]\t' 289 | 'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t' 290 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 291 | 'L2 {l2.val:.4f} ({l2.avg:.4f})\t' 292 | 'Loss(ConR) {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 293 | .format(epoch, i, len(train_loader), batch_time=batch_time, loss=losses, l2=l2,loss_c = loss_c)) 294 | else: 295 | logging.info('Epoch: [{0}][{1}/{2}]\t' 296 | 'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t' 297 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 298 | .format(epoch, i, len(train_loader), batch_time=batch_time, loss=losses)) 299 | 300 | if args.fds and epoch >= args.start_update: 301 | logging.info(f"Starting Creating Epoch [{epoch}] features of subsampled training data...") 302 | encodings, depths = [], [] 303 | with torch.no_grad(): 304 | for i, sample_batched in enumerate(tqdm(train_fds_loader)): 305 | image, depth = sample_batched['image'].cuda(), sample_batched['depth'].cuda() 306 | _, feature = model(image, depth, epoch) 307 | encodings.append(feature.data.cpu()) 308 | depths.append(depth.data.cpu()) 309 | encodings, depths = torch.cat(encodings, 0), torch.cat(depths, 0) 310 | logging.info(f"Created Epoch [{epoch}] features of subsampled training data (size: {encodings.size(0)})!") 311 | model.module.R.FDS.update_last_epoch_stats(epoch) 312 | model.module.R.FDS.update_running_stats(encodings, depths, epoch) 313 | 314 | 315 | def adjust_learning_rate(optimizer, epoch): 316 | lr = args.lr * (0.1 ** (epoch // 5)) 317 | 318 | for param_group in optimizer.param_groups: 319 | if 'name' in param_group and param_group['name'] == 'noise_sigma': 320 | continue 321 | param_group['lr'] = lr 322 | 323 | 324 | class AverageMeter(object): 325 | def __init__(self): 326 | self.reset() 327 | 328 | def reset(self): 329 | self.val = 0 330 | self.avg = 0 331 | self.sum = 0 332 | self.count = 0 333 | 334 | def update(self, val, n=1): 335 | self.val = val 336 | self.sum += val * n 337 | self.count += n 338 | self.avg = self.sum / self.count 339 | 340 | 341 | def save_checkpoint(state_dict, epoch, error, metric_dict, filename='checkpoint.pth.tar'): 342 | logging.info(f'Saving checkpoint to {os.path.join(args.store_dir, filename)}...') 343 | torch.save({ 344 | 'state_dict': state_dict, 345 | 'epoch': epoch, 346 | 'error': error, 347 | 'metric': metric_dict 348 | }, os.path.join(args.store_dir, filename)) 349 | 350 | if __name__ == '__main__': 351 | main() 352 | -------------------------------------------------------------------------------- /nyud2-dir/util.py: -------------------------------------------------------------------------------- 1 | ##################################################################################### 2 | # MIT License 3 | 4 | # Copyright (c) 2022 Jiawei Ren 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | #################################################################################### 24 | import math 25 | import logging 26 | import torch 27 | import numpy as np 28 | from scipy.ndimage import gaussian_filter1d 29 | from scipy.signal.windows import triang 30 | 31 | def lg10(x): 32 | return torch.div(torch.log(x), math.log(10)) 33 | 34 | def maxOfTwo(x, y): 35 | z = x.clone() 36 | maskYLarger = torch.lt(x, y) 37 | z[maskYLarger.detach()] = y[maskYLarger.detach()] 38 | return z 39 | 40 | def nValid(x): 41 | return torch.sum(torch.eq(x, x).float()) 42 | 43 | def getNanMask(x): 44 | return torch.ne(x, x) 45 | 46 | def setNanToZero(input, target): 47 | nanMask = getNanMask(target) 48 | nValidElement = nValid(target) 49 | 50 | _input = input.clone() 51 | _target = target.clone() 52 | 53 | _input[nanMask] = 0 54 | _target[nanMask] = 0 55 | 56 | return _input, _target, nanMask, nValidElement 57 | 58 | class Evaluator: 59 | def __init__(self): 60 | self.shot_idx = { 61 | 'many': [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 62 | 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 49], 63 | 'medium': [7, 8, 46, 48, 50, 51, 52, 53, 54, 55, 56, 58, 60, 61, 63], 64 | 'few': [57, 59, 62, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 65 | 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] 66 | } 67 | self.output = torch.tensor([], dtype=torch.float32) 68 | self.depth = torch.tensor([], dtype=torch.float32) 69 | 70 | def __call__(self, output, depth): 71 | output = output.squeeze().view(-1).cpu() 72 | depth = depth.squeeze().view(-1).cpu() 73 | self.output = torch.cat([self.output, output]) 74 | self.depth = torch.cat([self.depth, depth]) 75 | 76 | def evaluate_shot_balanced(self): 77 | metric_dict = {'overall': {}, 'many': {}, 'medium': {}, 'few': {}} 78 | self.depth_bucket = np.array(list(map(lambda v: self.get_bin_idx(v), self.depth.cpu().numpy()))) 79 | 80 | bin_cnt = [] 81 | for i in range(100): 82 | cnt = np.count_nonzero(self.depth_bucket == i) 83 | cnt = 1 if cnt >= 1 else 0 84 | bin_cnt.append(cnt) 85 | 86 | bin_metric = [] 87 | for i in range(100): 88 | mask = np.zeros(self.depth.size(0), dtype=np.bool) 89 | mask[np.where(self.depth_bucket == i)[0]] = True 90 | mask = torch.tensor(mask, dtype=torch.bool) 91 | bin_metric.append(self.evaluate(self.output[mask], self.depth[mask])) 92 | 93 | for shot in metric_dict.keys(): 94 | if shot == 'overall': 95 | for k in bin_metric[0].keys(): 96 | metric_dict[shot][k] = 0. 97 | for i in range(7, 100): 98 | metric_dict[shot][k] += bin_metric[i][k] 99 | if k!= 'NUM': 100 | metric_dict[shot][k] /= sum(bin_cnt) 101 | else: 102 | for k in bin_metric[0].keys(): 103 | metric_dict[shot][k] = 0. 104 | for i in self.shot_idx[shot]: 105 | metric_dict[shot][k] += bin_metric[i][k] 106 | if k != 'NUM': 107 | metric_dict[shot][k] /= sum([bin_cnt[i] for i in self.shot_idx[shot]]) 108 | 109 | logging.info('\n***** TEST RESULTS *****') 110 | for shot in ['Overall', 'Many', 'Medium', 'Few']: 111 | logging.info(f" * {shot}: RMSE {metric_dict[shot.lower()]['MSE'] ** 0.5:.3f}\t" 112 | f"MSE {metric_dict[shot.lower()]['MSE']:.3f}\t" 113 | f"ABS_REL {metric_dict[shot.lower()]['ABS_REL']:.3f}\t" 114 | f"LG10 {metric_dict[shot.lower()]['LG10']:.3f}\t" 115 | f"MAE {metric_dict[shot.lower()]['MAE']:.3f}\t" 116 | f"DELTA1 {metric_dict[shot.lower()]['DELTA1']:.3f}\t" 117 | f"DELTA2 {metric_dict[shot.lower()]['DELTA2']:.3f}\t" 118 | f"DELTA3 {metric_dict[shot.lower()]['DELTA3']:.3f}\t" 119 | f"NUM {metric_dict[shot.lower()]['NUM']}") 120 | 121 | return metric_dict 122 | 123 | def evaluate_shot(self): 124 | metric_dict = {'overall': {}, 'many': {}, 'medium': {}, 'few': {}} 125 | self.depth_bucket = np.array(list(map(lambda v: self.get_bin_idx(v), self.depth.cpu().numpy()))) 126 | 127 | for shot in metric_dict.keys(): 128 | if shot == 'overall': 129 | metric_dict[shot] = self.evaluate(self.output, self.depth) 130 | else: 131 | mask = np.zeros(self.depth.size(0), dtype=np.bool) 132 | for i in self.shot_idx[shot]: 133 | mask[np.where(self.depth_bucket == i)[0]] = True 134 | mask = torch.tensor(mask, dtype=torch.bool) 135 | metric_dict[shot] = self.evaluate(self.output[mask], self.depth[mask]) 136 | 137 | logging.info('\n***** TEST RESULTS *****') 138 | for shot in ['Overall', 'Many', 'Medium', 'Few']: 139 | logging.info(f" * {shot}: RMSE {metric_dict[shot.lower()]['RMSE']:.3f}\t" 140 | f"ABS_REL {metric_dict[shot.lower()]['ABS_REL']:.3f}\t" 141 | f"LG10 {metric_dict[shot.lower()]['LG10']:.3f}\t" 142 | f"MAE {metric_dict[shot.lower()]['MAE']:.3f}\t" 143 | f"DELTA1 {metric_dict[shot.lower()]['DELTA1']:.3f}\t" 144 | f"DELTA2 {metric_dict[shot.lower()]['DELTA2']:.3f}\t" 145 | f"DELTA3 {metric_dict[shot.lower()]['DELTA3']:.3f}\t" 146 | f"NUM {metric_dict[shot.lower()]['NUM']}") 147 | 148 | return metric_dict 149 | 150 | def reset(self): 151 | self.output = torch.tensor([], dtype=torch.float32) 152 | self.depth = torch.tensor([], dtype=torch.float32) 153 | 154 | @staticmethod 155 | def get_bin_idx(x): 156 | return min(int(x * np.float32(10)), 99) 157 | 158 | @staticmethod 159 | def evaluate(output, target): 160 | errors = {'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0, 161 | 'MAE': 0, 'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0, 'NUM': 0} 162 | 163 | _output, _target, nanMask, nValidElement = setNanToZero(output, target) 164 | 165 | if (nValidElement.data.cpu().numpy() > 0): 166 | diffMatrix = torch.abs(_output - _target) 167 | 168 | errors['MSE'] = torch.sum(torch.pow(diffMatrix, 2)) / nValidElement 169 | 170 | errors['MAE'] = torch.sum(diffMatrix) / nValidElement 171 | 172 | realMatrix = torch.div(diffMatrix, _target) 173 | realMatrix[nanMask] = 0 174 | errors['ABS_REL'] = torch.sum(realMatrix) / nValidElement 175 | 176 | LG10Matrix = torch.abs(lg10(_output) - lg10(_target)) 177 | LG10Matrix[nanMask] = 0 178 | errors['LG10'] = torch.sum(LG10Matrix) / nValidElement 179 | 180 | yOverZ = torch.div(_output, _target) 181 | zOverY = torch.div(_target, _output) 182 | 183 | maxRatio = maxOfTwo(yOverZ, zOverY) 184 | 185 | errors['DELTA1'] = torch.sum( 186 | torch.le(maxRatio, 1.25).float()) / nValidElement 187 | errors['DELTA2'] = torch.sum( 188 | torch.le(maxRatio, math.pow(1.25, 2)).float()) / nValidElement 189 | errors['DELTA3'] = torch.sum( 190 | torch.le(maxRatio, math.pow(1.25, 3)).float()) / nValidElement 191 | 192 | errors['MSE'] = float(errors['MSE'].data.cpu().numpy()) 193 | errors['ABS_REL'] = float(errors['ABS_REL'].data.cpu().numpy()) 194 | errors['LG10'] = float(errors['LG10'].data.cpu().numpy()) 195 | errors['MAE'] = float(errors['MAE'].data.cpu().numpy()) 196 | errors['DELTA1'] = float(errors['DELTA1'].data.cpu().numpy()) 197 | errors['DELTA2'] = float(errors['DELTA2'].data.cpu().numpy()) 198 | errors['DELTA3'] = float(errors['DELTA3'].data.cpu().numpy()) 199 | errors['NUM'] = int(nValidElement) 200 | 201 | errors['RMSE'] = np.sqrt(errors['MSE']) 202 | 203 | return errors 204 | 205 | 206 | def query_yes_no(question): 207 | """ Ask a yes/no question via input() and return their answer. """ 208 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 209 | prompt = " [Y/n] " 210 | 211 | while True: 212 | print(question + prompt, end=':') 213 | choice = input().lower() 214 | if choice == '': 215 | return valid['y'] 216 | elif choice in valid: 217 | return valid[choice] 218 | else: 219 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") 220 | 221 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.2, clip_max=5.): 222 | if torch.sum(v1) < 1e-10: 223 | return matrix 224 | if (v1 <= 0.).any() or (v2 < 0.).any(): 225 | valid_pos = (((v1 > 0.) + (v2 >= 0.)) == 2) 226 | # print(torch.sum(valid_pos)) 227 | factor = torch.clamp(v2[valid_pos] / v1[valid_pos], clip_min, clip_max) 228 | matrix[:, valid_pos] = (matrix[:, valid_pos] - m1[valid_pos]) * torch.sqrt(factor) + m2[valid_pos] 229 | return matrix 230 | 231 | factor = torch.clamp(v2 / v1, clip_min, clip_max) 232 | return (matrix - m1) * torch.sqrt(factor) + m2 233 | 234 | def get_lds_kernel_window(kernel, ks, sigma): 235 | assert kernel in ['gaussian', 'triang', 'laplace'] 236 | half_ks = (ks - 1) // 2 237 | if kernel == 'gaussian': 238 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 239 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) 240 | elif kernel == 'triang': 241 | kernel_window = triang(ks) 242 | else: 243 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 244 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) 245 | 246 | return kernel_window 247 | 248 | 249 | 250 | 251 | 252 | 253 | -------------------------------------------------------------------------------- /teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BorealisAI/ConR/ad6a8c90426da9243f9e1de7bb37c7429bb99256/teaser.jpg --------------------------------------------------------------------------------